25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77 InParallelOp, ReduceReturnOp>();
78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80 ForallOp, InParallelOp, WhileOp, YieldOp>();
81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
86 builder.
create<scf::YieldOp>(loc);
91 template <
typename TerminatorTy>
93 StringRef errorMessage) {
96 terminatorOperation = ®ion.
front().
back();
97 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
101 if (terminatorOperation)
102 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
114 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
160 if (getRegion().empty())
161 return emitOpError(
"region needs to have at least one block");
162 if (getRegion().front().getNumArguments() > 0)
163 return emitOpError(
"region cannot have any arguments");
186 if (!llvm::hasSingleElement(op.
getRegion()))
235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->
getParentOp()))
245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248 yieldOp.getResults());
257 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
269 void ExecuteRegionOp::getSuccessorRegions(
287 assert((point.
isParent() || point == getParentOp().getAfter()) &&
288 "condition op can only exit the loop or branch to the after"
291 return getArgsMutable();
294 void ConditionOp::getSuccessorRegions(
296 FoldAdaptor adaptor(operands, *
this);
298 WhileOp whileOp = getParentOp();
302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303 if (!boolAttr || boolAttr.getValue())
304 regions.emplace_back(&whileOp.getAfter(),
305 whileOp.getAfter().getArguments());
306 if (!boolAttr || !boolAttr.getValue())
307 regions.emplace_back(whileOp.getResults());
316 BodyBuilderFn bodyBuilder) {
321 for (
Value v : iterArgs)
327 for (
Value v : iterArgs)
333 if (iterArgs.empty() && !bodyBuilder) {
334 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
335 }
else if (bodyBuilder) {
345 if (getInitArgs().size() != getNumResults())
347 "mismatch in number of loop-carried values and defined values");
352 LogicalResult ForOp::verifyRegions() {
357 "expected induction variable to be same type as bounds and step");
359 if (getNumRegionIterArgs() != getNumResults())
361 "mismatch in number of basic block args and defined values");
363 auto initArgs = getInitArgs();
364 auto iterArgs = getRegionIterArgs();
365 auto opResults = getResults();
367 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 return emitOpError() <<
"types mismatch between " << i
370 <<
"th iter operand and defined value";
372 return emitOpError() <<
"types mismatch between " << i
373 <<
"th iter region arg and defined value";
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
401 std::optional<int64_t> tripCount =
403 if (!tripCount.has_value() || tripCount != 1)
407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
414 llvm::append_range(bbArgReplacements, getInitArgs());
418 getOperation()->getIterator(), bbArgReplacements);
434 StringRef prefix =
"") {
435 assert(blocksArgs.size() == initializers.size() &&
436 "expected same length of arguments and initializers");
437 if (initializers.empty())
441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
442 p << std::get<0>(it) <<
" = " << std::get<1>(it);
448 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
452 if (!getInitArgs().empty())
453 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 p <<
" : " << t <<
' ';
459 !getInitArgs().empty());
481 regionArgs.push_back(inductionVariable);
491 if (regionArgs.size() != result.
types.size() + 1)
494 "mismatch in number of loop-carried values and defined values");
503 regionArgs.front().type = type;
509 for (
auto argOperandType :
510 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
511 Type type = std::get<2>(argOperandType);
512 std::get<0>(argOperandType).type = type;
524 ForOp::ensureTerminator(*body, builder, result.
location);
536 return getBody()->getArguments().drop_front(getNumInductionVars());
540 return getInitArgsMutable();
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
546 bool replaceInitOperandUsesInLoop,
551 auto inits = llvm::to_vector(getInitArgs());
552 inits.append(newInitOperands.begin(), newInitOperands.end());
553 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
559 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
561 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
566 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567 assert(newInitOperands.size() == newYieldedValues.size() &&
568 "expected as many new yield values as new iter operands");
570 yieldOp.getResultsMutable().append(newYieldedValues);
576 newLoop.getBody()->getArguments().take_front(
577 getBody()->getNumArguments()));
579 if (replaceInitOperandUsesInLoop) {
582 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
593 newLoop->getResults().take_front(getNumResults()));
594 return cast<LoopLikeOpInterface>(newLoop.getOperation());
598 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
601 assert(ivArg.getOwner() &&
"unlinked block argument");
602 auto *containingOp = ivArg.getOwner()->getParentOp();
603 return dyn_cast_or_null<ForOp>(containingOp);
607 return getInitArgs();
624 for (
auto [lb, ub, step] :
625 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
627 if (!tripCount.has_value() || *tripCount != 1)
636 return getBody()->getArguments().drop_front(getRank());
640 return getOutputsMutable();
646 scf::InParallelOp terminator = forallOp.getTerminator();
651 bbArgReplacements.append(forallOp.getOutputs().begin(),
652 forallOp.getOutputs().end());
656 forallOp->getIterator(), bbArgReplacements);
661 results.reserve(forallOp.getResults().size());
662 for (
auto &yieldingOp : terminator.getYieldingOps()) {
663 auto parallelInsertSliceOp =
664 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
666 Value dst = parallelInsertSliceOp.getDest();
667 Value src = parallelInsertSliceOp.getSource();
668 if (llvm::isa<TensorType>(src.
getType())) {
669 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
670 forallOp.getLoc(), dst.
getType(), src, dst,
671 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672 parallelInsertSliceOp.getStrides(),
673 parallelInsertSliceOp.getStaticOffsets(),
674 parallelInsertSliceOp.getStaticSizes(),
675 parallelInsertSliceOp.getStaticStrides()));
677 llvm_unreachable(
"unsupported terminator");
692 assert(lbs.size() == ubs.size() &&
693 "expected the same number of lower and upper bounds");
694 assert(lbs.size() == steps.size() &&
695 "expected the same number of lower bounds and steps");
700 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
702 assert(results.size() == iterArgs.size() &&
703 "loop nest body must return as many values as loop has iteration "
705 return LoopNest{{}, std::move(results)};
713 loops.reserve(lbs.size());
714 ivs.reserve(lbs.size());
717 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
718 auto loop = builder.
create<scf::ForOp>(
719 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725 currentIterArgs = args;
726 currentLoc = nestedLoc;
732 loops.push_back(loop);
736 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
738 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
745 ? bodyBuilder(builder, currentLoc, ivs,
746 loops.back().getRegionIterArgs())
748 assert(results.size() == iterArgs.size() &&
749 "loop nest body must return as many values as loop has iteration "
752 builder.
create<scf::YieldOp>(loc, results);
756 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757 return LoopNest{std::move(loops), std::move(nestResults)};
765 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
770 bodyBuilder(nestedBuilder, nestedLoc, ivs);
792 LogicalResult matchAndRewrite(scf::ForOp forOp,
794 bool canonicalize =
false;
801 int64_t numResults = forOp.getNumResults();
803 keepMask.reserve(numResults);
806 newBlockTransferArgs.reserve(1 + numResults);
807 newBlockTransferArgs.push_back(
Value());
808 newIterArgs.reserve(forOp.getInitArgs().size());
809 newYieldValues.reserve(numResults);
810 newResultValues.reserve(numResults);
811 for (
auto it : llvm::zip(forOp.getInitArgs(),
812 forOp.getRegionIterArgs(),
814 forOp.getYieldedValues()
822 bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
823 (std::get<1>(it).use_empty() &&
824 (std::get<0>(it) == std::get<3>(it) ||
825 std::get<2>(it).use_empty())));
826 keepMask.push_back(!forwarded);
827 canonicalize |= forwarded;
829 newBlockTransferArgs.push_back(std::get<0>(it));
830 newResultValues.push_back(std::get<0>(it));
833 newIterArgs.push_back(std::get<0>(it));
834 newYieldValues.push_back(std::get<3>(it));
835 newBlockTransferArgs.push_back(
Value());
836 newResultValues.push_back(
Value());
842 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
843 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
844 forOp.getStep(), newIterArgs);
845 newForOp->
setAttrs(forOp->getAttrs());
846 Block &newBlock = newForOp.getRegion().
front();
850 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
852 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
853 Value &newResultVal = newResultValues[idx];
854 assert((blockTransferArg && newResultVal) ||
855 (!blockTransferArg && !newResultVal));
856 if (!blockTransferArg) {
857 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
858 newResultVal = newForOp.getResult(collapsedIdx++);
864 "unexpected argument size mismatch");
869 if (newIterArgs.empty()) {
870 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
873 rewriter.
replaceOp(forOp, newResultValues);
878 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
882 filteredOperands.reserve(newResultValues.size());
883 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
885 filteredOperands.push_back(mergedTerminator.getOperand(idx));
886 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
890 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
891 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
892 cloneFilteredTerminator(mergedYieldOp);
893 rewriter.
eraseOp(mergedYieldOp);
894 rewriter.
replaceOp(forOp, newResultValues);
902 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
903 IntegerAttr clb, cub;
905 llvm::APInt lbValue = clb.getValue();
906 llvm::APInt ubValue = cub.getValue();
907 return (ubValue - lbValue).getSExtValue();
916 return diff.getSExtValue();
926 LogicalResult matchAndRewrite(ForOp op,
930 if (op.getLowerBound() == op.getUpperBound()) {
931 rewriter.
replaceOp(op, op.getInitArgs());
935 std::optional<int64_t> diff =
936 computeConstDiff(op.getLowerBound(), op.getUpperBound());
942 rewriter.
replaceOp(op, op.getInitArgs());
946 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
952 llvm::APInt stepValue = *maybeStepValue;
953 if (stepValue.sge(*diff)) {
955 blockArgs.reserve(op.getInitArgs().size() + 1);
956 blockArgs.push_back(op.getLowerBound());
957 llvm::append_range(blockArgs, op.getInitArgs());
964 if (!llvm::hasSingleElement(block))
968 if (llvm::any_of(op.getYieldedValues(),
969 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
971 rewriter.
replaceOp(op, op.getYieldedValues());
983 assert(llvm::isa<RankedTensorType>(oldType) &&
984 llvm::isa<RankedTensorType>(newType) &&
985 "expected ranked tensor types");
988 ForOp forOp = cast<ForOp>(operand.
getOwner());
990 "expected an iter OpOperand");
992 "Expected a different type");
994 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
996 newIterOperands.push_back(replacement);
999 newIterOperands.push_back(opOperand.get());
1003 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
1004 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1005 forOp.getStep(), newIterOperands);
1006 newForOp->
setAttrs(forOp->getAttrs());
1007 Block &newBlock = newForOp.getRegion().
front();
1015 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
1017 Value castIn = rewriter.
create<tensor::CastOp>(newForOp.getLoc(), oldType,
1019 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
1023 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1026 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1029 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
1031 newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1033 newYieldOperands[yieldIdx] = castOut;
1034 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1035 rewriter.
eraseOp(clonedYieldOp);
1040 newResults[yieldIdx] = rewriter.
create<tensor::CastOp>(
1041 newForOp.getLoc(), oldType, newResults[yieldIdx]);
1075 LogicalResult matchAndRewrite(ForOp op,
1077 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.
getResults())) {
1078 OpOperand &iterOpOperand = std::get<0>(it);
1080 if (!incomingCast ||
1081 incomingCast.getSource().getType() == incomingCast.getType())
1086 incomingCast.getDest().getType(),
1087 incomingCast.getSource().getType()))
1089 if (!std::get<1>(it).hasOneUse())
1094 op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1095 incomingCast.getSource()));
1106 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1110 std::optional<APInt> ForOp::getConstantStep() {
1113 return step.getValue();
1117 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1118 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1124 if (
auto constantStep = getConstantStep())
1125 if (*constantStep == 1)
1138 unsigned numLoops = getRank();
1140 if (getNumResults() != getOutputs().size())
1141 return emitOpError(
"produces ")
1142 << getNumResults() <<
" results, but has only "
1143 << getOutputs().size() <<
" outputs";
1146 auto *body = getBody();
1148 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1149 for (int64_t i = 0; i < numLoops; ++i)
1151 return emitOpError(
"expects ")
1152 << i <<
"-th block argument to be an index";
1153 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1155 return emitOpError(
"type mismatch between ")
1156 << i <<
"-th output and corresponding block argument";
1157 if (getMapping().has_value() && !getMapping()->empty()) {
1158 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1159 return emitOpError() <<
"mapping attribute size must match op rank";
1160 for (
auto map : getMapping()->getValue()) {
1161 if (!isa<DeviceMappingAttrInterface>(map))
1162 return emitOpError()
1170 getStaticLowerBound(),
1171 getDynamicLowerBound())))
1174 getStaticUpperBound(),
1175 getDynamicUpperBound())))
1178 getStaticStep(), getDynamicStep())))
1186 p <<
" (" << getInductionVars();
1187 if (isNormalized()) {
1208 if (!getRegionOutArgs().empty())
1209 p <<
"-> (" << getResultTypes() <<
") ";
1210 p.printRegion(getRegion(),
1212 getNumResults() > 0);
1213 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1214 getStaticLowerBoundAttrName(),
1215 getStaticUpperBoundAttrName(),
1216 getStaticStepAttrName()});
1221 auto indexType = b.getIndexType();
1241 unsigned numLoops = ivs.size();
1276 if (outOperands.size() != result.
types.size())
1278 "mismatch between out operands and types");
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1289 for (
auto &iv : ivs) {
1290 iv.type = b.getIndexType();
1291 regionArgs.push_back(iv);
1294 auto &out = it.value();
1295 out.type = result.
types[it.index()];
1296 regionArgs.push_back(out);
1302 ForallOp::ensureTerminator(*region, b, result.
location);
1314 {static_cast<int32_t>(dynamicLbs.size()),
1315 static_cast<int32_t>(dynamicUbs.size()),
1316 static_cast<int32_t>(dynamicSteps.size()),
1317 static_cast<int32_t>(outOperands.size())}));
1322 void ForallOp::build(
1326 std::optional<ArrayAttr> mapping,
1347 "operandSegmentSizes",
1349 static_cast<int32_t>(dynamicUbs.size()),
1350 static_cast<int32_t>(dynamicSteps.size()),
1351 static_cast<int32_t>(outputs.size())}));
1352 if (mapping.has_value()) {
1371 if (!bodyBuilderFn) {
1372 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1379 void ForallOp::build(
1382 std::optional<ArrayAttr> mapping,
1384 unsigned numLoops = ubs.size();
1387 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1391 bool ForallOp::isNormalized() {
1395 return intValue.has_value() && intValue == val;
1398 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1407 ForallOp>::ensureTerminator(region, builder, loc);
1414 InParallelOp ForallOp::getTerminator() {
1415 return cast<InParallelOp>(getBody()->getTerminator());
1420 InParallelOp inParallelOp = getTerminator();
1421 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1422 if (
auto parallelInsertSliceOp =
1423 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1424 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1425 storeOps.push_back(parallelInsertSliceOp);
1431 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1436 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1438 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1442 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1444 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1448 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1454 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1457 assert(tidxArg.getOwner() &&
"unlinked block argument");
1458 auto *containingOp = tidxArg.getOwner()->getParentOp();
1459 return dyn_cast<ForallOp>(containingOp);
1467 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1469 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1473 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1476 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1485 LogicalResult matchAndRewrite(ForallOp op,
1500 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1501 op.setStaticLowerBound(staticLowerBound);
1505 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1506 op.setStaticUpperBound(staticUpperBound);
1509 op.getDynamicStepMutable().assign(dynamicStep);
1510 op.setStaticStep(staticStep);
1512 op->
setAttr(ForallOp::getOperandSegmentSizeAttr(),
1514 {static_cast<int32_t>(dynamicLowerBound.size()),
1515 static_cast<int32_t>(dynamicUpperBound.size()),
1516 static_cast<int32_t>(dynamicStep.size()),
1517 static_cast<int32_t>(op.getNumResults())}));
1599 LogicalResult matchAndRewrite(ForallOp forallOp,
1618 for (
OpResult result : forallOp.getResults()) {
1619 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1620 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1621 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1622 resultToDelete.insert(result);
1624 resultToReplace.push_back(result);
1625 newOuts.push_back(opOperand->
get());
1631 if (resultToDelete.empty())
1639 for (
OpResult result : resultToDelete) {
1640 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1641 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1643 forallOp.getCombiningOps(blockArg);
1644 for (
Operation *combiningOp : combiningOps)
1645 rewriter.
eraseOp(combiningOp);
1650 auto newForallOp = rewriter.
create<scf::ForallOp>(
1651 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1652 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1653 forallOp.getMapping(),
1658 Block *loopBody = forallOp.getBody();
1659 Block *newLoopBody = newForallOp.getBody();
1664 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1671 for (
OpResult result : forallOp.getResults()) {
1672 if (resultToDelete.count(result)) {
1673 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1675 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1678 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1682 for (
auto &&[oldResult, newResult] :
1683 llvm::zip(resultToReplace, newForallOp->getResults()))
1689 for (
OpResult oldResult : resultToDelete)
1691 forallOp.getTiedOpOperand(oldResult)->get());
1696 struct ForallOpSingleOrZeroIterationDimsFolder
1700 LogicalResult matchAndRewrite(ForallOp op,
1703 if (op.getMapping().has_value())
1711 for (
auto [lb, ub, step, iv] :
1712 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1713 op.getMixedStep(), op.getInductionVars())) {
1715 if (numIterations.has_value()) {
1717 if (*numIterations == 0) {
1718 rewriter.
replaceOp(op, op.getOutputs());
1723 if (*numIterations == 1) {
1728 newMixedLowerBounds.push_back(lb);
1729 newMixedUpperBounds.push_back(ub);
1730 newMixedSteps.push_back(step);
1733 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1735 op,
"no dimensions have 0 or 1 iterations");
1739 if (newMixedLowerBounds.empty()) {
1746 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1747 newMixedUpperBounds, newMixedSteps,
1748 op.getOutputs(), std::nullopt,
nullptr);
1749 newOp.getBodyRegion().getBlocks().clear();
1754 newOp.getStaticLowerBoundAttrName(),
1755 newOp.getStaticUpperBoundAttrName(),
1756 newOp.getStaticStepAttrName()};
1757 for (
const auto &namedAttr : op->
getAttrs()) {
1758 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1761 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1765 newOp.getRegion().
begin(), mapping);
1766 rewriter.
replaceOp(op, newOp.getResults());
1771 struct FoldTensorCastOfOutputIntoForallOp
1780 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1782 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1785 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1792 castOp.getSource().getType())) {
1796 tensorCastProducers[en.index()] =
1797 TypeCast{castOp.getSource().getType(), castOp.getType()};
1798 newOutputTensors[en.index()] = castOp.getSource();
1801 if (tensorCastProducers.empty())
1806 auto newForallOp = rewriter.
create<ForallOp>(
1807 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1808 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1810 auto castBlockArgs =
1811 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1812 for (
auto [index, cast] : tensorCastProducers) {
1813 Value &oldTypeBBArg = castBlockArgs[index];
1814 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1815 nestedLoc, cast.dstType, oldTypeBBArg);
1820 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1821 ivsBlockArgs.append(castBlockArgs);
1823 bbArgs.front().getParentBlock(), ivsBlockArgs);
1829 auto terminator = newForallOp.getTerminator();
1830 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1831 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1832 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1833 insertSliceOp.getDestMutable().assign(outputBlockArg);
1839 for (
auto &item : tensorCastProducers) {
1840 Value &oldTypeResult = castResults[item.first];
1841 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1844 rewriter.
replaceOp(forallOp, castResults);
1853 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1854 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1855 ForallOpSingleOrZeroIterationDimsFolder>(context);
1884 scf::ForallOp forallOp =
1885 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1887 return this->emitOpError(
"expected forall op parent");
1890 for (
Operation &op : getRegion().front().getOperations()) {
1891 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1892 return this->emitOpError(
"expected only ")
1893 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1897 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1899 if (!llvm::is_contained(regionOutArgs, dest))
1900 return op.
emitOpError(
"may only insert into an output block argument");
1917 std::unique_ptr<Region> region = std::make_unique<Region>();
1921 if (region->empty())
1931 OpResult InParallelOp::getParentResult(int64_t idx) {
1932 return getOperation()->getParentOp()->getResult(idx);
1936 return llvm::to_vector<4>(
1937 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1939 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1940 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1945 return getRegion().front().getOperations();
1953 assert(a &&
"expected non-empty operation");
1954 assert(b &&
"expected non-empty operation");
1959 if (ifOp->isProperAncestor(b))
1962 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1963 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1965 ifOp = ifOp->getParentOfType<IfOp>();
1973 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1974 IfOp::Adaptor adaptor,
1976 if (adaptor.getRegions().empty())
1978 Region *r = &adaptor.getThenRegion();
1981 Block &b = r->front();
1984 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
1987 TypeRange types = yieldOp.getOperandTypes();
1988 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1995 return build(builder, result, resultTypes, cond,
false,
2001 bool addElseBlock) {
2002 assert((!addElseBlock || addThenBlock) &&
2003 "must not create else block w/o then block");
2018 bool withElseRegion) {
2019 build(builder, result,
TypeRange{}, cond, withElseRegion);
2031 if (resultTypes.empty())
2032 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2036 if (withElseRegion) {
2038 if (resultTypes.empty())
2039 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2046 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2053 thenBuilder(builder, result.
location);
2059 elseBuilder(builder, result.
location);
2066 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2068 inferredReturnTypes))) {
2069 result.
addTypes(inferredReturnTypes);
2074 if (getNumResults() != 0 && getElseRegion().empty())
2075 return emitOpError(
"must have an else block if defining values");
2113 bool printBlockTerminators =
false;
2115 p <<
" " << getCondition();
2116 if (!getResults().empty()) {
2117 p <<
" -> (" << getResultTypes() <<
")";
2119 printBlockTerminators =
true;
2124 printBlockTerminators);
2127 auto &elseRegion = getElseRegion();
2128 if (!elseRegion.
empty()) {
2132 printBlockTerminators);
2149 Region *elseRegion = &this->getElseRegion();
2150 if (elseRegion->
empty())
2158 FoldAdaptor adaptor(operands, *
this);
2159 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2160 if (!boolAttr || boolAttr.getValue())
2161 regions.emplace_back(&getThenRegion());
2164 if (!boolAttr || !boolAttr.getValue()) {
2165 if (!getElseRegion().empty())
2166 regions.emplace_back(&getElseRegion());
2168 regions.emplace_back(getResults());
2172 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2175 if (getElseRegion().empty())
2178 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2185 getConditionMutable().assign(xorStmt.getLhs());
2189 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2190 getElseRegion().getBlocks());
2191 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2192 getThenRegion().getBlocks(), thenBlock);
2196 void IfOp::getRegionInvocationBounds(
2199 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2202 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2203 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2206 invocationBounds.assign(2, {0, 1});
2222 llvm::transform(usedResults, std::back_inserter(usedOperands),
2227 [&]() { yieldOp->setOperands(usedOperands); });
2230 LogicalResult matchAndRewrite(IfOp op,
2234 llvm::copy_if(op.
getResults(), std::back_inserter(usedResults),
2235 [](
OpResult result) { return !result.use_empty(); });
2243 llvm::transform(usedResults, std::back_inserter(newTypes),
2248 rewriter.
create<IfOp>(op.
getLoc(), newTypes, op.getCondition());
2254 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2255 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2260 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2269 LogicalResult matchAndRewrite(IfOp op,
2277 else if (!op.getElseRegion().empty())
2291 LogicalResult matchAndRewrite(IfOp op,
2296 auto cond = op.getCondition();
2297 auto thenYieldArgs = op.thenYield().
getOperands();
2298 auto elseYieldArgs = op.elseYield().
getOperands();
2301 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2304 nonHoistable.push_back(trueVal.getType());
2311 IfOp replacement = rewriter.
create<IfOp>(op.
getLoc(), nonHoistable, cond,
2313 if (replacement.thenBlock())
2314 rewriter.
eraseBlock(replacement.thenBlock());
2315 replacement.getThenRegion().takeBody(op.getThenRegion());
2316 replacement.getElseRegion().takeBody(op.getElseRegion());
2319 assert(thenYieldArgs.size() == results.size());
2320 assert(elseYieldArgs.size() == results.size());
2325 for (
const auto &it :
2327 Value trueVal = std::get<0>(it.value());
2328 Value falseVal = std::get<1>(it.value());
2331 results[it.index()] = replacement.getResult(trueYields.size());
2332 trueYields.push_back(trueVal);
2333 falseYields.push_back(falseVal);
2334 }
else if (trueVal == falseVal)
2335 results[it.index()] = trueVal;
2337 results[it.index()] = rewriter.
create<arith::SelectOp>(
2338 op.
getLoc(), cond, trueVal, falseVal);
2368 LogicalResult matchAndRewrite(IfOp op,
2375 bool changed =
false;
2380 Value constantTrue =
nullptr;
2381 Value constantFalse =
nullptr;
2384 llvm::make_early_inc_range(op.getCondition().
getUses())) {
2389 constantTrue = rewriter.
create<arith::ConstantOp>(
2393 [&]() { use.
set(constantTrue); });
2399 constantFalse = rewriter.
create<arith::ConstantOp>(
2403 [&]() { use.
set(constantFalse); });
2407 return success(changed);
2447 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2450 LogicalResult matchAndRewrite(IfOp op,
2457 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2459 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2462 op.getOperation()->getIterator());
2463 bool changed =
false;
2465 for (
auto [trueResult, falseResult, opResult] :
2466 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2468 if (trueResult == falseResult) {
2469 if (!opResult.use_empty()) {
2470 opResult.replaceAllUsesWith(trueResult);
2481 bool trueVal = trueYield.
getValue();
2482 bool falseVal = falseYield.
getValue();
2483 if (!trueVal && falseVal) {
2484 if (!opResult.use_empty()) {
2485 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2487 op.
getLoc(), op.getCondition(),
2497 if (trueVal && !falseVal) {
2498 if (!opResult.use_empty()) {
2499 opResult.replaceAllUsesWith(op.getCondition());
2504 return success(changed);
2532 LogicalResult matchAndRewrite(IfOp nextIf,
2534 Block *parent = nextIf->getBlock();
2535 if (nextIf == &parent->
front())
2538 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2546 Block *nextThen =
nullptr;
2547 Block *nextElse =
nullptr;
2548 if (nextIf.getCondition() == prevIf.getCondition()) {
2549 nextThen = nextIf.thenBlock();
2550 if (!nextIf.getElseRegion().empty())
2551 nextElse = nextIf.elseBlock();
2553 if (arith::XOrIOp notv =
2554 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2555 if (notv.getLhs() == prevIf.getCondition() &&
2557 nextElse = nextIf.thenBlock();
2558 if (!nextIf.getElseRegion().empty())
2559 nextThen = nextIf.elseBlock();
2562 if (arith::XOrIOp notv =
2563 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2564 if (notv.getLhs() == nextIf.getCondition() &&
2566 nextElse = nextIf.thenBlock();
2567 if (!nextIf.getElseRegion().empty())
2568 nextThen = nextIf.elseBlock();
2572 if (!nextThen && !nextElse)
2576 if (!prevIf.getElseRegion().empty())
2577 prevElseYielded = prevIf.elseYield().getOperands();
2580 for (
auto it : llvm::zip(prevIf.getResults(),
2581 prevIf.thenYield().getOperands(), prevElseYielded))
2583 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2587 use.
set(std::get<1>(it));
2592 use.
set(std::get<2>(it));
2598 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2600 IfOp combinedIf = rewriter.
create<IfOp>(
2601 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2602 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2605 combinedIf.getThenRegion(),
2606 combinedIf.getThenRegion().begin());
2609 YieldOp thenYield = combinedIf.thenYield();
2610 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2611 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2615 llvm::append_range(mergedYields, thenYield2.getOperands());
2616 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2622 combinedIf.getElseRegion(),
2623 combinedIf.getElseRegion().begin());
2626 if (combinedIf.getElseRegion().empty()) {
2628 combinedIf.getElseRegion(),
2629 combinedIf.getElseRegion().
begin());
2631 YieldOp elseYield = combinedIf.elseYield();
2632 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2633 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2638 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2640 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2649 if (pair.index() < prevIf.getNumResults())
2650 prevValues.push_back(pair.value());
2652 nextValues.push_back(pair.value());
2664 LogicalResult matchAndRewrite(IfOp ifOp,
2667 if (ifOp.getNumResults())
2669 Block *elseBlock = ifOp.elseBlock();
2670 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2674 newIfOp.getThenRegion().begin());
2699 LogicalResult matchAndRewrite(IfOp op,
2701 auto nestedOps = op.thenBlock()->without_terminator();
2703 if (!llvm::hasSingleElement(nestedOps))
2707 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2710 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2714 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2720 llvm::append_range(elseYield, op.elseYield().
getOperands());
2734 if (tup.value().getDefiningOp() == nestedIf) {
2735 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2736 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2737 elseYield[tup.index()]) {
2742 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2755 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2758 elseYieldsToUpgradeToSelect.push_back(tup.index());
2762 Value newCondition = rewriter.
create<arith::AndIOp>(
2763 loc, op.getCondition(), nestedIf.getCondition());
2768 llvm::append_range(results, newIf.getResults());
2771 for (
auto idx : elseYieldsToUpgradeToSelect)
2772 results[idx] = rewriter.
create<arith::SelectOp>(
2773 op.
getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2775 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2778 if (!elseYield.empty()) {
2781 rewriter.
create<YieldOp>(loc, elseYield);
2792 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2793 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2794 RemoveStaticCondition, RemoveUnusedResults,
2795 ReplaceIfYieldWithConditionOrValue>(context);
2798 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2799 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2800 Block *IfOp::elseBlock() {
2801 Region &r = getElseRegion();
2806 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2812 void ParallelOp::build(
2822 ParallelOp::getOperandSegmentSizeAttr(),
2824 static_cast<int32_t>(upperBounds.size()),
2825 static_cast<int32_t>(steps.size()),
2826 static_cast<int32_t>(initVals.size())}));
2830 unsigned numIVs = steps.size();
2836 if (bodyBuilderFn) {
2838 bodyBuilderFn(builder, result.
location,
2843 if (initVals.empty())
2844 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2847 void ParallelOp::build(
2854 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2857 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2861 wrapper = wrappedBuilderFn;
2863 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2872 if (stepValues.empty())
2874 "needs at least one tuple element for lowerBound, upperBound and step");
2877 for (
Value stepValue : stepValues)
2880 return emitOpError(
"constant step operand must be positive");
2884 Block *body = getBody();
2886 return emitOpError() <<
"expects the same number of induction variables: "
2888 <<
" as bound and step values: " << stepValues.size();
2890 if (!arg.getType().isIndex())
2892 "expects arguments for the induction variable to be of index type");
2895 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2896 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2901 auto resultsSize = getResults().size();
2902 auto reductionsSize = reduceOp.getReductions().size();
2903 auto initValsSize = getInitVals().size();
2904 if (resultsSize != reductionsSize)
2905 return emitOpError() <<
"expects number of results: " << resultsSize
2906 <<
" to be the same as number of reductions: "
2908 if (resultsSize != initValsSize)
2909 return emitOpError() <<
"expects number of results: " << resultsSize
2910 <<
" to be the same as number of initial values: "
2914 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2915 auto resultType = getOperation()->getResult(i).getType();
2916 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2917 if (resultType != reductionOperandType)
2918 return reduceOp.emitOpError()
2919 <<
"expects type of " << i
2920 <<
"-th reduction operand: " << reductionOperandType
2921 <<
" to be the same as the " << i
2922 <<
"-th result type: " << resultType;
2970 for (
auto &iv : ivs)
2977 ParallelOp::getOperandSegmentSizeAttr(),
2979 static_cast<int32_t>(upper.size()),
2980 static_cast<int32_t>(steps.size()),
2981 static_cast<int32_t>(initVals.size())}));
2990 ParallelOp::ensureTerminator(*body, builder, result.
location);
2995 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2996 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2997 if (!getInitVals().empty())
2998 p <<
" init (" << getInitVals() <<
")";
3003 (*this)->getAttrs(),
3004 ParallelOp::getOperandSegmentSizeAttr());
3009 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3013 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3017 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3021 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3026 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3028 return ParallelOp();
3029 assert(ivArg.getOwner() &&
"unlinked block argument");
3030 auto *containingOp = ivArg.getOwner()->getParentOp();
3031 return dyn_cast<ParallelOp>(containingOp);
3036 struct ParallelOpSingleOrZeroIterationDimsFolder
3040 LogicalResult matchAndRewrite(ParallelOp op,
3047 for (
auto [lb, ub, step, iv] :
3048 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3049 op.getInductionVars())) {
3051 if (numIterations.has_value()) {
3053 if (*numIterations == 0) {
3054 rewriter.
replaceOp(op, op.getInitVals());
3059 if (*numIterations == 1) {
3064 newLowerBounds.push_back(lb);
3065 newUpperBounds.push_back(ub);
3066 newSteps.push_back(step);
3069 if (newLowerBounds.size() == op.getLowerBound().size())
3072 if (newLowerBounds.empty()) {
3076 results.reserve(op.getInitVals().size());
3077 for (
auto &bodyOp : op.getBody()->without_terminator())
3078 rewriter.
clone(bodyOp, mapping);
3079 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3080 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3081 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3082 auto initValIndex = results.size();
3083 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3087 rewriter.
clone(reduceBodyOp, mapping);
3090 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3091 results.push_back(result);
3099 rewriter.
create<ParallelOp>(op.
getLoc(), newLowerBounds, newUpperBounds,
3100 newSteps, op.getInitVals(),
nullptr);
3106 newOp.getRegion().
begin(), mapping);
3107 rewriter.
replaceOp(op, newOp.getResults());
3115 LogicalResult matchAndRewrite(ParallelOp op,
3117 Block &outerBody = *op.getBody();
3121 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3126 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3127 llvm::is_contained(innerOp.getUpperBound(), val) ||
3128 llvm::is_contained(innerOp.getStep(), val))
3132 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3137 Block &innerBody = *innerOp.getBody();
3138 assert(iterVals.size() ==
3146 builder.
clone(op, mapping);
3149 auto concatValues = [](
const auto &first,
const auto &second) {
3151 ret.reserve(first.size() + second.size());
3152 ret.assign(first.begin(), first.end());
3153 ret.append(second.begin(), second.end());
3157 auto newLowerBounds =
3158 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3159 auto newUpperBounds =
3160 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3161 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3164 newSteps, std::nullopt,
3175 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3184 void ParallelOp::getSuccessorRegions(
3202 for (
Value v : operands) {
3211 LogicalResult ReduceOp::verifyRegions() {
3214 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3215 auto type = getOperands()[i].getType();
3218 return emitOpError() << i <<
"-th reduction has an empty body";
3221 return arg.getType() != type;
3223 return emitOpError() <<
"expected two block arguments with type " << type
3224 <<
" in the " << i <<
"-th reduction region";
3228 return emitOpError(
"reduction bodies must be terminated with an "
3229 "'scf.reduce.return' op");
3248 Block *reductionBody = getOperation()->getBlock();
3250 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3252 if (expectedResultType != getResult().
getType())
3253 return emitOpError() <<
"must have type " << expectedResultType
3254 <<
" (the type of the reduction inputs)";
3264 ValueRange operands, BodyBuilderFn beforeBuilder,
3265 BodyBuilderFn afterBuilder) {
3273 beforeArgLocs.reserve(operands.size());
3274 for (
Value operand : operands) {
3275 beforeArgLocs.push_back(operand.getLoc());
3280 beforeRegion, {}, operands.getTypes(), beforeArgLocs);
3289 resultTypes, afterArgLocs);
3295 ConditionOp WhileOp::getConditionOp() {
3296 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3299 YieldOp WhileOp::getYieldOp() {
3300 return cast<YieldOp>(getAfterBody()->getTerminator());
3303 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3304 return getYieldOp().getResultsMutable();
3308 return getBeforeBody()->getArguments();
3312 return getAfterBody()->getArguments();
3316 return getBeforeArguments();
3320 assert(point == getBefore() &&
3321 "WhileOp is expected to branch only to the first region");
3329 regions.emplace_back(&getBefore(), getBefore().getArguments());
3333 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3334 "there are only two regions in a WhileOp");
3336 if (point == getAfter()) {
3337 regions.emplace_back(&getBefore(), getBefore().getArguments());
3341 regions.emplace_back(getResults());
3342 regions.emplace_back(&getAfter(), getAfter().getArguments());
3346 return {&getBefore(), &getAfter()};
3367 FunctionType functionType;
3372 result.
addTypes(functionType.getResults());
3374 if (functionType.getNumInputs() != operands.size()) {
3376 <<
"expected as many input types as operands "
3377 <<
"(expected " << operands.size() <<
" got "
3378 << functionType.getNumInputs() <<
")";
3388 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3389 regionArgs[i].type = functionType.getInput(i);
3391 return failure(parser.
parseRegion(*before, regionArgs) ||
3411 template <
typename OpTy>
3414 if (left.size() != right.size())
3415 return op.
emitOpError(
"expects the same number of ") << message;
3417 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3418 if (left[i] != right[i]) {
3421 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3422 <<
" and " << right[i];
3431 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3433 "expects the 'before' region to terminate with 'scf.condition'");
3434 if (!beforeTerminator)
3437 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3439 "expects the 'after' region to terminate with 'scf.yield'");
3440 return success(afterTerminator !=
nullptr);
3466 LogicalResult matchAndRewrite(WhileOp op,
3468 auto term = op.getConditionOp();
3472 Value constantTrue =
nullptr;
3474 bool replaced =
false;
3475 for (
auto yieldedAndBlockArgs :
3476 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3477 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3478 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3480 constantTrue = rewriter.
create<arith::ConstantOp>(
3481 op.
getLoc(), term.getCondition().getType(),
3490 return success(replaced);
3542 struct RemoveLoopInvariantArgsFromBeforeBlock
3546 LogicalResult matchAndRewrite(WhileOp op,
3548 Block &afterBlock = *op.getAfterBody();
3550 ConditionOp condOp = op.getConditionOp();
3555 bool canSimplify =
false;
3556 for (
const auto &it :
3558 auto index =
static_cast<unsigned>(it.index());
3559 auto [initVal, yieldOpArg] = it.value();
3562 if (yieldOpArg == initVal) {
3571 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3572 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3573 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3574 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3587 for (
const auto &it :
3589 auto index =
static_cast<unsigned>(it.index());
3590 auto [initVal, yieldOpArg] = it.value();
3594 if (yieldOpArg == initVal) {
3595 beforeBlockInitValMap.insert({index, initVal});
3603 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3604 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3605 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3606 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3607 beforeBlockInitValMap.insert({index, initVal});
3612 newInitArgs.emplace_back(initVal);
3613 newYieldOpArgs.emplace_back(yieldOpArg);
3614 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3627 &newWhile.getBefore(), {},
3630 Block &beforeBlock = *op.getBeforeBody();
3637 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3640 if (beforeBlockInitValMap.count(i) != 0)
3641 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3643 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3646 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3648 newWhile.getAfter().begin());
3650 rewriter.
replaceOp(op, newWhile.getResults());
3695 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3698 LogicalResult matchAndRewrite(WhileOp op,
3700 Block &beforeBlock = *op.getBeforeBody();
3701 ConditionOp condOp = op.getConditionOp();
3704 bool canSimplify =
false;
3705 for (
Value condOpArg : condOpArgs) {
3725 auto index =
static_cast<unsigned>(it.index());
3726 Value condOpArg = it.value();
3731 condOpInitValMap.insert({index, condOpArg});
3733 newCondOpArgs.emplace_back(condOpArg);
3734 newAfterBlockType.emplace_back(condOpArg.
getType());
3735 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3746 auto newWhile = rewriter.
create<WhileOp>(op.
getLoc(), newAfterBlockType,
3749 Block &newAfterBlock =
3751 newAfterBlockType, newAfterBlockArgLocs);
3753 Block &afterBlock = *op.getAfterBody();
3760 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3761 Value afterBlockArg, result;
3764 if (condOpInitValMap.count(i) != 0) {
3765 afterBlockArg = condOpInitValMap[i];
3766 result = afterBlockArg;
3768 afterBlockArg = newAfterBlock.getArgument(
j);
3769 result = newWhile.getResult(
j);
3772 newAfterBlockArgs[i] = afterBlockArg;
3773 newWhileResults[i] = result;
3776 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3778 newWhile.getBefore().begin());
3780 rewriter.
replaceOp(op, newWhileResults);
3814 LogicalResult matchAndRewrite(WhileOp op,
3816 auto term = op.getConditionOp();
3817 auto afterArgs = op.getAfterArguments();
3818 auto termArgs = term.getArgs();
3825 bool needUpdate =
false;
3826 for (
const auto &it :
3828 auto i =
static_cast<unsigned>(it.index());
3829 Value result = std::get<0>(it.value());
3830 Value afterArg = std::get<1>(it.value());
3831 Value termArg = std::get<2>(it.value());
3835 newResultsIndices.emplace_back(i);
3836 newTermArgs.emplace_back(termArg);
3837 newResultTypes.emplace_back(result.
getType());
3838 newArgLocs.emplace_back(result.
getLoc());
3853 rewriter.
create<WhileOp>(op.
getLoc(), newResultTypes, op.getInits());
3856 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3863 newResults[it.value()] = newWhile.getResult(it.index());
3864 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3868 newWhile.getBefore().begin());
3870 Block &afterBlock = *op.getAfterBody();
3871 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3903 LogicalResult matchAndRewrite(scf::WhileOp op,
3905 using namespace scf;
3906 auto cond = op.getConditionOp();
3907 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3910 bool changed =
false;
3911 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3912 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3913 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3916 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3917 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3921 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3924 if (cmp2.getPredicate() == cmp.getPredicate())
3925 samePredicate =
true;
3926 else if (cmp2.getPredicate() ==
3928 samePredicate =
false;
3938 return success(changed);
3946 LogicalResult matchAndRewrite(WhileOp op,
3949 if (!llvm::any_of(op.getBeforeArguments(),
3950 [](
Value arg) { return arg.use_empty(); }))
3953 YieldOp yield = op.getYieldOp();
3958 llvm::BitVector argsToErase;
3960 size_t argsCount = op.getBeforeArguments().size();
3961 newYields.reserve(argsCount);
3962 newInits.reserve(argsCount);
3963 argsToErase.reserve(argsCount);
3964 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3965 op.getBeforeArguments(), yield.
getOperands(), op.getInits())) {
3966 if (beforeArg.use_empty()) {
3967 argsToErase.push_back(
true);
3969 argsToErase.push_back(
false);
3970 newYields.emplace_back(yieldValue);
3971 newInits.emplace_back(initValue);
3975 Block &beforeBlock = *op.getBeforeBody();
3976 Block &afterBlock = *op.getAfterBody();
3984 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3985 Block &newAfterBlock = *newWhileOp.getAfterBody();
3991 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
3992 newBeforeBlock.getArguments());
3996 rewriter.
replaceOp(op, newWhileOp.getResults());
4005 LogicalResult matchAndRewrite(WhileOp op,
4007 ConditionOp condOp = op.getConditionOp();
4011 for (
Value arg : condOpArgs)
4012 argsSet.insert(arg);
4014 if (argsSet.size() == condOpArgs.size())
4017 llvm::SmallDenseMap<Value, unsigned> argsMap;
4019 argsMap.reserve(condOpArgs.size());
4020 newArgs.reserve(condOpArgs.size());
4021 for (
Value arg : condOpArgs) {
4022 if (!argsMap.count(arg)) {
4023 auto pos =
static_cast<unsigned>(argsMap.size());
4024 argsMap.insert({arg, pos});
4025 newArgs.emplace_back(arg);
4032 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4033 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4035 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4036 Block &newAfterBlock = *newWhileOp.getAfterBody();
4041 auto it = argsMap.find(arg);
4042 assert(it != argsMap.end());
4043 auto pos = it->second;
4044 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4045 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4053 Block &beforeBlock = *op.getBeforeBody();
4054 Block &afterBlock = *op.getAfterBody();
4056 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4057 newBeforeBlock.getArguments());
4058 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4066 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4068 if (args1.size() != args2.size())
4069 return std::nullopt;
4073 auto it = llvm::find(args2, arg1);
4074 if (it == args2.end())
4075 return std::nullopt;
4077 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4084 llvm::SmallDenseSet<Value> set;
4085 for (
Value arg : args) {
4086 if (set.contains(arg))
4101 LogicalResult matchAndRewrite(WhileOp loop,
4103 auto oldBefore = loop.getBeforeBody();
4104 ConditionOp oldTerm = loop.getConditionOp();
4105 ValueRange beforeArgs = oldBefore->getArguments();
4107 if (beforeArgs == termArgs)
4110 if (hasDuplicates(termArgs))
4113 auto mapping = getArgsMapping(beforeArgs, termArgs);
4124 auto oldAfter = loop.getAfterBody();
4128 newResultTypes[
j] = loop.getResult(i).getType();
4130 auto newLoop = rewriter.
create<WhileOp>(
4131 loop.getLoc(), newResultTypes, loop.getInits(),
4133 auto newBefore = newLoop.getBeforeBody();
4134 auto newAfter = newLoop.getAfterBody();
4139 newResults[i] = newLoop.getResult(
j);
4140 newAfterArgs[i] = newAfter->getArgument(
j);
4144 newBefore->getArguments());
4156 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4157 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4158 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4159 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4173 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4176 caseValues.push_back(value);
4185 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4187 p <<
"case " << value <<
' ';
4193 if (getCases().size() != getCaseRegions().size()) {
4194 return emitOpError(
"has ")
4195 << getCaseRegions().size() <<
" case regions but "
4196 << getCases().size() <<
" case values";
4200 for (int64_t value : getCases())
4201 if (!valueSet.insert(value).second)
4202 return emitOpError(
"has duplicate case value: ") << value;
4203 auto verifyRegion = [&](
Region ®ion,
const Twine &name) -> LogicalResult {
4204 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4206 return emitOpError(
"expected region to end with scf.yield, but got ")
4209 if (yield.getNumOperands() != getNumResults()) {
4210 return (emitOpError(
"expected each region to return ")
4211 << getNumResults() <<
" values, but " << name <<
" returns "
4212 << yield.getNumOperands())
4213 .attachNote(yield.getLoc())
4214 <<
"see yield operation here";
4216 for (
auto [idx, result, operand] :
4217 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4218 yield.getOperandTypes())) {
4219 if (result == operand)
4221 return (emitOpError(
"expected result #")
4222 << idx <<
" of each region to be " << result)
4223 .attachNote(yield.getLoc())
4224 << name <<
" returns " << operand <<
" here";
4229 if (failed(verifyRegion(getDefaultRegion(),
"default region")))
4232 if (failed(verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4238 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4240 Block &scf::IndexSwitchOp::getDefaultBlock() {
4241 return getDefaultRegion().
front();
4244 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4245 assert(idx < getNumCases() &&
"case index out-of-bounds");
4246 return getCaseRegions()[idx].front();
4249 void IndexSwitchOp::getSuccessorRegions(
4253 successors.emplace_back(getResults());
4257 llvm::copy(getRegions(), std::back_inserter(successors));
4260 void IndexSwitchOp::getEntrySuccessorRegions(
4263 FoldAdaptor adaptor(operands, *
this);
4266 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4268 llvm::copy(getRegions(), std::back_inserter(successors));
4274 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4275 if (caseValue == arg.getInt()) {
4276 successors.emplace_back(&caseRegion);
4280 successors.emplace_back(&getDefaultRegion());
4283 void IndexSwitchOp::getRegionInvocationBounds(
4285 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4286 if (!operandValue) {
4292 unsigned liveIndex = getNumRegions() - 1;
4293 const auto *it = llvm::find(getCases(), operandValue.getInt());
4294 if (it != getCases().end())
4295 liveIndex = std::distance(getCases().begin(), it);
4296 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4297 bounds.emplace_back(0, i == liveIndex);
4308 if (!maybeCst.has_value())
4310 int64_t cst = *maybeCst;
4311 int64_t caseIdx, e = op.getNumCases();
4312 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4313 if (cst == op.getCases()[caseIdx])
4317 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4318 : op.getDefaultRegion();
4319 Block &source = r.front();
4342 #define GET_OP_CLASSES
4343 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalables, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.