26#include "llvm/ADT/STLExtras.h"
29#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
42 ~ControlFlowInlinerInterface()
override =
default;
46 bool wouldBeCloned)
const final {
49 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
54 void handleTerminator(Operation *op,
Block *newDest)
const final {}
62void ControlFlowDialect::initialize() {
65#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
67 addInterfaces<ControlFlowInlinerInterface>();
68 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
69 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
71 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
79LogicalResult AssertOp::canonicalize(AssertOp op,
PatternRewriter &rewriter) {
89void AssertOp::getEffects(
109 if (std::next(successor->
begin()) != successor->
end())
112 BranchOp successorBranch = dyn_cast<BranchOp>(successor->
getTerminator());
113 if (!successorBranch)
118 if (user != successorBranch)
122 Block *successorDest = successorBranch.getDest();
123 if (successorDest == successor)
126 BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->
getTerminator());
129 Block *nextBranchDest = nextBranch.getDest();
130 if (visited.contains(nextBranchDest))
132 visited.insert(nextBranchDest);
133 nextBranch = dyn_cast<BranchOp>(nextBranchDest->
getTerminator());
140 successor = successorDest;
141 successorOperands = operands;
146 for (
Value operand : operands) {
147 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
148 if (argOperand && argOperand.
getOwner() == successor)
149 argStorage.push_back(successorOperands[argOperand.
getArgNumber()]);
151 argStorage.push_back(operand);
153 successor = successorDest;
154 successorOperands = argStorage;
163 Block *succ = op.getDest();
164 Block *opParent = op->getBlock();
165 if (succ == opParent || !llvm::hasSingleElement(succ->
getPredecessors()))
183 Block *dest = op.getDest();
189 if (dest == op->getBlock() ||
198LogicalResult BranchOp::canonicalize(BranchOp op,
PatternRewriter &rewriter) {
203void BranchOp::setDest(
Block *block) {
return setSuccessor(block); }
205void BranchOp::eraseOperand(
unsigned index) { (*this)->eraseOperand(
index); }
208 assert(
index == 0 &&
"invalid successor index");
227 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
229 LogicalResult matchAndRewrite(CondBranchOp condbr,
230 PatternRewriter &rewriter)
const override {
234 condbr.getTrueOperands());
240 condbr.getFalseOperands());
255struct SimplifyPassThroughCondBranch :
public OpRewritePattern<CondBranchOp> {
256 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
258 LogicalResult matchAndRewrite(CondBranchOp condbr,
259 PatternRewriter &rewriter)
const override {
260 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
261 ValueRange trueDestOperands = condbr.getTrueOperands();
262 ValueRange falseDestOperands = condbr.getFalseOperands();
263 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
266 LogicalResult collapsedTrue =
267 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
268 LogicalResult collapsedFalse =
269 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
275 condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
276 falseDestOperands, condbr.getWeights());
288struct SimplifyCondBranchIdenticalSuccessors
290 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
292 LogicalResult matchAndRewrite(CondBranchOp condbr,
293 PatternRewriter &rewriter)
const override {
296 Block *trueDest = condbr.getTrueDest();
297 if (trueDest != condbr.getFalseDest())
301 OperandRange trueOperands = condbr.getTrueOperands();
302 OperandRange falseOperands = condbr.getFalseOperands();
303 if (trueOperands == falseOperands) {
314 SmallVector<Value, 8> mergedOperands;
315 mergedOperands.reserve(trueOperands.size());
316 Value condition = condbr.getCondition();
317 for (
auto it : llvm::zip(trueOperands, falseOperands)) {
318 if (std::get<0>(it) == std::get<1>(it))
319 mergedOperands.push_back(std::get<0>(it));
321 mergedOperands.push_back(
322 arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
323 std::get<0>(it), std::get<1>(it)));
347struct SimplifyCondBranchFromCondBranchOnSameCondition
349 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
351 LogicalResult matchAndRewrite(CondBranchOp condbr,
352 PatternRewriter &rewriter)
const override {
354 Block *currentBlock = condbr->getBlock();
361 auto predBranch = dyn_cast<CondBranchOp>(predecessor->
getTerminator());
362 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
366 if (currentBlock == predBranch.getTrueDest())
368 condbr.getTrueDestOperands());
371 condbr.getFalseDestOperands());
397 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
399 LogicalResult matchAndRewrite(CondBranchOp condbr,
400 PatternRewriter &rewriter)
const override {
402 bool replaced =
false;
407 Value constantTrue =
nullptr;
408 Value constantFalse =
nullptr;
415 if (condbr.getTrueDest()->getSinglePredecessor()) {
416 for (OpOperand &use :
417 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
418 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
422 constantTrue = arith::ConstantOp::create(
423 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
true));
426 [&] { use.set(constantTrue); });
430 if (condbr.getFalseDest()->getSinglePredecessor()) {
431 for (OpOperand &use :
432 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
433 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
437 constantFalse = arith::ConstantOp::create(
438 rewriter, condbr.getLoc(), ty, rewriter.
getBoolAttr(
false));
441 [&] { use.set(constantFalse); });
452 results.
add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
453 SimplifyCondBranchIdenticalSuccessors,
454 SimplifyCondBranchFromCondBranchOnSameCondition,
455 CondBranchTruthPropagation>(context);
459 assert(
index < getNumSuccessors() &&
"invalid successor index");
461 : getFalseDestOperandsMutable());
465 if (IntegerAttr condAttr =
466 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
467 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
480 build(builder,
result, value, defaultOperands, caseOperands, caseValues,
481 defaultDestination, caseDestinations);
489 if (!caseValues.empty()) {
490 ShapedType caseValueType = VectorType::get(
494 build(builder,
result, value, defaultDestination, defaultOperands,
495 caseValuesAttr, caseDestinations, caseOperands);
503 if (!caseValues.empty()) {
504 ShapedType caseValueType = VectorType::get(
508 build(builder,
result, value, defaultDestination, defaultOperands,
509 caseValuesAttr, caseDestinations, caseOperands);
538 values.push_back(APInt(bitWidth, value,
true));
553 caseDestinations.push_back(destination);
554 caseOperands.emplace_back(operands);
555 caseOperandTypes.emplace_back(operandTypes);
558 if (!values.empty()) {
559 ShapedType caseValueType =
560 VectorType::get(
static_cast<int64_t>(values.size()), flagType);
577 for (
const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
581 p << it.value().getLimitedValue();
584 caseOperands[it.index()]);
589LogicalResult SwitchOp::verify() {
590 auto caseValues = getCaseValues();
591 auto caseDestinations = getCaseDestinations();
593 if (!caseValues && caseDestinations.empty())
596 Type flagType = getFlag().getType();
597 Type caseValueType = caseValues->getType().getElementType();
598 if (caseValueType != flagType)
599 return emitOpError() <<
"'flag' type (" << flagType
600 <<
") should match case value type (" << caseValueType
604 caseValues->size() !=
static_cast<int64_t>(caseDestinations.size()))
605 return emitOpError() <<
"number of case values (" << caseValues->size()
606 <<
") should match number of "
607 "case destinations ("
608 << caseDestinations.size() <<
")";
613 assert(
index < getNumSuccessors() &&
"invalid successor index");
615 : getCaseOperandsMutable(
index - 1));
619 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
622 return getDefaultDestination();
625 if (
auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
626 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
627 if (it.value() == value.getValue())
628 return caseDests[it.index()];
629 return getDefaultDestination();
640 if (!op.getCaseDestinations().empty())
644 op.getDefaultOperands());
663 bool requiresChange =
false;
664 auto caseValues = op.getCaseValues();
665 auto caseDests = op.getCaseDestinations();
667 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
668 if (caseDests[it.index()] == op.getDefaultDestination() &&
669 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
670 requiresChange =
true;
673 newCaseDestinations.push_back(caseDests[it.index()]);
674 newCaseOperands.push_back(op.getCaseOperands(it.index()));
675 newCaseValues.push_back(it.value());
682 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
683 newCaseValues, newCaseDestinations, newCaseOperands);
695 const APInt &caseValue) {
696 auto caseValues = op.getCaseValues();
697 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
698 if (it.value() == caseValue) {
700 op, op.getCaseDestinations()[it.index()],
701 op.getCaseOperands(it.index()));
706 op.getDefaultOperands());
741 auto caseValues = op.getCaseValues();
742 argStorage.reserve(caseValues->size() + 1);
743 auto caseDests = op.getCaseDestinations();
744 bool requiresChange =
false;
745 for (
int64_t i = 0, size = caseValues->size(); i < size; ++i) {
746 Block *caseDest = caseDests[i];
747 ValueRange caseOperands = op.getCaseOperands(i);
748 argStorage.emplace_back();
749 if (succeeded(
collapseBranch(caseDest, caseOperands, argStorage.back())))
750 requiresChange =
true;
752 newCaseDests.push_back(caseDest);
753 newCaseOperands.push_back(caseOperands);
756 Block *defaultDest = op.getDefaultDestination();
757 ValueRange defaultOperands = op.getDefaultOperands();
758 argStorage.emplace_back();
762 requiresChange =
true;
768 defaultOperands, *caseValues,
769 newCaseDests, newCaseOperands);
812 Block *currentBlock = op->getBlock();
820 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
821 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
822 predSwitch.getDefaultDestination() == currentBlock)
827 auto it = llvm::find(predDests, currentBlock);
828 if (it != predDests.end()) {
829 std::optional<DenseIntElementsAttr> predCaseValues =
830 predSwitch.getCaseValues();
832 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
835 op.getDefaultOperands());
864 Block *currentBlock = op->getBlock();
872 auto predSwitch = dyn_cast<SwitchOp>(predecessor->
getTerminator());
873 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
874 predSwitch.getDefaultDestination() != currentBlock)
879 auto predDests = predSwitch.getCaseDestinations();
880 auto predCaseValues = predSwitch.getCaseValues();
881 for (
int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
882 if (currentBlock != predDests[i])
883 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
888 bool requiresChange =
false;
890 auto caseValues = op.getCaseValues();
891 auto caseDests = op.getCaseDestinations();
892 for (
const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
893 if (caseValuesToRemove.contains(it.value())) {
894 requiresChange =
true;
897 newCaseDestinations.push_back(caseDests[it.index()]);
898 newCaseOperands.push_back(op.getCaseOperands(it.index()));
899 newCaseValues.push_back(it.value());
906 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
907 newCaseValues, newCaseDestinations, newCaseOperands);
925#define GET_OP_CLASSES
926#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult simplifySwitchFromSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch flag : i32, [ default: ^bb3,...
static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1 ] -> br ^bb1
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, const APInt &caseValue)
Helper for folding a switch with a constant value.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static LogicalResult simplifyPassThroughBr(BranchOp op, PatternRewriter &rewriter)
br ^bb1 ^bb1 br ^bbN(...)
static LogicalResult simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter)
Simplify a branch to a block that has a single predecessor.
static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl< Value > &argStorage)
Given a successor, try to collapse it to a new destination if it only contains a passthrough uncondit...
static LogicalResult dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch flag : i32, [ default: ^bb1,...
static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter)
switch c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch c_42 : i32,...
static LogicalResult simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, PatternRewriter &rewriter)
switch flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch flag : i32, [ default: ^bb3,...
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type &flagType, Block *&defaultDestination, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &defaultOperands, SmallVectorImpl< Type > &defaultOperandTypes, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= default : bb-id (( ssa-use-and-type-list ))?
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
iterator_range< pred_iterator > getPredecessors()
Block * getSinglePredecessor()
If this block has exactly one predecessor, return it.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Block * getUniquePredecessor()
If this block has a unique predecessor, i.e., all incoming edges originate from one block,...
BoolAttr getBoolAttr(bool value)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
This class helps build Operations.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class provides an abstraction for a range of TypeRange.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.