25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77 InParallelOp, ReduceReturnOp>();
78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80 ForallOp, InParallelOp, WhileOp, YieldOp>();
81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
86 builder.
create<scf::YieldOp>(loc);
91 template <
typename TerminatorTy>
93 StringRef errorMessage) {
96 terminatorOperation = ®ion.
front().
back();
97 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
101 if (terminatorOperation)
102 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
114 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
160 if (getRegion().empty())
161 return emitOpError(
"region needs to have at least one block");
162 if (getRegion().front().getNumArguments() > 0)
163 return emitOpError(
"region cannot have any arguments");
186 if (!llvm::hasSingleElement(op.getRegion()))
235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238 Block *prevBlock = op->getBlock();
242 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 for (
Block &blk : op.getRegion()) {
245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248 yieldOp.getResults());
256 for (
auto res : op.getResults())
257 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
269 void ExecuteRegionOp::getSuccessorRegions(
287 assert((point.
isParent() || point == getParentOp().getAfter()) &&
288 "condition op can only exit the loop or branch to the after"
291 return getArgsMutable();
294 void ConditionOp::getSuccessorRegions(
296 FoldAdaptor adaptor(operands, *
this);
298 WhileOp whileOp = getParentOp();
302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303 if (!boolAttr || boolAttr.getValue())
304 regions.emplace_back(&whileOp.getAfter(),
305 whileOp.getAfter().getArguments());
306 if (!boolAttr || !boolAttr.getValue())
307 regions.emplace_back(whileOp.getResults());
316 BodyBuilderFn bodyBuilder) {
321 for (
Value v : initArgs)
327 for (
Value v : initArgs)
333 if (initArgs.empty() && !bodyBuilder) {
334 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
335 }
else if (bodyBuilder) {
345 if (getInitArgs().size() != getNumResults())
347 "mismatch in number of loop-carried values and defined values");
352 LogicalResult ForOp::verifyRegions() {
357 "expected induction variable to be same type as bounds and step");
359 if (getNumRegionIterArgs() != getNumResults())
361 "mismatch in number of basic block args and defined values");
363 auto initArgs = getInitArgs();
364 auto iterArgs = getRegionIterArgs();
365 auto opResults = getResults();
367 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 return emitOpError() <<
"types mismatch between " << i
370 <<
"th iter operand and defined value";
372 return emitOpError() <<
"types mismatch between " << i
373 <<
"th iter region arg and defined value";
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
401 std::optional<int64_t> tripCount =
403 if (!tripCount.has_value() || tripCount != 1)
407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
414 llvm::append_range(bbArgReplacements, getInitArgs());
418 getOperation()->getIterator(), bbArgReplacements);
434 StringRef prefix =
"") {
435 assert(blocksArgs.size() == initializers.size() &&
436 "expected same length of arguments and initializers");
437 if (initializers.empty())
441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
442 p << std::get<0>(it) <<
" = " << std::get<1>(it);
448 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
452 if (!getInitArgs().empty())
453 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 p <<
" : " << t <<
' ';
459 !getInitArgs().empty());
481 regionArgs.push_back(inductionVariable);
491 if (regionArgs.size() != result.
types.size() + 1)
494 "mismatch in number of loop-carried values and defined values");
503 regionArgs.front().type = type;
504 for (
auto [iterArg, type] :
505 llvm::zip_equal(llvm::drop_begin(regionArgs), result.
types))
512 ForOp::ensureTerminator(*body, builder, result.
location);
521 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
522 operands, result.
types)) {
523 Type type = std::get<2>(argOperandType);
524 std::get<0>(argOperandType).type = type;
541 return getBody()->getArguments().drop_front(getNumInductionVars());
545 return getInitArgsMutable();
548 FailureOr<LoopLikeOpInterface>
549 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
551 bool replaceInitOperandUsesInLoop,
556 auto inits = llvm::to_vector(getInitArgs());
557 inits.append(newInitOperands.begin(), newInitOperands.end());
558 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
564 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
566 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
571 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
572 assert(newInitOperands.size() == newYieldedValues.size() &&
573 "expected as many new yield values as new iter operands");
575 yieldOp.getResultsMutable().append(newYieldedValues);
581 newLoop.getBody()->getArguments().take_front(
582 getBody()->getNumArguments()));
584 if (replaceInitOperandUsesInLoop) {
587 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
598 newLoop->getResults().take_front(getNumResults()));
599 return cast<LoopLikeOpInterface>(newLoop.getOperation());
603 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
606 assert(ivArg.getOwner() &&
"unlinked block argument");
607 auto *containingOp = ivArg.getOwner()->getParentOp();
608 return dyn_cast_or_null<ForOp>(containingOp);
612 return getInitArgs();
629 for (
auto [lb, ub, step] :
630 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
632 if (!tripCount.has_value() || *tripCount != 1)
641 return getBody()->getArguments().drop_front(getRank());
645 return getOutputsMutable();
651 scf::InParallelOp terminator = forallOp.getTerminator();
656 bbArgReplacements.append(forallOp.getOutputs().begin(),
657 forallOp.getOutputs().end());
661 forallOp->getIterator(), bbArgReplacements);
666 results.reserve(forallOp.getResults().size());
667 for (
auto &yieldingOp : terminator.getYieldingOps()) {
668 auto parallelInsertSliceOp =
669 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
671 Value dst = parallelInsertSliceOp.getDest();
672 Value src = parallelInsertSliceOp.getSource();
673 if (llvm::isa<TensorType>(src.
getType())) {
674 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
675 forallOp.getLoc(), dst.
getType(), src, dst,
676 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
677 parallelInsertSliceOp.getStrides(),
678 parallelInsertSliceOp.getStaticOffsets(),
679 parallelInsertSliceOp.getStaticSizes(),
680 parallelInsertSliceOp.getStaticStrides()));
682 llvm_unreachable(
"unsupported terminator");
697 assert(lbs.size() == ubs.size() &&
698 "expected the same number of lower and upper bounds");
699 assert(lbs.size() == steps.size() &&
700 "expected the same number of lower bounds and steps");
705 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
707 assert(results.size() == iterArgs.size() &&
708 "loop nest body must return as many values as loop has iteration "
710 return LoopNest{{}, std::move(results)};
718 loops.reserve(lbs.size());
719 ivs.reserve(lbs.size());
722 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
723 auto loop = builder.
create<scf::ForOp>(
724 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
730 currentIterArgs = args;
731 currentLoc = nestedLoc;
737 loops.push_back(loop);
741 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
743 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
750 ? bodyBuilder(builder, currentLoc, ivs,
751 loops.back().getRegionIterArgs())
753 assert(results.size() == iterArgs.size() &&
754 "loop nest body must return as many values as loop has iteration "
757 builder.
create<scf::YieldOp>(loc, results);
761 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
762 return LoopNest{std::move(loops), std::move(nestResults)};
770 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
775 bodyBuilder(nestedBuilder, nestedLoc, ivs);
784 assert(operand.
getOwner() == forOp);
789 "expected an iter OpOperand");
791 "Expected a different type");
793 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
795 newIterOperands.push_back(replacement);
798 newIterOperands.push_back(opOperand.get());
802 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
803 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
804 forOp.getStep(), newIterOperands);
805 newForOp->
setAttrs(forOp->getAttrs());
806 Block &newBlock = newForOp.getRegion().
front();
814 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
816 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
817 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
821 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
824 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
827 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
828 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
829 clonedYieldOp.getOperand(yieldIdx));
831 newYieldOperands[yieldIdx] = castOut;
832 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
833 rewriter.
eraseOp(clonedYieldOp);
838 newResults[yieldIdx] =
839 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
859 LogicalResult matchAndRewrite(scf::ForOp forOp,
861 bool canonicalize =
false;
868 int64_t numResults = forOp.getNumResults();
870 keepMask.reserve(numResults);
873 newBlockTransferArgs.reserve(1 + numResults);
874 newBlockTransferArgs.push_back(
Value());
875 newIterArgs.reserve(forOp.getInitArgs().size());
876 newYieldValues.reserve(numResults);
877 newResultValues.reserve(numResults);
879 for (
auto [init, arg, result, yielded] :
880 llvm::zip(forOp.getInitArgs(),
881 forOp.getRegionIterArgs(),
883 forOp.getYieldedValues()
890 bool forwarded = (arg == yielded) || (init == yielded) ||
891 (arg.use_empty() && result.use_empty());
894 keepMask.push_back(
false);
895 newBlockTransferArgs.push_back(init);
896 newResultValues.push_back(init);
902 if (
auto it = initYieldToArg.find({init, yielded});
903 it != initYieldToArg.end()) {
905 keepMask.push_back(
false);
906 auto [sameArg, sameResult] = it->second;
910 newBlockTransferArgs.push_back(init);
911 newResultValues.push_back(init);
916 initYieldToArg.insert({{init, yielded}, {arg, result}});
917 keepMask.push_back(
true);
918 newIterArgs.push_back(init);
919 newYieldValues.push_back(yielded);
920 newBlockTransferArgs.push_back(
Value());
921 newResultValues.push_back(
Value());
927 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
928 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
929 forOp.getStep(), newIterArgs);
930 newForOp->
setAttrs(forOp->getAttrs());
931 Block &newBlock = newForOp.getRegion().
front();
935 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
937 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
938 Value &newResultVal = newResultValues[idx];
939 assert((blockTransferArg && newResultVal) ||
940 (!blockTransferArg && !newResultVal));
941 if (!blockTransferArg) {
942 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
943 newResultVal = newForOp.getResult(collapsedIdx++);
949 "unexpected argument size mismatch");
954 if (newIterArgs.empty()) {
955 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
958 rewriter.
replaceOp(forOp, newResultValues);
963 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
967 filteredOperands.reserve(newResultValues.size());
968 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
970 filteredOperands.push_back(mergedTerminator.getOperand(idx));
971 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
975 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
976 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
977 cloneFilteredTerminator(mergedYieldOp);
978 rewriter.
eraseOp(mergedYieldOp);
979 rewriter.
replaceOp(forOp, newResultValues);
987 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
988 IntegerAttr clb, cub;
990 llvm::APInt lbValue = clb.getValue();
991 llvm::APInt ubValue = cub.getValue();
992 return (ubValue - lbValue).getSExtValue();
1001 return diff.getSExtValue();
1002 return std::nullopt;
1011 LogicalResult matchAndRewrite(ForOp op,
1015 if (op.getLowerBound() == op.getUpperBound()) {
1016 rewriter.
replaceOp(op, op.getInitArgs());
1020 std::optional<int64_t> diff =
1021 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1027 rewriter.
replaceOp(op, op.getInitArgs());
1031 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1032 if (!maybeStepValue)
1037 llvm::APInt stepValue = *maybeStepValue;
1038 if (stepValue.sge(*diff)) {
1040 blockArgs.reserve(op.getInitArgs().size() + 1);
1041 blockArgs.push_back(op.getLowerBound());
1042 llvm::append_range(blockArgs, op.getInitArgs());
1049 if (!llvm::hasSingleElement(block))
1053 if (llvm::any_of(op.getYieldedValues(),
1054 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1056 rewriter.
replaceOp(op, op.getYieldedValues());
1090 LogicalResult matchAndRewrite(ForOp op,
1092 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1093 OpOperand &iterOpOperand = std::get<0>(it);
1095 if (!incomingCast ||
1096 incomingCast.getSource().getType() == incomingCast.getType())
1101 incomingCast.getDest().getType(),
1102 incomingCast.getSource().getType()))
1104 if (!std::get<1>(it).hasOneUse())
1110 rewriter, op, iterOpOperand, incomingCast.getSource(),
1112 return b.create<tensor::CastOp>(loc, type, source);
1124 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1128 std::optional<APInt> ForOp::getConstantStep() {
1131 return step.getValue();
1135 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1136 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1142 if (
auto constantStep = getConstantStep())
1143 if (*constantStep == 1)
1156 unsigned numLoops = getRank();
1158 if (getNumResults() != getOutputs().size())
1159 return emitOpError(
"produces ")
1160 << getNumResults() <<
" results, but has only "
1161 << getOutputs().size() <<
" outputs";
1164 auto *body = getBody();
1166 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1167 for (int64_t i = 0; i < numLoops; ++i)
1169 return emitOpError(
"expects ")
1170 << i <<
"-th block argument to be an index";
1171 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1173 return emitOpError(
"type mismatch between ")
1174 << i <<
"-th output and corresponding block argument";
1175 if (getMapping().has_value() && !getMapping()->empty()) {
1176 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1177 return emitOpError() <<
"mapping attribute size must match op rank";
1178 for (
auto map : getMapping()->getValue()) {
1179 if (!isa<DeviceMappingAttrInterface>(map))
1180 return emitOpError()
1188 getStaticLowerBound(),
1189 getDynamicLowerBound())))
1192 getStaticUpperBound(),
1193 getDynamicUpperBound())))
1196 getStaticStep(), getDynamicStep())))
1204 p <<
" (" << getInductionVars();
1205 if (isNormalized()) {
1226 if (!getRegionOutArgs().empty())
1227 p <<
"-> (" << getResultTypes() <<
") ";
1228 p.printRegion(getRegion(),
1230 getNumResults() > 0);
1231 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1232 getStaticLowerBoundAttrName(),
1233 getStaticUpperBoundAttrName(),
1234 getStaticStepAttrName()});
1239 auto indexType = b.getIndexType();
1259 unsigned numLoops = ivs.size();
1294 if (outOperands.size() != result.
types.size())
1296 "mismatch between out operands and types");
1306 std::unique_ptr<Region> region = std::make_unique<Region>();
1307 for (
auto &iv : ivs) {
1308 iv.type = b.getIndexType();
1309 regionArgs.push_back(iv);
1312 auto &out = it.value();
1313 out.type = result.
types[it.index()];
1314 regionArgs.push_back(out);
1320 ForallOp::ensureTerminator(*region, b, result.
location);
1332 {static_cast<int32_t>(dynamicLbs.size()),
1333 static_cast<int32_t>(dynamicUbs.size()),
1334 static_cast<int32_t>(dynamicSteps.size()),
1335 static_cast<int32_t>(outOperands.size())}));
1340 void ForallOp::build(
1344 std::optional<ArrayAttr> mapping,
1365 "operandSegmentSizes",
1367 static_cast<int32_t>(dynamicUbs.size()),
1368 static_cast<int32_t>(dynamicSteps.size()),
1369 static_cast<int32_t>(outputs.size())}));
1370 if (mapping.has_value()) {
1389 if (!bodyBuilderFn) {
1390 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1397 void ForallOp::build(
1400 std::optional<ArrayAttr> mapping,
1402 unsigned numLoops = ubs.size();
1405 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1409 bool ForallOp::isNormalized() {
1413 return intValue.has_value() && intValue == val;
1416 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1419 InParallelOp ForallOp::getTerminator() {
1420 return cast<InParallelOp>(getBody()->getTerminator());
1425 InParallelOp inParallelOp = getTerminator();
1426 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1427 if (
auto parallelInsertSliceOp =
1428 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1429 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1430 storeOps.push_back(parallelInsertSliceOp);
1436 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1441 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1443 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1447 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1449 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1453 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1459 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1462 assert(tidxArg.getOwner() &&
"unlinked block argument");
1463 auto *containingOp = tidxArg.getOwner()->getParentOp();
1464 return dyn_cast<ForallOp>(containingOp);
1472 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1474 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1478 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1481 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1490 LogicalResult matchAndRewrite(ForallOp op,
1505 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1506 op.setStaticLowerBound(staticLowerBound);
1510 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1511 op.setStaticUpperBound(staticUpperBound);
1514 op.getDynamicStepMutable().assign(dynamicStep);
1515 op.setStaticStep(staticStep);
1517 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1519 {static_cast<int32_t>(dynamicLowerBound.size()),
1520 static_cast<int32_t>(dynamicUpperBound.size()),
1521 static_cast<int32_t>(dynamicStep.size()),
1522 static_cast<int32_t>(op.getNumResults())}));
1604 LogicalResult matchAndRewrite(ForallOp forallOp,
1623 for (
OpResult result : forallOp.getResults()) {
1624 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1625 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1626 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1627 resultToDelete.insert(result);
1629 resultToReplace.push_back(result);
1630 newOuts.push_back(opOperand->
get());
1636 if (resultToDelete.empty())
1644 for (
OpResult result : resultToDelete) {
1645 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1646 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1648 forallOp.getCombiningOps(blockArg);
1649 for (
Operation *combiningOp : combiningOps)
1650 rewriter.
eraseOp(combiningOp);
1655 auto newForallOp = rewriter.
create<scf::ForallOp>(
1656 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1657 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1658 forallOp.getMapping(),
1663 Block *loopBody = forallOp.getBody();
1664 Block *newLoopBody = newForallOp.getBody();
1669 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1676 for (
OpResult result : forallOp.getResults()) {
1677 if (resultToDelete.count(result)) {
1678 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1680 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1683 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1687 for (
auto &&[oldResult, newResult] :
1688 llvm::zip(resultToReplace, newForallOp->getResults()))
1694 for (
OpResult oldResult : resultToDelete)
1696 forallOp.getTiedOpOperand(oldResult)->get());
1701 struct ForallOpSingleOrZeroIterationDimsFolder
1705 LogicalResult matchAndRewrite(ForallOp op,
1708 if (op.getMapping().has_value() && !op.getMapping()->empty())
1716 for (
auto [lb, ub, step, iv] :
1717 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1718 op.getMixedStep(), op.getInductionVars())) {
1720 if (numIterations.has_value()) {
1722 if (*numIterations == 0) {
1723 rewriter.
replaceOp(op, op.getOutputs());
1728 if (*numIterations == 1) {
1733 newMixedLowerBounds.push_back(lb);
1734 newMixedUpperBounds.push_back(ub);
1735 newMixedSteps.push_back(step);
1739 if (newMixedLowerBounds.empty()) {
1745 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1747 op,
"no dimensions have 0 or 1 iterations");
1752 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1753 newMixedUpperBounds, newMixedSteps,
1754 op.getOutputs(), std::nullopt,
nullptr);
1755 newOp.getBodyRegion().getBlocks().clear();
1760 newOp.getStaticLowerBoundAttrName(),
1761 newOp.getStaticUpperBoundAttrName(),
1762 newOp.getStaticStepAttrName()};
1763 for (
const auto &namedAttr : op->getAttrs()) {
1764 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1767 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1771 newOp.getRegion().begin(), mapping);
1772 rewriter.
replaceOp(op, newOp.getResults());
1778 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1781 LogicalResult matchAndRewrite(ForallOp op,
1785 for (
auto [lb, ub, step, iv] :
1786 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1787 op.getMixedStep(), op.getInductionVars())) {
1788 if (iv.getUses().begin() == iv.getUses().end())
1791 if (!numIterations.has_value() || numIterations.value() != 1) {
1802 struct FoldTensorCastOfOutputIntoForallOp
1811 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1813 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1816 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1823 castOp.getSource().getType())) {
1827 tensorCastProducers[en.index()] =
1828 TypeCast{castOp.getSource().getType(), castOp.getType()};
1829 newOutputTensors[en.index()] = castOp.getSource();
1832 if (tensorCastProducers.empty())
1837 auto newForallOp = rewriter.
create<ForallOp>(
1838 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1839 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1841 auto castBlockArgs =
1842 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1843 for (
auto [index, cast] : tensorCastProducers) {
1844 Value &oldTypeBBArg = castBlockArgs[index];
1845 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1846 nestedLoc, cast.dstType, oldTypeBBArg);
1851 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1852 ivsBlockArgs.append(castBlockArgs);
1854 bbArgs.front().getParentBlock(), ivsBlockArgs);
1860 auto terminator = newForallOp.getTerminator();
1861 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1862 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1863 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1864 insertSliceOp.getDestMutable().assign(outputBlockArg);
1870 for (
auto &item : tensorCastProducers) {
1871 Value &oldTypeResult = castResults[item.first];
1872 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1875 rewriter.
replaceOp(forallOp, castResults);
1884 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1885 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1886 ForallOpSingleOrZeroIterationDimsFolder,
1887 ForallOpReplaceConstantInductionVar>(context);
1916 scf::ForallOp forallOp =
1917 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1919 return this->emitOpError(
"expected forall op parent");
1922 for (
Operation &op : getRegion().front().getOperations()) {
1923 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1924 return this->emitOpError(
"expected only ")
1925 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1929 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1931 if (!llvm::is_contained(regionOutArgs, dest))
1932 return op.emitOpError(
"may only insert into an output block argument");
1949 std::unique_ptr<Region> region = std::make_unique<Region>();
1953 if (region->empty())
1963 OpResult InParallelOp::getParentResult(int64_t idx) {
1964 return getOperation()->getParentOp()->getResult(idx);
1968 return llvm::to_vector<4>(
1969 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1971 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1972 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1977 return getRegion().front().getOperations();
1985 assert(a &&
"expected non-empty operation");
1986 assert(b &&
"expected non-empty operation");
1991 if (ifOp->isProperAncestor(b))
1994 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1995 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1997 ifOp = ifOp->getParentOfType<IfOp>();
2005 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2006 IfOp::Adaptor adaptor,
2008 if (adaptor.getRegions().empty())
2010 Region *r = &adaptor.getThenRegion();
2013 Block &b = r->front();
2016 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2019 TypeRange types = yieldOp.getOperandTypes();
2020 llvm::append_range(inferredReturnTypes, types);
2026 return build(builder, result, resultTypes, cond,
false,
2032 bool addElseBlock) {
2033 assert((!addElseBlock || addThenBlock) &&
2034 "must not create else block w/o then block");
2049 bool withElseRegion) {
2050 build(builder, result,
TypeRange{}, cond, withElseRegion);
2062 if (resultTypes.empty())
2063 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2067 if (withElseRegion) {
2069 if (resultTypes.empty())
2070 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2077 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2084 thenBuilder(builder, result.
location);
2090 elseBuilder(builder, result.
location);
2097 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2099 inferredReturnTypes))) {
2100 result.
addTypes(inferredReturnTypes);
2105 if (getNumResults() != 0 && getElseRegion().empty())
2106 return emitOpError(
"must have an else block if defining values");
2144 bool printBlockTerminators =
false;
2146 p <<
" " << getCondition();
2147 if (!getResults().empty()) {
2148 p <<
" -> (" << getResultTypes() <<
")";
2150 printBlockTerminators =
true;
2155 printBlockTerminators);
2158 auto &elseRegion = getElseRegion();
2159 if (!elseRegion.
empty()) {
2163 printBlockTerminators);
2180 Region *elseRegion = &this->getElseRegion();
2181 if (elseRegion->
empty())
2189 FoldAdaptor adaptor(operands, *
this);
2190 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2191 if (!boolAttr || boolAttr.getValue())
2192 regions.emplace_back(&getThenRegion());
2195 if (!boolAttr || !boolAttr.getValue()) {
2196 if (!getElseRegion().empty())
2197 regions.emplace_back(&getElseRegion());
2199 regions.emplace_back(getResults());
2203 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2206 if (getElseRegion().empty())
2209 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2216 getConditionMutable().assign(xorStmt.getLhs());
2220 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2221 getElseRegion().getBlocks());
2222 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2223 getThenRegion().getBlocks(), thenBlock);
2227 void IfOp::getRegionInvocationBounds(
2230 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2233 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2234 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2237 invocationBounds.assign(2, {0, 1});
2253 llvm::transform(usedResults, std::back_inserter(usedOperands),
2258 [&]() { yieldOp->setOperands(usedOperands); });
2261 LogicalResult matchAndRewrite(IfOp op,
2265 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2266 [](
OpResult result) { return !result.use_empty(); });
2269 if (usedResults.size() == op.getNumResults())
2274 llvm::transform(usedResults, std::back_inserter(newTypes),
2279 rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2285 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2286 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2291 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2300 LogicalResult matchAndRewrite(IfOp op,
2308 else if (!op.getElseRegion().empty())
2322 LogicalResult matchAndRewrite(IfOp op,
2324 if (op->getNumResults() == 0)
2327 auto cond = op.getCondition();
2328 auto thenYieldArgs = op.thenYield().getOperands();
2329 auto elseYieldArgs = op.elseYield().getOperands();
2332 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2333 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2334 &op.getElseRegion() == falseVal.getParentRegion())
2335 nonHoistable.push_back(trueVal.getType());
2339 if (nonHoistable.size() == op->getNumResults())
2342 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond,
2344 if (replacement.thenBlock())
2345 rewriter.
eraseBlock(replacement.thenBlock());
2346 replacement.getThenRegion().takeBody(op.getThenRegion());
2347 replacement.getElseRegion().takeBody(op.getElseRegion());
2350 assert(thenYieldArgs.size() == results.size());
2351 assert(elseYieldArgs.size() == results.size());
2356 for (
const auto &it :
2358 Value trueVal = std::get<0>(it.value());
2359 Value falseVal = std::get<1>(it.value());
2362 results[it.index()] = replacement.getResult(trueYields.size());
2363 trueYields.push_back(trueVal);
2364 falseYields.push_back(falseVal);
2365 }
else if (trueVal == falseVal)
2366 results[it.index()] = trueVal;
2368 results[it.index()] = rewriter.
create<arith::SelectOp>(
2369 op.getLoc(), cond, trueVal, falseVal);
2399 LogicalResult matchAndRewrite(IfOp op,
2411 Value constantTrue =
nullptr;
2412 Value constantFalse =
nullptr;
2415 llvm::make_early_inc_range(op.getCondition().getUses())) {
2420 constantTrue = rewriter.
create<arith::ConstantOp>(
2424 [&]() { use.
set(constantTrue); });
2425 }
else if (op.getElseRegion().isAncestor(
2430 constantFalse = rewriter.
create<arith::ConstantOp>(
2434 [&]() { use.
set(constantFalse); });
2478 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2481 LogicalResult matchAndRewrite(IfOp op,
2484 if (op.getNumResults() == 0)
2488 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2490 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2493 op.getOperation()->getIterator());
2496 for (
auto [trueResult, falseResult, opResult] :
2497 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2499 if (trueResult == falseResult) {
2500 if (!opResult.use_empty()) {
2501 opResult.replaceAllUsesWith(trueResult);
2512 bool trueVal = trueYield.
getValue();
2513 bool falseVal = falseYield.
getValue();
2514 if (!trueVal && falseVal) {
2515 if (!opResult.use_empty()) {
2516 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2518 op.getLoc(), op.getCondition(),
2528 if (trueVal && !falseVal) {
2529 if (!opResult.use_empty()) {
2530 opResult.replaceAllUsesWith(op.getCondition());
2563 LogicalResult matchAndRewrite(IfOp nextIf,
2565 Block *parent = nextIf->getBlock();
2566 if (nextIf == &parent->
front())
2569 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2577 Block *nextThen =
nullptr;
2578 Block *nextElse =
nullptr;
2579 if (nextIf.getCondition() == prevIf.getCondition()) {
2580 nextThen = nextIf.thenBlock();
2581 if (!nextIf.getElseRegion().empty())
2582 nextElse = nextIf.elseBlock();
2584 if (arith::XOrIOp notv =
2585 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2586 if (notv.getLhs() == prevIf.getCondition() &&
2588 nextElse = nextIf.thenBlock();
2589 if (!nextIf.getElseRegion().empty())
2590 nextThen = nextIf.elseBlock();
2593 if (arith::XOrIOp notv =
2594 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2595 if (notv.getLhs() == nextIf.getCondition() &&
2597 nextElse = nextIf.thenBlock();
2598 if (!nextIf.getElseRegion().empty())
2599 nextThen = nextIf.elseBlock();
2603 if (!nextThen && !nextElse)
2607 if (!prevIf.getElseRegion().empty())
2608 prevElseYielded = prevIf.elseYield().getOperands();
2611 for (
auto it : llvm::zip(prevIf.getResults(),
2612 prevIf.thenYield().getOperands(), prevElseYielded))
2614 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2618 use.
set(std::get<1>(it));
2623 use.
set(std::get<2>(it));
2629 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2631 IfOp combinedIf = rewriter.
create<IfOp>(
2632 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2633 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2636 combinedIf.getThenRegion(),
2637 combinedIf.getThenRegion().begin());
2640 YieldOp thenYield = combinedIf.thenYield();
2641 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2642 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2646 llvm::append_range(mergedYields, thenYield2.getOperands());
2647 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2653 combinedIf.getElseRegion(),
2654 combinedIf.getElseRegion().begin());
2657 if (combinedIf.getElseRegion().empty()) {
2659 combinedIf.getElseRegion(),
2660 combinedIf.getElseRegion().
begin());
2662 YieldOp elseYield = combinedIf.elseYield();
2663 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2664 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2669 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2671 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2680 if (pair.index() < prevIf.getNumResults())
2681 prevValues.push_back(pair.value());
2683 nextValues.push_back(pair.value());
2695 LogicalResult matchAndRewrite(IfOp ifOp,
2698 if (ifOp.getNumResults())
2700 Block *elseBlock = ifOp.elseBlock();
2701 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2705 newIfOp.getThenRegion().begin());
2730 LogicalResult matchAndRewrite(IfOp op,
2732 auto nestedOps = op.thenBlock()->without_terminator();
2734 if (!llvm::hasSingleElement(nestedOps))
2738 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2741 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2745 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2751 llvm::append_range(elseYield, op.elseYield().getOperands());
2765 if (tup.value().getDefiningOp() == nestedIf) {
2766 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2767 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2768 elseYield[tup.index()]) {
2773 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2786 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2789 elseYieldsToUpgradeToSelect.push_back(tup.index());
2793 Value newCondition = rewriter.
create<arith::AndIOp>(
2794 loc, op.getCondition(), nestedIf.getCondition());
2795 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2799 llvm::append_range(results, newIf.getResults());
2802 for (
auto idx : elseYieldsToUpgradeToSelect)
2803 results[idx] = rewriter.
create<arith::SelectOp>(
2804 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2806 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2809 if (!elseYield.empty()) {
2812 rewriter.
create<YieldOp>(loc, elseYield);
2823 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2824 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2825 RemoveStaticCondition, RemoveUnusedResults,
2826 ReplaceIfYieldWithConditionOrValue>(context);
2829 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2830 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2831 Block *IfOp::elseBlock() {
2832 Region &r = getElseRegion();
2837 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2843 void ParallelOp::build(
2853 ParallelOp::getOperandSegmentSizeAttr(),
2855 static_cast<int32_t>(upperBounds.size()),
2856 static_cast<int32_t>(steps.size()),
2857 static_cast<int32_t>(initVals.size())}));
2861 unsigned numIVs = steps.size();
2867 if (bodyBuilderFn) {
2869 bodyBuilderFn(builder, result.
location,
2874 if (initVals.empty())
2875 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2878 void ParallelOp::build(
2885 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2888 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2892 wrapper = wrappedBuilderFn;
2894 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2903 if (stepValues.empty())
2905 "needs at least one tuple element for lowerBound, upperBound and step");
2908 for (
Value stepValue : stepValues)
2911 return emitOpError(
"constant step operand must be positive");
2915 Block *body = getBody();
2917 return emitOpError() <<
"expects the same number of induction variables: "
2919 <<
" as bound and step values: " << stepValues.size();
2921 if (!arg.getType().isIndex())
2923 "expects arguments for the induction variable to be of index type");
2926 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2927 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2932 auto resultsSize = getResults().size();
2933 auto reductionsSize = reduceOp.getReductions().size();
2934 auto initValsSize = getInitVals().size();
2935 if (resultsSize != reductionsSize)
2936 return emitOpError() <<
"expects number of results: " << resultsSize
2937 <<
" to be the same as number of reductions: "
2939 if (resultsSize != initValsSize)
2940 return emitOpError() <<
"expects number of results: " << resultsSize
2941 <<
" to be the same as number of initial values: "
2945 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2946 auto resultType = getOperation()->getResult(i).getType();
2947 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2948 if (resultType != reductionOperandType)
2949 return reduceOp.emitOpError()
2950 <<
"expects type of " << i
2951 <<
"-th reduction operand: " << reductionOperandType
2952 <<
" to be the same as the " << i
2953 <<
"-th result type: " << resultType;
3001 for (
auto &iv : ivs)
3008 ParallelOp::getOperandSegmentSizeAttr(),
3010 static_cast<int32_t>(upper.size()),
3011 static_cast<int32_t>(steps.size()),
3012 static_cast<int32_t>(initVals.size())}));
3021 ParallelOp::ensureTerminator(*body, builder, result.
location);
3026 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3027 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3028 if (!getInitVals().empty())
3029 p <<
" init (" << getInitVals() <<
")";
3034 (*this)->getAttrs(),
3035 ParallelOp::getOperandSegmentSizeAttr());
3040 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3044 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3048 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3052 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3057 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3059 return ParallelOp();
3060 assert(ivArg.getOwner() &&
"unlinked block argument");
3061 auto *containingOp = ivArg.getOwner()->getParentOp();
3062 return dyn_cast<ParallelOp>(containingOp);
3067 struct ParallelOpSingleOrZeroIterationDimsFolder
3071 LogicalResult matchAndRewrite(ParallelOp op,
3078 for (
auto [lb, ub, step, iv] :
3079 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3080 op.getInductionVars())) {
3082 if (numIterations.has_value()) {
3084 if (*numIterations == 0) {
3085 rewriter.
replaceOp(op, op.getInitVals());
3090 if (*numIterations == 1) {
3095 newLowerBounds.push_back(lb);
3096 newUpperBounds.push_back(ub);
3097 newSteps.push_back(step);
3100 if (newLowerBounds.size() == op.getLowerBound().size())
3103 if (newLowerBounds.empty()) {
3107 results.reserve(op.getInitVals().size());
3108 for (
auto &bodyOp : op.getBody()->without_terminator())
3109 rewriter.
clone(bodyOp, mapping);
3110 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3111 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3112 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3113 auto initValIndex = results.size();
3114 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3118 rewriter.
clone(reduceBodyOp, mapping);
3121 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3122 results.push_back(result);
3130 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3131 newSteps, op.getInitVals(),
nullptr);
3137 newOp.getRegion().begin(), mapping);
3138 rewriter.
replaceOp(op, newOp.getResults());
3146 LogicalResult matchAndRewrite(ParallelOp op,
3148 Block &outerBody = *op.getBody();
3152 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3157 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3158 llvm::is_contained(innerOp.getUpperBound(), val) ||
3159 llvm::is_contained(innerOp.getStep(), val))
3163 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3168 Block &innerBody = *innerOp.getBody();
3169 assert(iterVals.size() ==
3177 builder.
clone(op, mapping);
3180 auto concatValues = [](
const auto &first,
const auto &second) {
3182 ret.reserve(first.size() + second.size());
3183 ret.assign(first.begin(), first.end());
3184 ret.append(second.begin(), second.end());
3188 auto newLowerBounds =
3189 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3190 auto newUpperBounds =
3191 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3192 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3195 newSteps, std::nullopt,
3206 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3215 void ParallelOp::getSuccessorRegions(
3233 for (
Value v : operands) {
3242 LogicalResult ReduceOp::verifyRegions() {
3245 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3246 auto type = getOperands()[i].getType();
3249 return emitOpError() << i <<
"-th reduction has an empty body";
3252 return arg.getType() != type;
3254 return emitOpError() <<
"expected two block arguments with type " << type
3255 <<
" in the " << i <<
"-th reduction region";
3259 return emitOpError(
"reduction bodies must be terminated with an "
3260 "'scf.reduce.return' op");
3279 Block *reductionBody = getOperation()->getBlock();
3281 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3283 if (expectedResultType != getResult().
getType())
3284 return emitOpError() <<
"must have type " << expectedResultType
3285 <<
" (the type of the reduction inputs)";
3295 ValueRange inits, BodyBuilderFn beforeBuilder,
3296 BodyBuilderFn afterBuilder) {
3304 beforeArgLocs.reserve(inits.size());
3305 for (
Value operand : inits) {
3306 beforeArgLocs.push_back(operand.getLoc());
3311 inits.getTypes(), beforeArgLocs);
3320 resultTypes, afterArgLocs);
3326 ConditionOp WhileOp::getConditionOp() {
3327 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3330 YieldOp WhileOp::getYieldOp() {
3331 return cast<YieldOp>(getAfterBody()->getTerminator());
3334 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3335 return getYieldOp().getResultsMutable();
3339 return getBeforeBody()->getArguments();
3343 return getAfterBody()->getArguments();
3347 return getBeforeArguments();
3351 assert(point == getBefore() &&
3352 "WhileOp is expected to branch only to the first region");
3360 regions.emplace_back(&getBefore(), getBefore().getArguments());
3364 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3365 "there are only two regions in a WhileOp");
3367 if (point == getAfter()) {
3368 regions.emplace_back(&getBefore(), getBefore().getArguments());
3372 regions.emplace_back(getResults());
3373 regions.emplace_back(&getAfter(), getAfter().getArguments());
3377 return {&getBefore(), &getAfter()};
3398 FunctionType functionType;
3403 result.
addTypes(functionType.getResults());
3405 if (functionType.getNumInputs() != operands.size()) {
3407 <<
"expected as many input types as operands "
3408 <<
"(expected " << operands.size() <<
" got "
3409 << functionType.getNumInputs() <<
")";
3419 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3420 regionArgs[i].type = functionType.getInput(i);
3422 return failure(parser.
parseRegion(*before, regionArgs) ||
3442 template <
typename OpTy>
3445 if (left.size() != right.size())
3446 return op.emitOpError(
"expects the same number of ") << message;
3448 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3449 if (left[i] != right[i]) {
3452 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3453 <<
" and " << right[i];
3462 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3464 "expects the 'before' region to terminate with 'scf.condition'");
3465 if (!beforeTerminator)
3468 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3470 "expects the 'after' region to terminate with 'scf.yield'");
3471 return success(afterTerminator !=
nullptr);
3497 LogicalResult matchAndRewrite(WhileOp op,
3499 auto term = op.getConditionOp();
3503 Value constantTrue =
nullptr;
3505 bool replaced =
false;
3506 for (
auto yieldedAndBlockArgs :
3507 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3508 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3509 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3511 constantTrue = rewriter.
create<arith::ConstantOp>(
3512 op.getLoc(), term.getCondition().getType(),
3521 return success(replaced);
3573 struct RemoveLoopInvariantArgsFromBeforeBlock
3577 LogicalResult matchAndRewrite(WhileOp op,
3579 Block &afterBlock = *op.getAfterBody();
3581 ConditionOp condOp = op.getConditionOp();
3586 bool canSimplify =
false;
3587 for (
const auto &it :
3589 auto index =
static_cast<unsigned>(it.index());
3590 auto [initVal, yieldOpArg] = it.value();
3593 if (yieldOpArg == initVal) {
3602 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3603 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3604 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3605 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3618 for (
const auto &it :
3620 auto index =
static_cast<unsigned>(it.index());
3621 auto [initVal, yieldOpArg] = it.value();
3625 if (yieldOpArg == initVal) {
3626 beforeBlockInitValMap.insert({index, initVal});
3634 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3635 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3636 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3637 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3638 beforeBlockInitValMap.insert({index, initVal});
3643 newInitArgs.emplace_back(initVal);
3644 newYieldOpArgs.emplace_back(yieldOpArg);
3645 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3655 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3658 &newWhile.getBefore(), {},
3661 Block &beforeBlock = *op.getBeforeBody();
3668 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3671 if (beforeBlockInitValMap.count(i) != 0)
3672 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3674 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3677 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3679 newWhile.getAfter().begin());
3681 rewriter.
replaceOp(op, newWhile.getResults());
3726 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3729 LogicalResult matchAndRewrite(WhileOp op,
3731 Block &beforeBlock = *op.getBeforeBody();
3732 ConditionOp condOp = op.getConditionOp();
3735 bool canSimplify =
false;
3736 for (
Value condOpArg : condOpArgs) {
3756 auto index =
static_cast<unsigned>(it.index());
3757 Value condOpArg = it.value();
3762 condOpInitValMap.insert({index, condOpArg});
3764 newCondOpArgs.emplace_back(condOpArg);
3765 newAfterBlockType.emplace_back(condOpArg.
getType());
3766 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3777 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3780 Block &newAfterBlock =
3782 newAfterBlockType, newAfterBlockArgLocs);
3784 Block &afterBlock = *op.getAfterBody();
3791 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3792 Value afterBlockArg, result;
3795 if (condOpInitValMap.count(i) != 0) {
3796 afterBlockArg = condOpInitValMap[i];
3797 result = afterBlockArg;
3799 afterBlockArg = newAfterBlock.getArgument(
j);
3800 result = newWhile.getResult(
j);
3803 newAfterBlockArgs[i] = afterBlockArg;
3804 newWhileResults[i] = result;
3807 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3809 newWhile.getBefore().begin());
3811 rewriter.
replaceOp(op, newWhileResults);
3845 LogicalResult matchAndRewrite(WhileOp op,
3847 auto term = op.getConditionOp();
3848 auto afterArgs = op.getAfterArguments();
3849 auto termArgs = term.getArgs();
3856 bool needUpdate =
false;
3857 for (
const auto &it :
3859 auto i =
static_cast<unsigned>(it.index());
3860 Value result = std::get<0>(it.value());
3861 Value afterArg = std::get<1>(it.value());
3862 Value termArg = std::get<2>(it.value());
3866 newResultsIndices.emplace_back(i);
3867 newTermArgs.emplace_back(termArg);
3868 newResultTypes.emplace_back(result.
getType());
3869 newArgLocs.emplace_back(result.
getLoc());
3884 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3887 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3894 newResults[it.value()] = newWhile.getResult(it.index());
3895 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3899 newWhile.getBefore().begin());
3901 Block &afterBlock = *op.getAfterBody();
3902 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3934 LogicalResult matchAndRewrite(scf::WhileOp op,
3936 using namespace scf;
3937 auto cond = op.getConditionOp();
3938 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3942 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3943 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3944 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3947 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3948 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3952 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3955 if (cmp2.getPredicate() == cmp.getPredicate())
3956 samePredicate =
true;
3957 else if (cmp2.getPredicate() ==
3959 samePredicate =
false;
3977 LogicalResult matchAndRewrite(WhileOp op,
3980 if (!llvm::any_of(op.getBeforeArguments(),
3981 [](
Value arg) { return arg.use_empty(); }))
3984 YieldOp yield = op.getYieldOp();
3989 llvm::BitVector argsToErase;
3991 size_t argsCount = op.getBeforeArguments().size();
3992 newYields.reserve(argsCount);
3993 newInits.reserve(argsCount);
3994 argsToErase.reserve(argsCount);
3995 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3996 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3997 if (beforeArg.use_empty()) {
3998 argsToErase.push_back(
true);
4000 argsToErase.push_back(
false);
4001 newYields.emplace_back(yieldValue);
4002 newInits.emplace_back(initValue);
4006 Block &beforeBlock = *op.getBeforeBody();
4007 Block &afterBlock = *op.getAfterBody();
4013 rewriter.
create<WhileOp>(loc, op.getResultTypes(), newInits,
4015 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4016 Block &newAfterBlock = *newWhileOp.getAfterBody();
4022 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4023 newBeforeBlock.getArguments());
4027 rewriter.
replaceOp(op, newWhileOp.getResults());
4036 LogicalResult matchAndRewrite(WhileOp op,
4038 ConditionOp condOp = op.getConditionOp();
4043 if (argsSet.size() == condOpArgs.size())
4046 llvm::SmallDenseMap<Value, unsigned> argsMap;
4048 argsMap.reserve(condOpArgs.size());
4049 newArgs.reserve(condOpArgs.size());
4050 for (
Value arg : condOpArgs) {
4051 if (!argsMap.count(arg)) {
4052 auto pos =
static_cast<unsigned>(argsMap.size());
4053 argsMap.insert({arg, pos});
4054 newArgs.emplace_back(arg);
4061 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4062 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4064 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4065 Block &newAfterBlock = *newWhileOp.getAfterBody();
4070 auto it = argsMap.find(arg);
4071 assert(it != argsMap.end());
4072 auto pos = it->second;
4073 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4074 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4082 Block &beforeBlock = *op.getBeforeBody();
4083 Block &afterBlock = *op.getAfterBody();
4085 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4086 newBeforeBlock.getArguments());
4087 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4095 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4097 if (args1.size() != args2.size())
4098 return std::nullopt;
4102 auto it = llvm::find(args2, arg1);
4103 if (it == args2.end())
4104 return std::nullopt;
4106 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4113 llvm::SmallDenseSet<Value> set;
4114 for (
Value arg : args) {
4115 if (!set.insert(arg).second)
4128 LogicalResult matchAndRewrite(WhileOp loop,
4130 auto oldBefore = loop.getBeforeBody();
4131 ConditionOp oldTerm = loop.getConditionOp();
4132 ValueRange beforeArgs = oldBefore->getArguments();
4134 if (beforeArgs == termArgs)
4137 if (hasDuplicates(termArgs))
4140 auto mapping = getArgsMapping(beforeArgs, termArgs);
4151 auto oldAfter = loop.getAfterBody();
4155 newResultTypes[
j] = loop.getResult(i).getType();
4157 auto newLoop = rewriter.
create<WhileOp>(
4158 loop.getLoc(), newResultTypes, loop.getInits(),
4160 auto newBefore = newLoop.getBeforeBody();
4161 auto newAfter = newLoop.getAfterBody();
4166 newResults[i] = newLoop.getResult(
j);
4167 newAfterArgs[i] = newAfter->getArgument(
j);
4171 newBefore->getArguments());
4183 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4184 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4185 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4186 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4200 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4203 caseValues.push_back(value);
4212 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4214 p <<
"case " << value <<
' ';
4220 if (getCases().size() != getCaseRegions().size()) {
4221 return emitOpError(
"has ")
4222 << getCaseRegions().size() <<
" case regions but "
4223 << getCases().size() <<
" case values";
4227 for (int64_t value : getCases())
4228 if (!valueSet.insert(value).second)
4229 return emitOpError(
"has duplicate case value: ") << value;
4231 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4233 return emitOpError(
"expected region to end with scf.yield, but got ")
4236 if (yield.getNumOperands() != getNumResults()) {
4237 return (emitOpError(
"expected each region to return ")
4238 << getNumResults() <<
" values, but " << name <<
" returns "
4239 << yield.getNumOperands())
4240 .attachNote(yield.getLoc())
4241 <<
"see yield operation here";
4243 for (
auto [idx, result, operand] :
4244 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4245 yield.getOperandTypes())) {
4246 if (result == operand)
4248 return (emitOpError(
"expected result #")
4249 << idx <<
" of each region to be " << result)
4250 .attachNote(yield.getLoc())
4251 << name <<
" returns " << operand <<
" here";
4256 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4259 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4265 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4267 Block &scf::IndexSwitchOp::getDefaultBlock() {
4268 return getDefaultRegion().
front();
4271 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4272 assert(idx < getNumCases() &&
"case index out-of-bounds");
4273 return getCaseRegions()[idx].front();
4276 void IndexSwitchOp::getSuccessorRegions(
4280 successors.emplace_back(getResults());
4284 llvm::copy(getRegions(), std::back_inserter(successors));
4287 void IndexSwitchOp::getEntrySuccessorRegions(
4290 FoldAdaptor adaptor(operands, *
this);
4293 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4295 llvm::copy(getRegions(), std::back_inserter(successors));
4301 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4302 if (caseValue == arg.getInt()) {
4303 successors.emplace_back(&caseRegion);
4307 successors.emplace_back(&getDefaultRegion());
4310 void IndexSwitchOp::getRegionInvocationBounds(
4312 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4313 if (!operandValue) {
4319 unsigned liveIndex = getNumRegions() - 1;
4320 const auto *it = llvm::find(getCases(), operandValue.getInt());
4321 if (it != getCases().end())
4322 liveIndex = std::distance(getCases().begin(), it);
4323 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4324 bounds.emplace_back(0, i == liveIndex);
4335 if (!maybeCst.has_value())
4337 int64_t cst = *maybeCst;
4338 int64_t caseIdx, e = op.getNumCases();
4339 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4340 if (cst == op.getCases()[caseIdx])
4344 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4345 : op.getDefaultRegion();
4346 Block &source = r.front();
4369 #define GET_OP_CLASSES
4370 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.