26 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc" 51 auto retValOp = dyn_cast<scf::YieldOp>(op);
55 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
56 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
66 void SCFDialect::initialize() {
69 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" 71 addInterfaces<SCFInlinerInterface>();
76 builder.
create<scf::YieldOp>(loc);
87 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
133 if (getRegion().empty())
134 return emitOpError(
"region needs to have at least one block");
135 if (getRegion().front().getNumArguments() > 0)
136 return emitOpError(
"region cannot have any arguments");
159 if (!llvm::hasSingleElement(op.getRegion()))
208 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
211 Block *prevBlock = op->getBlock();
215 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
217 for (
Block &blk : op.getRegion()) {
218 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
220 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
221 yieldOp.getResults());
229 for (
auto res : op.getResults())
247 void ExecuteRegionOp::getSuccessorRegions(
267 return getArgsMutable();
276 BodyBuilderFn bodyBuilder) {
279 for (
Value v : iterArgs)
285 for (
Value v : iterArgs)
291 if (iterArgs.empty() && !bodyBuilder) {
292 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
293 }
else if (bodyBuilder) {
302 if (
auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
303 if (cst.value() <= 0)
304 return emitOpError(
"constant step operand must be positive");
306 auto opNumResults = getNumResults();
307 if (opNumResults == 0)
312 if (getNumIterOperands() != opNumResults)
314 "mismatch in number of loop-carried values and defined values");
321 auto *body = getBody();
322 if (!body->getArgument(0).getType().isIndex())
324 "expected body first argument to be an index argument for " 325 "the induction variable");
327 auto opNumResults = getNumResults();
328 if (opNumResults == 0)
331 if (getNumRegionIterArgs() != opNumResults)
333 "mismatch in number of basic block args and defined values");
335 auto iterOperands = getIterOperands();
336 auto iterArgs = getRegionIterArgs();
337 auto opResults = getResults();
339 for (
auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
340 if (std::get<0>(e).getType() != std::get<2>(e).getType())
341 return emitOpError() <<
"types mismatch between " << i
342 <<
"th iter operand and defined value";
343 if (std::get<1>(e).getType() != std::get<2>(e).getType())
344 return emitOpError() <<
"types mismatch between " << i
345 <<
"th iter region arg and defined value";
352 Optional<Value> ForOp::getSingleInductionVar() {
return getInductionVar(); }
373 StringRef prefix =
"") {
374 assert(blocksArgs.size() == initializers.size() &&
375 "expected same length of arguments and initializers");
376 if (initializers.empty())
380 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
381 p << std::get<0>(it) <<
" = " << std::get<1>(it);
387 p <<
" " << getInductionVar() <<
" = " << getLowerBound() <<
" to " 388 << getUpperBound() <<
" step " << getStep();
392 if (!getIterOperands().empty())
393 p <<
" -> (" << getIterOperands().getTypes() <<
')';
406 inductionVariable.
type = indexType;
423 regionArgs.push_back(inductionVariable);
432 for (
auto argOperandType :
433 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
434 Type type = std::get<2>(argOperandType);
435 std::get<0>(argOperandType).type = type;
442 if (regionArgs.size() != result.
types.size() + 1)
445 "mismatch in number of loop-carried values and defined values");
452 ForOp::ensureTerminator(*body, builder, result.
location);
461 Region &ForOp::getLoopBody() {
return getRegion(); }
467 assert(ivArg.getOwner() &&
"unlinked block argument");
468 auto *containingOp = ivArg.getOwner()->getParentOp();
469 return dyn_cast_or_null<ForOp>(containingOp);
477 assert(index && *index == 0 &&
"invalid region index");
481 return getInitArgs();
495 regions.push_back(
RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
500 assert(*index == 0 &&
"expected loop region");
501 regions.push_back(
RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
510 assert(lbs.size() == ubs.size() &&
511 "expected the same number of lower and upper bounds");
512 assert(lbs.size() == steps.size() &&
513 "expected the same number of lower bounds and steps");
518 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
520 assert(results.size() == iterArgs.size() &&
521 "loop nest body must return as many values as loop has iteration " 531 loops.reserve(lbs.size());
532 ivs.reserve(lbs.size());
535 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
536 auto loop = builder.
create<scf::ForOp>(
537 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
543 currentIterArgs = args;
544 currentLoc = nestedLoc;
550 loops.push_back(loop);
554 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
556 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
563 ? bodyBuilder(builder, currentLoc, ivs,
564 loops.back().getRegionIterArgs())
566 assert(results.size() == iterArgs.size() &&
567 "loop nest body must return as many values as loop has iteration " 570 builder.
create<scf::YieldOp>(loc, results);
574 res.loops.assign(loops.begin(), loops.end());
583 return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
588 bodyBuilder(nestedBuilder, nestedLoc, ivs);
612 bool canonicalize =
false;
622 keepMask.reserve(yieldOp.getNumOperands());
625 newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
626 newBlockTransferArgs.push_back(
Value());
627 newIterArgs.reserve(forOp.getNumIterOperands());
628 newYieldValues.reserve(yieldOp.getNumOperands());
629 newResultValues.reserve(forOp.getNumResults());
630 for (
auto it : llvm::zip(forOp.getIterOperands(),
631 forOp.getRegionIterArgs(),
633 yieldOp.getOperands()
641 bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
642 (std::get<1>(it).use_empty() &&
643 (std::get<0>(it) == std::get<3>(it) ||
644 std::get<2>(it).use_empty())));
645 keepMask.push_back(!forwarded);
646 canonicalize |= forwarded;
648 newBlockTransferArgs.push_back(std::get<0>(it));
649 newResultValues.push_back(std::get<0>(it));
652 newIterArgs.push_back(std::get<0>(it));
653 newYieldValues.push_back(std::get<3>(it));
654 newBlockTransferArgs.push_back(
Value());
655 newResultValues.push_back(
Value());
661 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
662 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
663 forOp.getStep(), newIterArgs);
664 newForOp->
setAttrs(forOp->getAttrs());
665 Block &newBlock = newForOp.getRegion().
front();
669 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
671 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
672 Value &newResultVal = newResultValues[idx];
673 assert((blockTransferArg && newResultVal) ||
674 (!blockTransferArg && !newResultVal));
675 if (!blockTransferArg) {
676 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
677 newResultVal = newForOp.getResult(collapsedIdx++);
683 "unexpected argument size mismatch");
688 if (newIterArgs.empty()) {
689 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
692 rewriter.
replaceOp(forOp, newResultValues);
697 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
701 filteredOperands.reserve(newResultValues.size());
702 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
704 filteredOperands.push_back(mergedTerminator.getOperand(idx));
705 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
709 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
710 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
711 cloneFilteredTerminator(mergedYieldOp);
712 rewriter.
eraseOp(mergedYieldOp);
713 rewriter.
replaceOp(forOp, newResultValues);
728 if (op.getLowerBound() == op.getUpperBound()) {
729 rewriter.
replaceOp(op, op.getIterOperands());
733 auto lb = op.getLowerBound().
getDefiningOp<arith::ConstantOp>();
734 auto ub = op.getUpperBound().
getDefiningOp<arith::ConstantOp>();
739 llvm::APInt lbValue = lb.getValue().
cast<IntegerAttr>().getValue();
740 llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
741 if (lbValue.sge(ubValue)) {
742 rewriter.
replaceOp(op, op.getIterOperands());
752 llvm::APInt stepValue = step.getValue().
cast<IntegerAttr>().getValue();
753 if ((lbValue + stepValue).sge(ubValue)) {
755 blockArgs.reserve(op.getNumIterOperands() + 1);
756 blockArgs.push_back(op.getLowerBound());
757 llvm::append_range(blockArgs, op.getIterOperands());
764 if (!llvm::hasSingleElement(block))
769 auto yieldOperands = yieldOp.getOperands();
770 if (llvm::any_of(yieldOperands,
771 [&](
Value v) {
return !op.isDefinedOutsideOfLoop(v); }))
785 assert(oldType.
isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
786 "expected ranked tensor types");
789 ForOp forOp = cast<ForOp>(operand.
getOwner());
791 "expected an iter OpOperand");
795 for (
OpOperand &opOperand : forOp.getIterOpOperands()) {
797 newIterOperands.push_back(replacement);
800 newIterOperands.push_back(opOperand.get());
804 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
805 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806 forOp.getStep(), newIterOperands);
807 newForOp->
setAttrs(forOp->getAttrs());
808 Block &newBlock = newForOp.getRegion().
front();
816 BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
818 Value castIn = rewriter.
create<tensor::CastOp>(newForOp.getLoc(), oldType,
820 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
824 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
827 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
830 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
832 newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
834 newYieldOperands[yieldIdx] = castOut;
835 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
836 rewriter.
eraseOp(clonedYieldOp);
841 newResults[yieldIdx] = rewriter.
create<tensor::CastOp>(
842 newForOp.getLoc(), oldType, newResults[yieldIdx]);
878 for (
auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
879 OpOperand &iterOpOperand = std::get<0>(it);
883 if (!std::get<1>(it).hasOneUse())
885 auto outgoingCastOp =
886 dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
891 if (outgoingCastOp.getResult().getType() !=
892 incomingCast.getSource().getType())
896 auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
897 incomingCast.getSource());
904 replacements[returnIdx] = rewriter.
create<tensor::CastOp>(
905 op.getLoc(), incomingCast.getDest().getType(),
906 replacements[returnIdx]);
969 assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
970 "unexpected multiple blocks");
975 unsigned idx = bbArg.getArgNumber() - 1;
977 cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
978 Value yieldVal = yieldOp->getOperand(idx);
979 auto tensorLoadOp = yieldVal.
getDefiningOp<bufferization::ToTensorOp>();
980 bool isTensor = bbArg.getType().isa<
TensorType>();
982 bufferization::ToMemrefOp tensorToMemref;
984 if (bbArg.hasOneUse())
986 dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
987 if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
990 if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
997 if (tensorLoadOp->getNextNode() != yieldOp)
1001 if (tensorToMemref) {
1004 tensorToMemref, tensorToMemref.getMemref().getType(),
1005 tensorToMemref.getTensor());
1010 Value newTensorLoad = rewriter.
create<bufferization::ToTensorOp>(
1011 loc, tensorLoadOp.getMemref());
1012 Value forOpResult = forOp.getResult(bbArg.getArgNumber() - 1);
1013 replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1018 yieldOp.setOperand(idx, bbArg);
1021 if (replacements.empty())
1029 newResults.reserve(forOp.getNumResults());
1030 for (
Value v : forOp.getResults()) {
1031 auto it = replacements.find(v);
1032 newResults.push_back((it != replacements.end()) ? it->second : v);
1036 return op.
get() != newResults[idx++];
1045 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1046 LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1059 auto *body = getBody();
1060 if (body->getNumArguments() != getRank())
1061 return emitOpError(
"region expects ") << getRank() <<
" arguments";
1064 auto terminatorTypes = getTerminator().getYieldedTypes();
1065 auto opResults = getResults();
1066 if (opResults.size() != terminatorTypes.size())
1067 return emitOpError(
"produces ")
1068 << opResults.size() <<
" results, but its terminator yields " 1069 << terminatorTypes.size() <<
" value(s)";
1071 for (
auto e : llvm::zip(terminatorTypes, opResults)) {
1072 if (std::get<0>(e) != std::get<1>(e).getType())
1073 return emitOpError() <<
"type mismatch between result " << i <<
" (" 1074 << std::get<1>(e).getType() <<
") and terminator (" 1075 << std::get<0>(e) <<
")";
1083 llvm::interleaveComma(getThreadIndices(), p);
1085 llvm::interleaveComma(getNumThreads(), p);
1086 p <<
") -> (" << getResultTypes() <<
") ";
1089 getNumResults() > 0);
1117 std::unique_ptr<Region> region = std::make_unique<Region>();
1118 for (
auto &idx : threadIndices)
1125 ForeachThreadOp::ensureTerminator(*region, b, result.
location);
1155 ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.
location);
1161 void ForeachThreadOp::build(
1182 llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.
getTerminator());
1183 assert(terminator &&
1184 "expected bodyBuilder to create PerformConcurrentlyOp terminator");
1185 result.
addTypes(terminator.getYieldedTypes());
1191 void ForeachThreadOp::ensureTerminator(
Region ®ion,
OpBuilder &builder,
1194 ForeachThreadOp>::ensureTerminator(region, builder, loc);
1197 if (terminator.getRegion().empty())
1201 PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
1202 return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
1208 return ForeachThreadOp();
1209 assert(tidxArg.getOwner() &&
"unlinked block argument");
1210 auto *containingOp = tidxArg.getOwner()->getParentOp();
1211 return dyn_cast<ForeachThreadOp>(containingOp);
1227 for (
const Operation &op : getRegion().front().getOperations()) {
1228 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1229 return this->emitOpError(
"expected only ")
1230 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1249 std::unique_ptr<Region> region = std::make_unique<Region>();
1253 if (region->empty())
1263 OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
1264 return getOperation()->getParentOp()->getResult(idx);
1268 return llvm::to_vector<4>(
1269 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1270 auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
1271 return insertSliceOp ? insertSliceOp.yieldedType() :
Type();
1276 return getRegion().front().getOperations();
1284 assert(a &&
"expected non-empty operation");
1285 assert(b &&
"expected non-empty operation");
1290 if (ifOp->isProperAncestor(b))
1293 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1294 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1296 ifOp = ifOp->getParentOfType<IfOp>();
1304 bool withElseRegion) {
1305 build(builder, result, llvm::None, cond, withElseRegion);
1311 if (resultTypes.empty())
1316 build(builder, result, resultTypes, cond, addTerminator,
1317 withElseRegion ? addTerminator
1325 assert(thenBuilder &&
"the builder callback for 'then' must be present");
1333 thenBuilder(builder, result.
location);
1340 elseBuilder(builder, result.
location);
1346 build(builder, result,
TypeRange(), cond, thenBuilder, elseBuilder);
1350 if (getNumResults() != 0 && getElseRegion().empty())
1351 return emitOpError(
"must have an else block if defining values");
1389 bool printBlockTerminators =
false;
1391 p <<
" " << getCondition();
1392 if (!getResults().empty()) {
1393 p <<
" -> (" << getResultTypes() <<
")";
1395 printBlockTerminators =
true;
1400 printBlockTerminators);
1403 auto &elseRegion = getElseRegion();
1404 if (!elseRegion.empty()) {
1408 printBlockTerminators);
1429 Region *elseRegion = &this->getElseRegion();
1430 if (elseRegion->
empty())
1431 elseRegion =
nullptr;
1435 if (
auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1436 condition = condAttr.getValue().isOneValue();
1447 regions.push_back(
RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1453 if (getElseRegion().empty())
1456 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1463 getConditionMutable().assign(xorStmt.getLhs());
1467 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1468 getElseRegion().getBlocks());
1469 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1470 getThenRegion().getBlocks(), thenBlock);
1474 void IfOp::getRegionInvocationBounds(
1477 if (
auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1480 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1481 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1484 invocationBounds.assign(2, {0, 1});
1500 llvm::transform(usedResults, std::back_inserter(usedOperands),
1505 [&]() { yieldOp->setOperands(usedOperands); });
1512 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1513 [](
OpResult result) {
return !result.use_empty(); });
1516 if (usedResults.size() == op.getNumResults())
1521 llvm::transform(usedResults, std::back_inserter(newTypes),
1526 auto newOp = rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1527 emptyBuilder, emptyBuilder);
1531 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1532 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1537 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1548 auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
1552 if (constant.getValue().cast<
BoolAttr>().getValue())
1554 else if (!op.getElseRegion().empty())
1570 if (op->getNumResults() == 0)
1573 auto cond = op.getCondition();
1574 auto thenYieldArgs = op.thenYield().getOperands();
1575 auto elseYieldArgs = op.elseYield().getOperands();
1578 for (
const auto &it :
1580 Value trueVal = std::get<0>(it.value());
1581 Value falseVal = std::get<1>(it.value());
1584 nonHoistable.push_back(trueVal.
getType());
1588 if (nonHoistable.size() == op->getNumResults())
1591 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond);
1592 if (replacement.thenBlock())
1593 rewriter.
eraseBlock(replacement.thenBlock());
1594 replacement.getThenRegion().takeBody(op.getThenRegion());
1595 replacement.getElseRegion().takeBody(op.getElseRegion());
1598 assert(thenYieldArgs.size() == results.size());
1599 assert(elseYieldArgs.size() == results.size());
1604 for (
const auto &it :
1606 Value trueVal = std::get<0>(it.value());
1607 Value falseVal = std::get<1>(it.value());
1610 results[it.index()] = replacement.getResult(trueYields.size());
1611 trueYields.push_back(trueVal);
1612 falseYields.push_back(falseVal);
1613 }
else if (trueVal == falseVal)
1614 results[it.index()] = trueVal;
1616 results[it.index()] = rewriter.
create<arith::SelectOp>(
1617 op.getLoc(), cond, trueVal, falseVal);
1651 if (op.getCondition().getDefiningOp<arith::ConstantOp>())
1654 bool changed =
false;
1659 Value constantTrue =
nullptr;
1660 Value constantFalse =
nullptr;
1663 llvm::make_early_inc_range(op.getCondition().getUses())) {
1664 if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1668 constantTrue = rewriter.
create<arith::ConstantOp>(
1672 [&]() { use.set(constantTrue); });
1673 }
else if (op.getElseRegion().isAncestor(
1674 use.getOwner()->getParentRegion())) {
1678 constantFalse = rewriter.
create<arith::ConstantOp>(
1682 [&]() { use.set(constantFalse); });
1726 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
1732 if (op.getNumResults() == 0)
1736 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1738 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1741 op.getOperation()->getIterator());
1742 bool changed =
false;
1744 for (
auto [trueResult, falseResult, opResult] :
1745 llvm::zip(trueYield.getResults(), falseYield.getResults(),
1747 if (trueResult == falseResult) {
1748 if (!opResult.use_empty()) {
1749 opResult.replaceAllUsesWith(trueResult);
1755 auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
1759 if (!trueYield.getType().isInteger(1))
1762 auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
1766 bool trueVal = trueYield.getValue().cast<
BoolAttr>().getValue();
1767 bool falseVal = falseYield.getValue().cast<
BoolAttr>().getValue();
1768 if (!trueVal && falseVal) {
1769 if (!opResult.use_empty()) {
1771 op.getLoc(), op.getCondition(),
1772 rewriter.
create<arith::ConstantOp>(
1774 opResult.replaceAllUsesWith(notCond);
1778 if (trueVal && !falseVal) {
1779 if (!opResult.use_empty()) {
1780 opResult.replaceAllUsesWith(op.getCondition());
1815 Block *parent = nextIf->getBlock();
1816 if (nextIf == &parent->
front())
1819 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1827 Block *nextThen =
nullptr;
1828 Block *nextElse =
nullptr;
1829 if (nextIf.getCondition() == prevIf.getCondition()) {
1830 nextThen = nextIf.thenBlock();
1831 if (!nextIf.getElseRegion().empty())
1832 nextElse = nextIf.elseBlock();
1834 if (arith::XOrIOp notv =
1835 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1836 if (notv.getLhs() == prevIf.getCondition() &&
1838 nextElse = nextIf.thenBlock();
1839 if (!nextIf.getElseRegion().empty())
1840 nextThen = nextIf.elseBlock();
1843 if (arith::XOrIOp notv =
1844 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1845 if (notv.getLhs() == nextIf.getCondition() &&
1847 nextElse = nextIf.thenBlock();
1848 if (!nextIf.getElseRegion().empty())
1849 nextThen = nextIf.elseBlock();
1853 if (!nextThen && !nextElse)
1857 if (!prevIf.getElseRegion().empty())
1858 prevElseYielded = prevIf.elseYield().getOperands();
1861 for (
auto it : llvm::zip(prevIf.getResults(),
1862 prevIf.thenYield().getOperands(), prevElseYielded))
1864 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
1866 use.getOwner()->getParentRegion())) {
1868 use.set(std::get<1>(it));
1871 use.getOwner()->getParentRegion())) {
1873 use.set(std::get<2>(it));
1879 llvm::append_range(mergedTypes, nextIf.getResultTypes());
1881 IfOp combinedIf = rewriter.
create<IfOp>(
1882 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
1883 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
1886 combinedIf.getThenRegion(),
1887 combinedIf.getThenRegion().begin());
1890 YieldOp thenYield = combinedIf.thenYield();
1891 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
1892 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
1896 llvm::append_range(mergedYields, thenYield2.getOperands());
1897 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
1903 combinedIf.getElseRegion(),
1904 combinedIf.getElseRegion().begin());
1907 if (combinedIf.getElseRegion().empty()) {
1909 combinedIf.getElseRegion(),
1910 combinedIf.getElseRegion().begin());
1912 YieldOp elseYield = combinedIf.elseYield();
1913 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
1914 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
1919 llvm::append_range(mergedElseYields, elseYield2.getOperands());
1921 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
1930 if (pair.index() < prevIf.getNumResults())
1931 prevValues.push_back(pair.value());
1933 nextValues.push_back(pair.value());
1948 if (ifOp.getNumResults())
1950 Block *elseBlock = ifOp.elseBlock();
1951 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
1955 newIfOp.getThenRegion().begin());
1982 auto nestedOps = op.thenBlock()->without_terminator();
1984 if (!llvm::hasSingleElement(nestedOps))
1988 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
1991 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
1995 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2001 llvm::append_range(elseYield, op.elseYield().getOperands());
2015 if (tup.value().getDefiningOp() == nestedIf) {
2016 auto nestedIdx = tup.value().cast<
OpResult>().getResultNumber();
2017 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2018 elseYield[tup.index()]) {
2023 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2036 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2039 elseYieldsToUpgradeToSelect.push_back(tup.index());
2043 Value newCondition = rewriter.
create<arith::AndIOp>(
2044 loc, op.getCondition(), nestedIf.getCondition());
2045 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2048 llvm::append_range(results, newIf.getResults());
2051 for (
auto idx : elseYieldsToUpgradeToSelect)
2052 results[idx] = rewriter.
create<arith::SelectOp>(
2053 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2055 Block *newIfBlock = newIf.thenBlock();
2057 rewriter.
eraseOp(newIfBlock->getTerminator());
2059 newIfBlock = rewriter.
createBlock(&newIf.getThenRegion());
2060 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2063 if (!elseYield.empty()) {
2066 rewriter.
create<YieldOp>(loc, elseYield);
2077 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2078 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2079 RemoveStaticCondition, RemoveUnusedResults,
2080 ReplaceIfYieldWithConditionOrValue>(context);
2083 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2084 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2085 Block *IfOp::elseBlock() {
2086 Region &r = getElseRegion();
2091 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2097 void ParallelOp::build(
2107 ParallelOp::getOperandSegmentSizeAttr(),
2109 static_cast<int32_t>(upperBounds.size()),
2110 static_cast<int32_t>(steps.size()),
2111 static_cast<int32_t>(initVals.size())}));
2115 unsigned numIVs = steps.size();
2121 if (bodyBuilderFn) {
2123 bodyBuilderFn(builder, result.
location,
2127 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2130 void ParallelOp::build(
2137 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2140 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2144 wrapper = wrappedBuilderFn;
2146 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2155 if (stepValues.empty())
2157 "needs at least one tuple element for lowerBound, upperBound and step");
2160 for (
Value stepValue : stepValues)
2162 if (cst.value() <= 0)
2163 return emitOpError(
"constant step operand must be positive");
2167 Block *body = getBody();
2169 return emitOpError() <<
"expects the same number of induction variables: " 2171 <<
" as bound and step values: " << stepValues.size();
2173 if (!arg.getType().isIndex())
2175 "expects arguments for the induction variable to be of index type");
2180 return yield->
emitOpError() <<
"not allowed to have operands inside '" 2181 << ParallelOp::getOperationName() <<
"'";
2185 auto resultsSize = getResults().size();
2186 auto reductionsSize = reductions.size();
2187 auto initValsSize = getInitVals().size();
2188 if (resultsSize != reductionsSize)
2189 return emitOpError() <<
"expects number of results: " << resultsSize
2190 <<
" to be the same as number of reductions: " 2192 if (resultsSize != initValsSize)
2193 return emitOpError() <<
"expects number of results: " << resultsSize
2194 <<
" to be the same as number of initial values: " 2198 for (
auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2199 auto resultType = std::get<0>(resultAndReduce).getType();
2200 auto reduceOp = std::get<1>(resultAndReduce);
2201 auto reduceType = reduceOp.getOperand().getType();
2202 if (resultType != reduceType)
2203 return reduceOp.emitOpError()
2204 <<
"expects type of reduce: " << reduceType
2205 <<
" to be the same as result type: " << resultType;
2253 for (
auto &iv : ivs)
2260 ParallelOp::getOperandSegmentSizeAttr(),
2262 static_cast<int32_t>(upper.size()),
2263 static_cast<int32_t>(steps.size()),
2264 static_cast<int32_t>(initVals.size())}));
2273 ForOp::ensureTerminator(*body, builder, result.
location);
2278 p <<
" (" << getBody()->getArguments() <<
") = (" << getLowerBound()
2279 <<
") to (" << getUpperBound() <<
") step (" << getStep() <<
")";
2280 if (!getInitVals().empty())
2281 p <<
" init (" << getInitVals() <<
")";
2286 (*this)->getAttrs(),
2287 ParallelOp::getOperandSegmentSizeAttr());
2290 Region &ParallelOp::getLoopBody() {
return getRegion(); }
2295 return ParallelOp();
2296 assert(ivArg.getOwner() &&
"unlinked block argument");
2297 auto *containingOp = ivArg.getOwner()->getParentOp();
2298 return dyn_cast<ParallelOp>(containingOp);
2303 struct CollapseSingleIterationLoops :
public OpRewritePattern<ParallelOp> {
2313 newLowerBounds.reserve(op.getLowerBound().size());
2314 newUpperBounds.reserve(op.getUpperBound().size());
2315 newSteps.reserve(op.getStep().size());
2316 for (
auto [lowerBound, upperBound, step, iv] :
2317 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2318 op.getInductionVars())) {
2320 auto lowerBoundConstant =
2321 dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
2322 auto upperBoundConstant =
2323 dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
2325 dyn_cast_or_null<arith::ConstantIndexOp>(step.
getDefiningOp());
2328 if (lowerBoundConstant && upperBoundConstant && stepConstant &&
2329 (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
2330 (upperBoundConstant.value() - lowerBoundConstant.value()) <=
2331 stepConstant.value()) {
2332 mapping.
map(iv, lowerBound);
2334 newLowerBounds.push_back(lowerBound);
2335 newUpperBounds.push_back(upperBound);
2336 newSteps.push_back(step);
2340 if (newLowerBounds.size() == op.getLowerBound().size())
2343 if (newLowerBounds.empty()) {
2347 results.reserve(op.getInitVals().size());
2348 for (
auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2349 auto reduce = dyn_cast<ReduceOp>(bodyOp);
2351 rewriter.
clone(bodyOp, mapping);
2354 Block &reduceBlock = reduce.getReductionOperator().
front();
2355 auto initValIndex = results.size();
2356 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
2360 rewriter.
clone(reduceBodyOp, mapping);
2363 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
2364 results.push_back(result);
2371 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2372 newSteps, op.getInitVals(),
nullptr);
2376 newOp.getRegion().begin(), mapping);
2377 rewriter.
replaceOp(op, newOp.getResults());
2389 for (
auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
2390 if (std::get<0>(dim) == std::get<1>(dim)) {
2391 rewriter.
replaceOp(op, op.getInitVals());
2408 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
2413 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2414 llvm::is_contained(innerOp.getUpperBound(), val) ||
2415 llvm::is_contained(innerOp.getStep(), val))
2419 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2424 Block &innerBody = innerOp.getLoopBody().
front();
2425 assert(iterVals.size() ==
2433 builder.
clone(op, mapping);
2436 auto concatValues = [](
const auto &first,
const auto &second) {
2438 ret.reserve(first.size() + second.size());
2439 ret.assign(first.begin(), first.end());
2440 ret.append(second.begin(), second.end());
2444 auto newLowerBounds =
2445 concatValues(op.getLowerBound(), innerOp.getLowerBound());
2446 auto newUpperBounds =
2447 concatValues(op.getUpperBound(), innerOp.getUpperBound());
2448 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2451 newSteps, llvm::None, bodyBuilder);
2460 results.
add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2461 MergeNestedParallelLoops>(context);
2468 void ReduceOp::build(
2471 auto type = operand.
getType();
2485 auto type = getOperand().getType();
2486 Block &block = getReductionOperator().
front();
2488 return emitOpError(
"the block inside reduce should not be empty");
2491 return arg.getType() != type;
2493 return emitOpError() <<
"expects two arguments to reduce block of type " 2498 return emitOpError(
"the block inside reduce should be terminated with a " 2499 "'scf.reduce.return' op");
2526 p <<
"(" << getOperand() <<
") ";
2527 p <<
" : " << getOperand().getType() <<
' ';
2538 auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
2539 Type reduceType = reduceOp.getOperand().getType();
2540 if (reduceType != getResult().getType())
2541 return emitOpError() <<
"needs to have type " << reduceType
2542 <<
" (the type of the enclosing ReduceOp)";
2551 assert(index && *index == 0 &&
2552 "WhileOp is expected to branch only to the first region");
2557 ConditionOp WhileOp::getConditionOp() {
2558 return cast<ConditionOp>(getBefore().front().getTerminator());
2561 YieldOp WhileOp::getYieldOp() {
2562 return cast<YieldOp>(getAfter().front().getTerminator());
2566 return getBefore().front().getArguments();
2570 return getAfter().front().getArguments();
2578 regions.emplace_back(&getBefore(), getBefore().getArguments());
2582 assert(*index < 2 &&
"there are only two regions in a WhileOp");
2585 regions.emplace_back(&getBefore(), getBefore().getArguments());
2590 assert(!operands.empty() &&
"expected at least one operand");
2591 auto cond = operands[0].dyn_cast_or_null<
BoolAttr>();
2592 if (!cond || !cond.getValue())
2593 regions.emplace_back(getResults());
2594 if (!cond || cond.getValue())
2595 regions.emplace_back(&getAfter(), getAfter().getArguments());
2616 FunctionType functionType;
2621 result.
addTypes(functionType.getResults());
2623 if (functionType.getNumInputs() != operands.size()) {
2625 <<
"expected as many input types as operands " 2626 <<
"(expected " << operands.size() <<
" got " 2627 << functionType.getNumInputs() <<
")";
2637 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
2638 regionArgs[i].type = functionType.getInput(i);
2661 template <
typename OpTy>
2664 if (left.size() != right.size())
2665 return op.emitOpError(
"expects the same number of ") << message;
2667 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
2668 if (left[i] != right[i]) {
2671 diag.
attachNote() <<
"for argument " << i <<
", found " << left[i]
2672 <<
" and " << right[i];
2682 template <
typename TerminatorTy>
2683 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op,
Region ®ion,
2684 StringRef errorMessage) {
2686 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2689 auto diag = op.emitOpError(errorMessage);
2690 if (terminatorOperation)
2691 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
2696 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2698 "expects the 'before' region to terminate with 'scf.condition'");
2699 if (!beforeTerminator)
2702 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2704 "expects the 'after' region to terminate with 'scf.yield'");
2705 return success(afterTerminator !=
nullptr);
2733 auto term = op.getConditionOp();
2737 Value constantTrue =
nullptr;
2739 bool replaced =
false;
2740 for (
auto yieldedAndBlockArgs :
2741 llvm::zip(term.getArgs(), op.getAfterArguments())) {
2742 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2743 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2745 constantTrue = rewriter.
create<arith::ConstantOp>(
2746 op.getLoc(), term.getCondition().getType(),
2749 std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2806 struct RemoveLoopInvariantArgsFromBeforeBlock
2814 ConditionOp condOp = op.getConditionOp();
2819 bool canSimplify =
false;
2820 for (
const auto &it :
2822 auto index =
static_cast<unsigned>(it.index());
2823 auto [initVal, yieldOpArg] = it.value();
2826 if (yieldOpArg == initVal) {
2835 auto yieldOpBlockArg = yieldOpArg.dyn_cast<
BlockArgument>();
2836 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2837 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2838 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2851 for (
const auto &it :
2853 auto index =
static_cast<unsigned>(it.index());
2854 auto [initVal, yieldOpArg] = it.value();
2858 if (yieldOpArg == initVal) {
2859 beforeBlockInitValMap.insert({index, initVal});
2867 auto yieldOpBlockArg = yieldOpArg.dyn_cast<
BlockArgument>();
2868 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2869 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2870 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2871 beforeBlockInitValMap.insert({index, initVal});
2876 newInitArgs.emplace_back(initVal);
2877 newYieldOpArgs.emplace_back(yieldOpArg);
2878 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
2888 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
2891 &newWhile.getBefore(), {},
2901 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
2904 if (beforeBlockInitValMap.count(i) != 0)
2905 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
2907 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
2910 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
2912 newWhile.getAfter().begin());
2914 rewriter.
replaceOp(op, newWhile.getResults());
2959 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
2965 ConditionOp condOp = op.getConditionOp();
2968 bool canSimplify =
false;
2969 for (
Value condOpArg : condOpArgs) {
2973 if (condOpArg.getParentBlock() != &beforeBlock) {
2989 auto index =
static_cast<unsigned>(it.index());
2990 Value condOpArg = it.value();
2995 condOpInitValMap.insert({index, condOpArg});
2997 newCondOpArgs.emplace_back(condOpArg);
2998 newAfterBlockType.emplace_back(condOpArg.
getType());
2999 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3010 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3013 Block &newAfterBlock =
3015 newAfterBlockType, newAfterBlockArgLocs);
3024 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3025 Value afterBlockArg, result;
3028 if (condOpInitValMap.count(i) != 0) {
3029 afterBlockArg = condOpInitValMap[i];
3030 result = afterBlockArg;
3032 afterBlockArg = newAfterBlock.getArgument(
j);
3033 result = newWhile.getResult(
j);
3036 newAfterBlockArgs[i] = afterBlockArg;
3037 newWhileResults[i] = result;
3040 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3042 newWhile.getBefore().begin());
3044 rewriter.
replaceOp(op, newWhileResults);
3080 auto term = op.getConditionOp();
3081 auto afterArgs = op.getAfterArguments();
3082 auto termArgs = term.getArgs();
3089 bool needUpdate =
false;
3090 for (
const auto &it :
3092 auto i =
static_cast<unsigned>(it.index());
3093 Value result = std::get<0>(it.value());
3094 Value afterArg = std::get<1>(it.value());
3095 Value termArg = std::get<2>(it.value());
3096 if (result.
use_empty() && afterArg.use_empty()) {
3099 newResultsIndices.emplace_back(i);
3100 newTermArgs.emplace_back(termArg);
3101 newResultTypes.emplace_back(result.
getType());
3102 newArgLocs.emplace_back(result.
getLoc());
3117 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3120 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3127 newResults[it.value()] = newWhile.getResult(it.index());
3128 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3132 newWhile.getBefore().begin());
3135 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3169 using namespace scf;
3170 auto cond = op.getConditionOp();
3171 auto cmp = cond.getCondition().
getDefiningOp<arith::CmpIOp>();
3174 bool changed =
false;
3176 llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
3177 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3178 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3181 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3182 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3186 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3189 if (cmp2.getPredicate() == cmp.getPredicate())
3190 samePredicate =
true;
3191 else if (cmp2.getPredicate() ==
3193 samePredicate =
false;
3213 if (!llvm::any_of(op.getBeforeArguments(),
3214 [](
Value arg) {
return arg.use_empty(); }))
3217 YieldOp yield = op.getYieldOp();
3224 op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
3225 Value beforeArg = std::get<0>(it.value());
3226 Value yieldValue = std::get<1>(it.value());
3227 Value initValue = std::get<2>(it.value());
3229 argsToErase.push_back(it.index());
3231 newYields.emplace_back(yieldValue);
3232 newInits.emplace_back(initValue);
3236 if (argsToErase.empty())
3240 op.getBefore().front().eraseArguments(argsToErase);
3243 WhileOp replacement =
3244 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
3245 replacement.getBefore().takeBody(op.getBefore());
3246 replacement.getAfter().takeBody(op.getAfter());
3247 rewriter.
replaceOp(op, replacement.getResults());
3258 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
3259 RemoveLoopInvariantValueYielded, WhileConditionTruth,
3260 WhileCmpCond, WhileUnusedResult>(context);
3267 #define GET_OP_CLASSES 3268 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
Include the generated interface declarations.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static std::string diag(llvm::Value &v)
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is a basic unit of execution within MLIR.
This is a value defined by a result of an operation.
Specialization of arith.constant op that returns an integer value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
This class represents a diagnostic that is inflight and set to be reported.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
void push_back(Block *block)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
std::vector< Value > ValueVector
An owning vector of values, handy to return from functions.
unsigned getNumOperands()
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp...
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
This is the representation of an operand reference.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
unsigned getArgNumber() const
Returns the number of this argument.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
BlockArgument getArgument(unsigned i)
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
Region * getParentRegion()
Return the Region in which this Value is defined.
void addOperands(ValueRange newOperands)
IntegerAttr getIntegerAttr(Type type, int64_t value)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
unsigned getNumArguments()
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
type_range getTypes() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DialectInlinerInterface(Dialect *dialect)
unsigned getResultNumber() const
Returns the number of this result.
Block * getParentBlock()
Return the Block in which this Value is defined.
IntegerType getIntegerType(unsigned width)
virtual ParseResult parseRParen()=0
Parse a ) token.
This is the interface that must be implemented by the dialects of operations to be inlined...
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
This class provides an abstraction over the various different ranges of value types.
void addTypes(ArrayRef< Type > newTypes)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
This class provides a mutable adaptor for a range of operands.
Location getLoc()
The source location the operation was defined or derived from.
IRValueT get() const
Return the current value being used by this operand.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val)
Returns the ForeachThreadOp parent of an thread index variable.
Parens surrounding zero or more operands.
BlockArgListType getArguments()
This class represents an argument of a Block.
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
Location getLoc() const
Return the location of this value.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
bool use_empty() const
Returns true if this value has no uses.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class implements Optional functionality for ParseResult.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Operation * getTerminator()
Get the terminator operation of this block.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
Type getType() const
Return the type of this value.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'. ...
Specialization of arith.constant op that returns an integer of index type.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
ParseResult value() const
Access the internal ParseResult value.
BoolAttr getBoolAttr(bool value)
Operation * getOwner() const
Return the owner of this operand.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2, <...>) where 'inner' values are assumed to be region arguments and 'outer' values are regular SSA values.
This class represents an operand of an operation.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class implements the operand iterators for the Operation class.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseEqual()=0
Parse a = token.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers...
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class represents success/failure for parsing-like operations that find it important to chain tog...
void setAttrs(DictionaryAttr newAttrs)
Set the attribute dictionary on this operation.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.