25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77 InParallelOp, ReduceReturnOp>();
78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80 ForallOp, InParallelOp, WhileOp, YieldOp>();
81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
86 builder.
create<scf::YieldOp>(loc);
91 template <
typename TerminatorTy>
93 StringRef errorMessage) {
96 terminatorOperation = ®ion.
front().
back();
97 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
101 if (terminatorOperation)
102 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
114 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
160 if (getRegion().empty())
161 return emitOpError(
"region needs to have at least one block");
162 if (getRegion().front().getNumArguments() > 0)
163 return emitOpError(
"region cannot have any arguments");
186 if (!llvm::hasSingleElement(op.getRegion()))
235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238 Block *prevBlock = op->getBlock();
242 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 for (
Block &blk : op.getRegion()) {
245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248 yieldOp.getResults());
256 for (
auto res : op.getResults())
257 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
269 void ExecuteRegionOp::getSuccessorRegions(
287 assert((point.
isParent() || point == getParentOp().getAfter()) &&
288 "condition op can only exit the loop or branch to the after"
291 return getArgsMutable();
294 void ConditionOp::getSuccessorRegions(
296 FoldAdaptor adaptor(operands, *
this);
298 WhileOp whileOp = getParentOp();
302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303 if (!boolAttr || boolAttr.getValue())
304 regions.emplace_back(&whileOp.getAfter(),
305 whileOp.getAfter().getArguments());
306 if (!boolAttr || !boolAttr.getValue())
307 regions.emplace_back(whileOp.getResults());
316 BodyBuilderFn bodyBuilder) {
321 for (
Value v : initArgs)
327 for (
Value v : initArgs)
333 if (initArgs.empty() && !bodyBuilder) {
334 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
335 }
else if (bodyBuilder) {
345 if (getInitArgs().size() != getNumResults())
347 "mismatch in number of loop-carried values and defined values");
352 LogicalResult ForOp::verifyRegions() {
357 "expected induction variable to be same type as bounds and step");
359 if (getNumRegionIterArgs() != getNumResults())
361 "mismatch in number of basic block args and defined values");
363 auto initArgs = getInitArgs();
364 auto iterArgs = getRegionIterArgs();
365 auto opResults = getResults();
367 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 return emitOpError() <<
"types mismatch between " << i
370 <<
"th iter operand and defined value";
372 return emitOpError() <<
"types mismatch between " << i
373 <<
"th iter region arg and defined value";
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
401 std::optional<int64_t> tripCount =
403 if (!tripCount.has_value() || tripCount != 1)
407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
414 llvm::append_range(bbArgReplacements, getInitArgs());
418 getOperation()->getIterator(), bbArgReplacements);
434 StringRef prefix =
"") {
435 assert(blocksArgs.size() == initializers.size() &&
436 "expected same length of arguments and initializers");
437 if (initializers.empty())
441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
442 p << std::get<0>(it) <<
" = " << std::get<1>(it);
448 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
452 if (!getInitArgs().empty())
453 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 p <<
" : " << t <<
' ';
459 !getInitArgs().empty());
481 regionArgs.push_back(inductionVariable);
491 if (regionArgs.size() != result.
types.size() + 1)
494 "mismatch in number of loop-carried values and defined values");
503 regionArgs.front().type = type;
504 for (
auto [iterArg, type] :
505 llvm::zip_equal(llvm::drop_begin(regionArgs), result.
types))
512 ForOp::ensureTerminator(*body, builder, result.
location);
521 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
522 operands, result.
types)) {
523 Type type = std::get<2>(argOperandType);
524 std::get<0>(argOperandType).type = type;
541 return getBody()->getArguments().drop_front(getNumInductionVars());
545 return getInitArgsMutable();
548 FailureOr<LoopLikeOpInterface>
549 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
551 bool replaceInitOperandUsesInLoop,
556 auto inits = llvm::to_vector(getInitArgs());
557 inits.append(newInitOperands.begin(), newInitOperands.end());
558 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
564 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
566 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
571 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
572 assert(newInitOperands.size() == newYieldedValues.size() &&
573 "expected as many new yield values as new iter operands");
575 yieldOp.getResultsMutable().append(newYieldedValues);
581 newLoop.getBody()->getArguments().take_front(
582 getBody()->getNumArguments()));
584 if (replaceInitOperandUsesInLoop) {
587 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
598 newLoop->getResults().take_front(getNumResults()));
599 return cast<LoopLikeOpInterface>(newLoop.getOperation());
603 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
606 assert(ivArg.getOwner() &&
"unlinked block argument");
607 auto *containingOp = ivArg.getOwner()->getParentOp();
608 return dyn_cast_or_null<ForOp>(containingOp);
612 return getInitArgs();
629 for (
auto [lb, ub, step] :
630 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
632 if (!tripCount.has_value() || *tripCount != 1)
641 return getBody()->getArguments().drop_front(getRank());
645 return getOutputsMutable();
651 scf::InParallelOp terminator = forallOp.getTerminator();
656 bbArgReplacements.append(forallOp.getOutputs().begin(),
657 forallOp.getOutputs().end());
661 forallOp->getIterator(), bbArgReplacements);
666 results.reserve(forallOp.getResults().size());
667 for (
auto &yieldingOp : terminator.getYieldingOps()) {
668 auto parallelInsertSliceOp =
669 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
671 Value dst = parallelInsertSliceOp.getDest();
672 Value src = parallelInsertSliceOp.getSource();
673 if (llvm::isa<TensorType>(src.
getType())) {
674 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
675 forallOp.getLoc(), dst.
getType(), src, dst,
676 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
677 parallelInsertSliceOp.getStrides(),
678 parallelInsertSliceOp.getStaticOffsets(),
679 parallelInsertSliceOp.getStaticSizes(),
680 parallelInsertSliceOp.getStaticStrides()));
682 llvm_unreachable(
"unsupported terminator");
697 assert(lbs.size() == ubs.size() &&
698 "expected the same number of lower and upper bounds");
699 assert(lbs.size() == steps.size() &&
700 "expected the same number of lower bounds and steps");
705 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
707 assert(results.size() == iterArgs.size() &&
708 "loop nest body must return as many values as loop has iteration "
710 return LoopNest{{}, std::move(results)};
718 loops.reserve(lbs.size());
719 ivs.reserve(lbs.size());
722 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
723 auto loop = builder.
create<scf::ForOp>(
724 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
730 currentIterArgs = args;
731 currentLoc = nestedLoc;
737 loops.push_back(loop);
741 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
743 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
750 ? bodyBuilder(builder, currentLoc, ivs,
751 loops.back().getRegionIterArgs())
753 assert(results.size() == iterArgs.size() &&
754 "loop nest body must return as many values as loop has iteration "
757 builder.
create<scf::YieldOp>(loc, results);
761 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
762 return LoopNest{std::move(loops), std::move(nestResults)};
770 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
775 bodyBuilder(nestedBuilder, nestedLoc, ivs);
784 assert(operand.
getOwner() == forOp);
789 "expected an iter OpOperand");
791 "Expected a different type");
793 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
795 newIterOperands.push_back(replacement);
798 newIterOperands.push_back(opOperand.get());
802 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
803 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
804 forOp.getStep(), newIterOperands);
805 newForOp->
setAttrs(forOp->getAttrs());
806 Block &newBlock = newForOp.getRegion().
front();
814 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
816 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
817 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
821 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
824 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
827 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
828 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
829 clonedYieldOp.getOperand(yieldIdx));
831 newYieldOperands[yieldIdx] = castOut;
832 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
833 rewriter.
eraseOp(clonedYieldOp);
838 newResults[yieldIdx] =
839 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
859 LogicalResult matchAndRewrite(scf::ForOp forOp,
861 bool canonicalize =
false;
868 int64_t numResults = forOp.getNumResults();
870 keepMask.reserve(numResults);
873 newBlockTransferArgs.reserve(1 + numResults);
874 newBlockTransferArgs.push_back(
Value());
875 newIterArgs.reserve(forOp.getInitArgs().size());
876 newYieldValues.reserve(numResults);
877 newResultValues.reserve(numResults);
879 for (
auto [init, arg, result, yielded] :
880 llvm::zip(forOp.getInitArgs(),
881 forOp.getRegionIterArgs(),
883 forOp.getYieldedValues()
890 bool forwarded = (arg == yielded) || (init == yielded) ||
891 (arg.use_empty() && result.use_empty());
894 keepMask.push_back(
false);
895 newBlockTransferArgs.push_back(init);
896 newResultValues.push_back(init);
902 if (
auto it = initYieldToArg.find({init, yielded});
903 it != initYieldToArg.end()) {
905 keepMask.push_back(
false);
906 auto [sameArg, sameResult] = it->second;
910 newBlockTransferArgs.push_back(init);
911 newResultValues.push_back(init);
916 initYieldToArg.insert({{init, yielded}, {arg, result}});
917 keepMask.push_back(
true);
918 newIterArgs.push_back(init);
919 newYieldValues.push_back(yielded);
920 newBlockTransferArgs.push_back(
Value());
921 newResultValues.push_back(
Value());
927 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
928 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
929 forOp.getStep(), newIterArgs);
930 newForOp->
setAttrs(forOp->getAttrs());
931 Block &newBlock = newForOp.getRegion().
front();
935 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
937 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
938 Value &newResultVal = newResultValues[idx];
939 assert((blockTransferArg && newResultVal) ||
940 (!blockTransferArg && !newResultVal));
941 if (!blockTransferArg) {
942 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
943 newResultVal = newForOp.getResult(collapsedIdx++);
949 "unexpected argument size mismatch");
954 if (newIterArgs.empty()) {
955 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
958 rewriter.
replaceOp(forOp, newResultValues);
963 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
967 filteredOperands.reserve(newResultValues.size());
968 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
970 filteredOperands.push_back(mergedTerminator.getOperand(idx));
971 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
975 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
976 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
977 cloneFilteredTerminator(mergedYieldOp);
978 rewriter.
eraseOp(mergedYieldOp);
979 rewriter.
replaceOp(forOp, newResultValues);
987 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
988 IntegerAttr clb, cub;
990 llvm::APInt lbValue = clb.getValue();
991 llvm::APInt ubValue = cub.getValue();
992 return (ubValue - lbValue).getSExtValue();
1001 return diff.getSExtValue();
1002 return std::nullopt;
1011 LogicalResult matchAndRewrite(ForOp op,
1015 if (op.getLowerBound() == op.getUpperBound()) {
1016 rewriter.
replaceOp(op, op.getInitArgs());
1020 std::optional<int64_t> diff =
1021 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1027 rewriter.
replaceOp(op, op.getInitArgs());
1031 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1032 if (!maybeStepValue)
1037 llvm::APInt stepValue = *maybeStepValue;
1038 if (stepValue.sge(*diff)) {
1040 blockArgs.reserve(op.getInitArgs().size() + 1);
1041 blockArgs.push_back(op.getLowerBound());
1042 llvm::append_range(blockArgs, op.getInitArgs());
1049 if (!llvm::hasSingleElement(block))
1053 if (llvm::any_of(op.getYieldedValues(),
1054 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1056 rewriter.
replaceOp(op, op.getYieldedValues());
1090 LogicalResult matchAndRewrite(ForOp op,
1092 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1093 OpOperand &iterOpOperand = std::get<0>(it);
1095 if (!incomingCast ||
1096 incomingCast.getSource().getType() == incomingCast.getType())
1101 incomingCast.getDest().getType(),
1102 incomingCast.getSource().getType()))
1104 if (!std::get<1>(it).hasOneUse())
1110 rewriter, op, iterOpOperand, incomingCast.getSource(),
1112 return b.create<tensor::CastOp>(loc, type, source);
1124 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1128 std::optional<APInt> ForOp::getConstantStep() {
1131 return step.getValue();
1135 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1136 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1142 if (
auto constantStep = getConstantStep())
1143 if (*constantStep == 1)
1156 unsigned numLoops = getRank();
1158 if (getNumResults() != getOutputs().size())
1159 return emitOpError(
"produces ")
1160 << getNumResults() <<
" results, but has only "
1161 << getOutputs().size() <<
" outputs";
1164 auto *body = getBody();
1166 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1167 for (int64_t i = 0; i < numLoops; ++i)
1169 return emitOpError(
"expects ")
1170 << i <<
"-th block argument to be an index";
1171 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1173 return emitOpError(
"type mismatch between ")
1174 << i <<
"-th output and corresponding block argument";
1175 if (getMapping().has_value() && !getMapping()->empty()) {
1176 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1177 return emitOpError() <<
"mapping attribute size must match op rank";
1178 for (
auto map : getMapping()->getValue()) {
1179 if (!isa<DeviceMappingAttrInterface>(map))
1180 return emitOpError()
1188 getStaticLowerBound(),
1189 getDynamicLowerBound())))
1192 getStaticUpperBound(),
1193 getDynamicUpperBound())))
1196 getStaticStep(), getDynamicStep())))
1204 p <<
" (" << getInductionVars();
1205 if (isNormalized()) {
1226 if (!getRegionOutArgs().empty())
1227 p <<
"-> (" << getResultTypes() <<
") ";
1228 p.printRegion(getRegion(),
1230 getNumResults() > 0);
1231 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1232 getStaticLowerBoundAttrName(),
1233 getStaticUpperBoundAttrName(),
1234 getStaticStepAttrName()});
1239 auto indexType = b.getIndexType();
1259 unsigned numLoops = ivs.size();
1294 if (outOperands.size() != result.
types.size())
1296 "mismatch between out operands and types");
1306 std::unique_ptr<Region> region = std::make_unique<Region>();
1307 for (
auto &iv : ivs) {
1308 iv.type = b.getIndexType();
1309 regionArgs.push_back(iv);
1312 auto &out = it.value();
1313 out.type = result.
types[it.index()];
1314 regionArgs.push_back(out);
1320 ForallOp::ensureTerminator(*region, b, result.
location);
1332 {static_cast<int32_t>(dynamicLbs.size()),
1333 static_cast<int32_t>(dynamicUbs.size()),
1334 static_cast<int32_t>(dynamicSteps.size()),
1335 static_cast<int32_t>(outOperands.size())}));
1340 void ForallOp::build(
1344 std::optional<ArrayAttr> mapping,
1365 "operandSegmentSizes",
1367 static_cast<int32_t>(dynamicUbs.size()),
1368 static_cast<int32_t>(dynamicSteps.size()),
1369 static_cast<int32_t>(outputs.size())}));
1370 if (mapping.has_value()) {
1389 if (!bodyBuilderFn) {
1390 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1397 void ForallOp::build(
1400 std::optional<ArrayAttr> mapping,
1402 unsigned numLoops = ubs.size();
1405 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1409 bool ForallOp::isNormalized() {
1413 return intValue.has_value() && intValue == val;
1416 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1419 InParallelOp ForallOp::getTerminator() {
1420 return cast<InParallelOp>(getBody()->getTerminator());
1425 InParallelOp inParallelOp = getTerminator();
1426 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1427 if (
auto parallelInsertSliceOp =
1428 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1429 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1430 storeOps.push_back(parallelInsertSliceOp);
1436 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1441 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1443 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1447 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1449 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1453 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1459 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1462 assert(tidxArg.getOwner() &&
"unlinked block argument");
1463 auto *containingOp = tidxArg.getOwner()->getParentOp();
1464 return dyn_cast<ForallOp>(containingOp);
1472 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1474 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1478 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1481 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1490 LogicalResult matchAndRewrite(ForallOp op,
1505 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1506 op.setStaticLowerBound(staticLowerBound);
1510 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1511 op.setStaticUpperBound(staticUpperBound);
1514 op.getDynamicStepMutable().assign(dynamicStep);
1515 op.setStaticStep(staticStep);
1517 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1519 {static_cast<int32_t>(dynamicLowerBound.size()),
1520 static_cast<int32_t>(dynamicUpperBound.size()),
1521 static_cast<int32_t>(dynamicStep.size()),
1522 static_cast<int32_t>(op.getNumResults())}));
1604 LogicalResult matchAndRewrite(ForallOp forallOp,
1623 for (
OpResult result : forallOp.getResults()) {
1624 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1625 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1626 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1627 resultToDelete.insert(result);
1629 resultToReplace.push_back(result);
1630 newOuts.push_back(opOperand->
get());
1636 if (resultToDelete.empty())
1644 for (
OpResult result : resultToDelete) {
1645 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1646 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1648 forallOp.getCombiningOps(blockArg);
1649 for (
Operation *combiningOp : combiningOps)
1650 rewriter.
eraseOp(combiningOp);
1655 auto newForallOp = rewriter.
create<scf::ForallOp>(
1656 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1657 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1658 forallOp.getMapping(),
1663 Block *loopBody = forallOp.getBody();
1664 Block *newLoopBody = newForallOp.getBody();
1669 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1676 for (
OpResult result : forallOp.getResults()) {
1677 if (resultToDelete.count(result)) {
1678 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1680 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1683 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1687 for (
auto &&[oldResult, newResult] :
1688 llvm::zip(resultToReplace, newForallOp->getResults()))
1694 for (
OpResult oldResult : resultToDelete)
1696 forallOp.getTiedOpOperand(oldResult)->get());
1701 struct ForallOpSingleOrZeroIterationDimsFolder
1705 LogicalResult matchAndRewrite(ForallOp op,
1708 if (op.getMapping().has_value() && !op.getMapping()->empty())
1716 for (
auto [lb, ub, step, iv] :
1717 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1718 op.getMixedStep(), op.getInductionVars())) {
1720 if (numIterations.has_value()) {
1722 if (*numIterations == 0) {
1723 rewriter.
replaceOp(op, op.getOutputs());
1728 if (*numIterations == 1) {
1733 newMixedLowerBounds.push_back(lb);
1734 newMixedUpperBounds.push_back(ub);
1735 newMixedSteps.push_back(step);
1739 if (newMixedLowerBounds.empty()) {
1745 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1747 op,
"no dimensions have 0 or 1 iterations");
1752 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1753 newMixedUpperBounds, newMixedSteps,
1754 op.getOutputs(), std::nullopt,
nullptr);
1755 newOp.getBodyRegion().getBlocks().clear();
1760 newOp.getStaticLowerBoundAttrName(),
1761 newOp.getStaticUpperBoundAttrName(),
1762 newOp.getStaticStepAttrName()};
1763 for (
const auto &namedAttr : op->getAttrs()) {
1764 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1767 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1771 newOp.getRegion().begin(), mapping);
1772 rewriter.
replaceOp(op, newOp.getResults());
1778 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1781 LogicalResult matchAndRewrite(ForallOp op,
1785 for (
auto [lb, ub, step, iv] :
1786 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1787 op.getMixedStep(), op.getInductionVars())) {
1788 if (iv.getUses().begin() == iv.getUses().end())
1791 if (!numIterations.has_value() || numIterations.value() != 1) {
1802 struct FoldTensorCastOfOutputIntoForallOp
1811 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1813 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1816 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1823 castOp.getSource().getType())) {
1827 tensorCastProducers[en.index()] =
1828 TypeCast{castOp.getSource().getType(), castOp.getType()};
1829 newOutputTensors[en.index()] = castOp.getSource();
1832 if (tensorCastProducers.empty())
1837 auto newForallOp = rewriter.
create<ForallOp>(
1838 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1839 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1841 auto castBlockArgs =
1842 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1843 for (
auto [index, cast] : tensorCastProducers) {
1844 Value &oldTypeBBArg = castBlockArgs[index];
1845 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1846 nestedLoc, cast.dstType, oldTypeBBArg);
1851 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1852 ivsBlockArgs.append(castBlockArgs);
1854 bbArgs.front().getParentBlock(), ivsBlockArgs);
1860 auto terminator = newForallOp.getTerminator();
1861 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1862 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1863 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1864 insertSliceOp.getDestMutable().assign(outputBlockArg);
1870 for (
auto &item : tensorCastProducers) {
1871 Value &oldTypeResult = castResults[item.first];
1872 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1875 rewriter.
replaceOp(forallOp, castResults);
1884 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1885 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1886 ForallOpSingleOrZeroIterationDimsFolder,
1887 ForallOpReplaceConstantInductionVar>(context);
1916 scf::ForallOp forallOp =
1917 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1919 return this->emitOpError(
"expected forall op parent");
1922 for (
Operation &op : getRegion().front().getOperations()) {
1923 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1924 return this->emitOpError(
"expected only ")
1925 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1929 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1931 if (!llvm::is_contained(regionOutArgs, dest))
1932 return op.emitOpError(
"may only insert into an output block argument");
1949 std::unique_ptr<Region> region = std::make_unique<Region>();
1953 if (region->empty())
1963 OpResult InParallelOp::getParentResult(int64_t idx) {
1964 return getOperation()->getParentOp()->getResult(idx);
1968 return llvm::to_vector<4>(
1969 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1971 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1972 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1977 return getRegion().front().getOperations();
1985 assert(a &&
"expected non-empty operation");
1986 assert(b &&
"expected non-empty operation");
1991 if (ifOp->isProperAncestor(b))
1994 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1995 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1997 ifOp = ifOp->getParentOfType<IfOp>();
2005 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2006 IfOp::Adaptor adaptor,
2008 if (adaptor.getRegions().empty())
2010 Region *r = &adaptor.getThenRegion();
2013 Block &b = r->front();
2016 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2019 TypeRange types = yieldOp.getOperandTypes();
2020 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2027 return build(builder, result, resultTypes, cond,
false,
2033 bool addElseBlock) {
2034 assert((!addElseBlock || addThenBlock) &&
2035 "must not create else block w/o then block");
2050 bool withElseRegion) {
2051 build(builder, result,
TypeRange{}, cond, withElseRegion);
2063 if (resultTypes.empty())
2064 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2068 if (withElseRegion) {
2070 if (resultTypes.empty())
2071 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2078 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2085 thenBuilder(builder, result.
location);
2091 elseBuilder(builder, result.
location);
2098 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2100 inferredReturnTypes))) {
2101 result.
addTypes(inferredReturnTypes);
2106 if (getNumResults() != 0 && getElseRegion().empty())
2107 return emitOpError(
"must have an else block if defining values");
2145 bool printBlockTerminators =
false;
2147 p <<
" " << getCondition();
2148 if (!getResults().empty()) {
2149 p <<
" -> (" << getResultTypes() <<
")";
2151 printBlockTerminators =
true;
2156 printBlockTerminators);
2159 auto &elseRegion = getElseRegion();
2160 if (!elseRegion.
empty()) {
2164 printBlockTerminators);
2181 Region *elseRegion = &this->getElseRegion();
2182 if (elseRegion->
empty())
2190 FoldAdaptor adaptor(operands, *
this);
2191 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2192 if (!boolAttr || boolAttr.getValue())
2193 regions.emplace_back(&getThenRegion());
2196 if (!boolAttr || !boolAttr.getValue()) {
2197 if (!getElseRegion().empty())
2198 regions.emplace_back(&getElseRegion());
2200 regions.emplace_back(getResults());
2204 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2207 if (getElseRegion().empty())
2210 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2217 getConditionMutable().assign(xorStmt.getLhs());
2221 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2222 getElseRegion().getBlocks());
2223 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2224 getThenRegion().getBlocks(), thenBlock);
2228 void IfOp::getRegionInvocationBounds(
2231 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2234 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2235 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2238 invocationBounds.assign(2, {0, 1});
2254 llvm::transform(usedResults, std::back_inserter(usedOperands),
2259 [&]() { yieldOp->setOperands(usedOperands); });
2262 LogicalResult matchAndRewrite(IfOp op,
2266 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2267 [](
OpResult result) { return !result.use_empty(); });
2270 if (usedResults.size() == op.getNumResults())
2275 llvm::transform(usedResults, std::back_inserter(newTypes),
2280 rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2286 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2287 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2292 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2301 LogicalResult matchAndRewrite(IfOp op,
2309 else if (!op.getElseRegion().empty())
2323 LogicalResult matchAndRewrite(IfOp op,
2325 if (op->getNumResults() == 0)
2328 auto cond = op.getCondition();
2329 auto thenYieldArgs = op.thenYield().getOperands();
2330 auto elseYieldArgs = op.elseYield().getOperands();
2333 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2334 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2335 &op.getElseRegion() == falseVal.getParentRegion())
2336 nonHoistable.push_back(trueVal.getType());
2340 if (nonHoistable.size() == op->getNumResults())
2343 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond,
2345 if (replacement.thenBlock())
2346 rewriter.
eraseBlock(replacement.thenBlock());
2347 replacement.getThenRegion().takeBody(op.getThenRegion());
2348 replacement.getElseRegion().takeBody(op.getElseRegion());
2351 assert(thenYieldArgs.size() == results.size());
2352 assert(elseYieldArgs.size() == results.size());
2357 for (
const auto &it :
2359 Value trueVal = std::get<0>(it.value());
2360 Value falseVal = std::get<1>(it.value());
2363 results[it.index()] = replacement.getResult(trueYields.size());
2364 trueYields.push_back(trueVal);
2365 falseYields.push_back(falseVal);
2366 }
else if (trueVal == falseVal)
2367 results[it.index()] = trueVal;
2369 results[it.index()] = rewriter.
create<arith::SelectOp>(
2370 op.getLoc(), cond, trueVal, falseVal);
2400 LogicalResult matchAndRewrite(IfOp op,
2412 Value constantTrue =
nullptr;
2413 Value constantFalse =
nullptr;
2416 llvm::make_early_inc_range(op.getCondition().getUses())) {
2421 constantTrue = rewriter.
create<arith::ConstantOp>(
2425 [&]() { use.
set(constantTrue); });
2426 }
else if (op.getElseRegion().isAncestor(
2431 constantFalse = rewriter.
create<arith::ConstantOp>(
2435 [&]() { use.
set(constantFalse); });
2479 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2482 LogicalResult matchAndRewrite(IfOp op,
2485 if (op.getNumResults() == 0)
2489 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2491 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2494 op.getOperation()->getIterator());
2497 for (
auto [trueResult, falseResult, opResult] :
2498 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2500 if (trueResult == falseResult) {
2501 if (!opResult.use_empty()) {
2502 opResult.replaceAllUsesWith(trueResult);
2513 bool trueVal = trueYield.
getValue();
2514 bool falseVal = falseYield.
getValue();
2515 if (!trueVal && falseVal) {
2516 if (!opResult.use_empty()) {
2517 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2519 op.getLoc(), op.getCondition(),
2529 if (trueVal && !falseVal) {
2530 if (!opResult.use_empty()) {
2531 opResult.replaceAllUsesWith(op.getCondition());
2564 LogicalResult matchAndRewrite(IfOp nextIf,
2566 Block *parent = nextIf->getBlock();
2567 if (nextIf == &parent->
front())
2570 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2578 Block *nextThen =
nullptr;
2579 Block *nextElse =
nullptr;
2580 if (nextIf.getCondition() == prevIf.getCondition()) {
2581 nextThen = nextIf.thenBlock();
2582 if (!nextIf.getElseRegion().empty())
2583 nextElse = nextIf.elseBlock();
2585 if (arith::XOrIOp notv =
2586 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2587 if (notv.getLhs() == prevIf.getCondition() &&
2589 nextElse = nextIf.thenBlock();
2590 if (!nextIf.getElseRegion().empty())
2591 nextThen = nextIf.elseBlock();
2594 if (arith::XOrIOp notv =
2595 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2596 if (notv.getLhs() == nextIf.getCondition() &&
2598 nextElse = nextIf.thenBlock();
2599 if (!nextIf.getElseRegion().empty())
2600 nextThen = nextIf.elseBlock();
2604 if (!nextThen && !nextElse)
2608 if (!prevIf.getElseRegion().empty())
2609 prevElseYielded = prevIf.elseYield().getOperands();
2612 for (
auto it : llvm::zip(prevIf.getResults(),
2613 prevIf.thenYield().getOperands(), prevElseYielded))
2615 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2619 use.
set(std::get<1>(it));
2624 use.
set(std::get<2>(it));
2630 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2632 IfOp combinedIf = rewriter.
create<IfOp>(
2633 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2634 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2637 combinedIf.getThenRegion(),
2638 combinedIf.getThenRegion().begin());
2641 YieldOp thenYield = combinedIf.thenYield();
2642 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2643 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2647 llvm::append_range(mergedYields, thenYield2.getOperands());
2648 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2654 combinedIf.getElseRegion(),
2655 combinedIf.getElseRegion().begin());
2658 if (combinedIf.getElseRegion().empty()) {
2660 combinedIf.getElseRegion(),
2661 combinedIf.getElseRegion().
begin());
2663 YieldOp elseYield = combinedIf.elseYield();
2664 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2665 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2670 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2672 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2681 if (pair.index() < prevIf.getNumResults())
2682 prevValues.push_back(pair.value());
2684 nextValues.push_back(pair.value());
2696 LogicalResult matchAndRewrite(IfOp ifOp,
2699 if (ifOp.getNumResults())
2701 Block *elseBlock = ifOp.elseBlock();
2702 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2706 newIfOp.getThenRegion().begin());
2731 LogicalResult matchAndRewrite(IfOp op,
2733 auto nestedOps = op.thenBlock()->without_terminator();
2735 if (!llvm::hasSingleElement(nestedOps))
2739 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2742 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2746 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2752 llvm::append_range(elseYield, op.elseYield().getOperands());
2766 if (tup.value().getDefiningOp() == nestedIf) {
2767 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2768 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2769 elseYield[tup.index()]) {
2774 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2787 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2790 elseYieldsToUpgradeToSelect.push_back(tup.index());
2794 Value newCondition = rewriter.
create<arith::AndIOp>(
2795 loc, op.getCondition(), nestedIf.getCondition());
2796 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2800 llvm::append_range(results, newIf.getResults());
2803 for (
auto idx : elseYieldsToUpgradeToSelect)
2804 results[idx] = rewriter.
create<arith::SelectOp>(
2805 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2807 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2810 if (!elseYield.empty()) {
2813 rewriter.
create<YieldOp>(loc, elseYield);
2824 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2825 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2826 RemoveStaticCondition, RemoveUnusedResults,
2827 ReplaceIfYieldWithConditionOrValue>(context);
2830 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2831 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2832 Block *IfOp::elseBlock() {
2833 Region &r = getElseRegion();
2838 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2844 void ParallelOp::build(
2854 ParallelOp::getOperandSegmentSizeAttr(),
2856 static_cast<int32_t>(upperBounds.size()),
2857 static_cast<int32_t>(steps.size()),
2858 static_cast<int32_t>(initVals.size())}));
2862 unsigned numIVs = steps.size();
2868 if (bodyBuilderFn) {
2870 bodyBuilderFn(builder, result.
location,
2875 if (initVals.empty())
2876 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2879 void ParallelOp::build(
2886 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2889 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2893 wrapper = wrappedBuilderFn;
2895 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2904 if (stepValues.empty())
2906 "needs at least one tuple element for lowerBound, upperBound and step");
2909 for (
Value stepValue : stepValues)
2912 return emitOpError(
"constant step operand must be positive");
2916 Block *body = getBody();
2918 return emitOpError() <<
"expects the same number of induction variables: "
2920 <<
" as bound and step values: " << stepValues.size();
2922 if (!arg.getType().isIndex())
2924 "expects arguments for the induction variable to be of index type");
2927 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2928 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2933 auto resultsSize = getResults().size();
2934 auto reductionsSize = reduceOp.getReductions().size();
2935 auto initValsSize = getInitVals().size();
2936 if (resultsSize != reductionsSize)
2937 return emitOpError() <<
"expects number of results: " << resultsSize
2938 <<
" to be the same as number of reductions: "
2940 if (resultsSize != initValsSize)
2941 return emitOpError() <<
"expects number of results: " << resultsSize
2942 <<
" to be the same as number of initial values: "
2946 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2947 auto resultType = getOperation()->getResult(i).getType();
2948 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2949 if (resultType != reductionOperandType)
2950 return reduceOp.emitOpError()
2951 <<
"expects type of " << i
2952 <<
"-th reduction operand: " << reductionOperandType
2953 <<
" to be the same as the " << i
2954 <<
"-th result type: " << resultType;
3002 for (
auto &iv : ivs)
3009 ParallelOp::getOperandSegmentSizeAttr(),
3011 static_cast<int32_t>(upper.size()),
3012 static_cast<int32_t>(steps.size()),
3013 static_cast<int32_t>(initVals.size())}));
3022 ParallelOp::ensureTerminator(*body, builder, result.
location);
3027 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3028 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3029 if (!getInitVals().empty())
3030 p <<
" init (" << getInitVals() <<
")";
3035 (*this)->getAttrs(),
3036 ParallelOp::getOperandSegmentSizeAttr());
3041 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3045 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3049 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3053 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3058 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3060 return ParallelOp();
3061 assert(ivArg.getOwner() &&
"unlinked block argument");
3062 auto *containingOp = ivArg.getOwner()->getParentOp();
3063 return dyn_cast<ParallelOp>(containingOp);
3068 struct ParallelOpSingleOrZeroIterationDimsFolder
3072 LogicalResult matchAndRewrite(ParallelOp op,
3079 for (
auto [lb, ub, step, iv] :
3080 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3081 op.getInductionVars())) {
3083 if (numIterations.has_value()) {
3085 if (*numIterations == 0) {
3086 rewriter.
replaceOp(op, op.getInitVals());
3091 if (*numIterations == 1) {
3096 newLowerBounds.push_back(lb);
3097 newUpperBounds.push_back(ub);
3098 newSteps.push_back(step);
3101 if (newLowerBounds.size() == op.getLowerBound().size())
3104 if (newLowerBounds.empty()) {
3108 results.reserve(op.getInitVals().size());
3109 for (
auto &bodyOp : op.getBody()->without_terminator())
3110 rewriter.
clone(bodyOp, mapping);
3111 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3112 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3113 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3114 auto initValIndex = results.size();
3115 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3119 rewriter.
clone(reduceBodyOp, mapping);
3122 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3123 results.push_back(result);
3131 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3132 newSteps, op.getInitVals(),
nullptr);
3138 newOp.getRegion().begin(), mapping);
3139 rewriter.
replaceOp(op, newOp.getResults());
3147 LogicalResult matchAndRewrite(ParallelOp op,
3149 Block &outerBody = *op.getBody();
3153 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3158 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3159 llvm::is_contained(innerOp.getUpperBound(), val) ||
3160 llvm::is_contained(innerOp.getStep(), val))
3164 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3169 Block &innerBody = *innerOp.getBody();
3170 assert(iterVals.size() ==
3178 builder.
clone(op, mapping);
3181 auto concatValues = [](
const auto &first,
const auto &second) {
3183 ret.reserve(first.size() + second.size());
3184 ret.assign(first.begin(), first.end());
3185 ret.append(second.begin(), second.end());
3189 auto newLowerBounds =
3190 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3191 auto newUpperBounds =
3192 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3193 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3196 newSteps, std::nullopt,
3207 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3216 void ParallelOp::getSuccessorRegions(
3234 for (
Value v : operands) {
3243 LogicalResult ReduceOp::verifyRegions() {
3246 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3247 auto type = getOperands()[i].getType();
3250 return emitOpError() << i <<
"-th reduction has an empty body";
3253 return arg.getType() != type;
3255 return emitOpError() <<
"expected two block arguments with type " << type
3256 <<
" in the " << i <<
"-th reduction region";
3260 return emitOpError(
"reduction bodies must be terminated with an "
3261 "'scf.reduce.return' op");
3280 Block *reductionBody = getOperation()->getBlock();
3282 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3284 if (expectedResultType != getResult().
getType())
3285 return emitOpError() <<
"must have type " << expectedResultType
3286 <<
" (the type of the reduction inputs)";
3296 ValueRange inits, BodyBuilderFn beforeBuilder,
3297 BodyBuilderFn afterBuilder) {
3305 beforeArgLocs.reserve(inits.size());
3306 for (
Value operand : inits) {
3307 beforeArgLocs.push_back(operand.getLoc());
3312 inits.getTypes(), beforeArgLocs);
3321 resultTypes, afterArgLocs);
3327 ConditionOp WhileOp::getConditionOp() {
3328 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3331 YieldOp WhileOp::getYieldOp() {
3332 return cast<YieldOp>(getAfterBody()->getTerminator());
3335 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3336 return getYieldOp().getResultsMutable();
3340 return getBeforeBody()->getArguments();
3344 return getAfterBody()->getArguments();
3348 return getBeforeArguments();
3352 assert(point == getBefore() &&
3353 "WhileOp is expected to branch only to the first region");
3361 regions.emplace_back(&getBefore(), getBefore().getArguments());
3365 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3366 "there are only two regions in a WhileOp");
3368 if (point == getAfter()) {
3369 regions.emplace_back(&getBefore(), getBefore().getArguments());
3373 regions.emplace_back(getResults());
3374 regions.emplace_back(&getAfter(), getAfter().getArguments());
3378 return {&getBefore(), &getAfter()};
3399 FunctionType functionType;
3404 result.
addTypes(functionType.getResults());
3406 if (functionType.getNumInputs() != operands.size()) {
3408 <<
"expected as many input types as operands "
3409 <<
"(expected " << operands.size() <<
" got "
3410 << functionType.getNumInputs() <<
")";
3420 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3421 regionArgs[i].type = functionType.getInput(i);
3423 return failure(parser.
parseRegion(*before, regionArgs) ||
3443 template <
typename OpTy>
3446 if (left.size() != right.size())
3447 return op.emitOpError(
"expects the same number of ") << message;
3449 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3450 if (left[i] != right[i]) {
3453 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3454 <<
" and " << right[i];
3463 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3465 "expects the 'before' region to terminate with 'scf.condition'");
3466 if (!beforeTerminator)
3469 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3471 "expects the 'after' region to terminate with 'scf.yield'");
3472 return success(afterTerminator !=
nullptr);
3498 LogicalResult matchAndRewrite(WhileOp op,
3500 auto term = op.getConditionOp();
3504 Value constantTrue =
nullptr;
3506 bool replaced =
false;
3507 for (
auto yieldedAndBlockArgs :
3508 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3509 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3510 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3512 constantTrue = rewriter.
create<arith::ConstantOp>(
3513 op.getLoc(), term.getCondition().getType(),
3522 return success(replaced);
3574 struct RemoveLoopInvariantArgsFromBeforeBlock
3578 LogicalResult matchAndRewrite(WhileOp op,
3580 Block &afterBlock = *op.getAfterBody();
3582 ConditionOp condOp = op.getConditionOp();
3587 bool canSimplify =
false;
3588 for (
const auto &it :
3590 auto index =
static_cast<unsigned>(it.index());
3591 auto [initVal, yieldOpArg] = it.value();
3594 if (yieldOpArg == initVal) {
3603 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3604 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3605 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3606 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3619 for (
const auto &it :
3621 auto index =
static_cast<unsigned>(it.index());
3622 auto [initVal, yieldOpArg] = it.value();
3626 if (yieldOpArg == initVal) {
3627 beforeBlockInitValMap.insert({index, initVal});
3635 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3636 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3637 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3638 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3639 beforeBlockInitValMap.insert({index, initVal});
3644 newInitArgs.emplace_back(initVal);
3645 newYieldOpArgs.emplace_back(yieldOpArg);
3646 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3656 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3659 &newWhile.getBefore(), {},
3662 Block &beforeBlock = *op.getBeforeBody();
3669 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3672 if (beforeBlockInitValMap.count(i) != 0)
3673 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3675 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3678 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3680 newWhile.getAfter().begin());
3682 rewriter.
replaceOp(op, newWhile.getResults());
3727 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3730 LogicalResult matchAndRewrite(WhileOp op,
3732 Block &beforeBlock = *op.getBeforeBody();
3733 ConditionOp condOp = op.getConditionOp();
3736 bool canSimplify =
false;
3737 for (
Value condOpArg : condOpArgs) {
3757 auto index =
static_cast<unsigned>(it.index());
3758 Value condOpArg = it.value();
3763 condOpInitValMap.insert({index, condOpArg});
3765 newCondOpArgs.emplace_back(condOpArg);
3766 newAfterBlockType.emplace_back(condOpArg.
getType());
3767 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3778 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3781 Block &newAfterBlock =
3783 newAfterBlockType, newAfterBlockArgLocs);
3785 Block &afterBlock = *op.getAfterBody();
3792 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3793 Value afterBlockArg, result;
3796 if (condOpInitValMap.count(i) != 0) {
3797 afterBlockArg = condOpInitValMap[i];
3798 result = afterBlockArg;
3800 afterBlockArg = newAfterBlock.getArgument(
j);
3801 result = newWhile.getResult(
j);
3804 newAfterBlockArgs[i] = afterBlockArg;
3805 newWhileResults[i] = result;
3808 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3810 newWhile.getBefore().begin());
3812 rewriter.
replaceOp(op, newWhileResults);
3846 LogicalResult matchAndRewrite(WhileOp op,
3848 auto term = op.getConditionOp();
3849 auto afterArgs = op.getAfterArguments();
3850 auto termArgs = term.getArgs();
3857 bool needUpdate =
false;
3858 for (
const auto &it :
3860 auto i =
static_cast<unsigned>(it.index());
3861 Value result = std::get<0>(it.value());
3862 Value afterArg = std::get<1>(it.value());
3863 Value termArg = std::get<2>(it.value());
3867 newResultsIndices.emplace_back(i);
3868 newTermArgs.emplace_back(termArg);
3869 newResultTypes.emplace_back(result.
getType());
3870 newArgLocs.emplace_back(result.
getLoc());
3885 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3888 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3895 newResults[it.value()] = newWhile.getResult(it.index());
3896 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3900 newWhile.getBefore().begin());
3902 Block &afterBlock = *op.getAfterBody();
3903 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3935 LogicalResult matchAndRewrite(scf::WhileOp op,
3937 using namespace scf;
3938 auto cond = op.getConditionOp();
3939 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3943 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3944 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3945 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3948 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3949 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3953 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3956 if (cmp2.getPredicate() == cmp.getPredicate())
3957 samePredicate =
true;
3958 else if (cmp2.getPredicate() ==
3960 samePredicate =
false;
3978 LogicalResult matchAndRewrite(WhileOp op,
3981 if (!llvm::any_of(op.getBeforeArguments(),
3982 [](
Value arg) { return arg.use_empty(); }))
3985 YieldOp yield = op.getYieldOp();
3990 llvm::BitVector argsToErase;
3992 size_t argsCount = op.getBeforeArguments().size();
3993 newYields.reserve(argsCount);
3994 newInits.reserve(argsCount);
3995 argsToErase.reserve(argsCount);
3996 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3997 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3998 if (beforeArg.use_empty()) {
3999 argsToErase.push_back(
true);
4001 argsToErase.push_back(
false);
4002 newYields.emplace_back(yieldValue);
4003 newInits.emplace_back(initValue);
4007 Block &beforeBlock = *op.getBeforeBody();
4008 Block &afterBlock = *op.getAfterBody();
4014 rewriter.
create<WhileOp>(loc, op.getResultTypes(), newInits,
4016 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4017 Block &newAfterBlock = *newWhileOp.getAfterBody();
4023 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4024 newBeforeBlock.getArguments());
4028 rewriter.
replaceOp(op, newWhileOp.getResults());
4037 LogicalResult matchAndRewrite(WhileOp op,
4039 ConditionOp condOp = op.getConditionOp();
4044 if (argsSet.size() == condOpArgs.size())
4047 llvm::SmallDenseMap<Value, unsigned> argsMap;
4049 argsMap.reserve(condOpArgs.size());
4050 newArgs.reserve(condOpArgs.size());
4051 for (
Value arg : condOpArgs) {
4052 if (!argsMap.count(arg)) {
4053 auto pos =
static_cast<unsigned>(argsMap.size());
4054 argsMap.insert({arg, pos});
4055 newArgs.emplace_back(arg);
4062 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4063 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4065 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4066 Block &newAfterBlock = *newWhileOp.getAfterBody();
4071 auto it = argsMap.find(arg);
4072 assert(it != argsMap.end());
4073 auto pos = it->second;
4074 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4075 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4083 Block &beforeBlock = *op.getBeforeBody();
4084 Block &afterBlock = *op.getAfterBody();
4086 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4087 newBeforeBlock.getArguments());
4088 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4096 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4098 if (args1.size() != args2.size())
4099 return std::nullopt;
4103 auto it = llvm::find(args2, arg1);
4104 if (it == args2.end())
4105 return std::nullopt;
4107 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4114 llvm::SmallDenseSet<Value> set;
4115 for (
Value arg : args) {
4116 if (!set.insert(arg).second)
4129 LogicalResult matchAndRewrite(WhileOp loop,
4131 auto oldBefore = loop.getBeforeBody();
4132 ConditionOp oldTerm = loop.getConditionOp();
4133 ValueRange beforeArgs = oldBefore->getArguments();
4135 if (beforeArgs == termArgs)
4138 if (hasDuplicates(termArgs))
4141 auto mapping = getArgsMapping(beforeArgs, termArgs);
4152 auto oldAfter = loop.getAfterBody();
4156 newResultTypes[
j] = loop.getResult(i).getType();
4158 auto newLoop = rewriter.
create<WhileOp>(
4159 loop.getLoc(), newResultTypes, loop.getInits(),
4161 auto newBefore = newLoop.getBeforeBody();
4162 auto newAfter = newLoop.getAfterBody();
4167 newResults[i] = newLoop.getResult(
j);
4168 newAfterArgs[i] = newAfter->getArgument(
j);
4172 newBefore->getArguments());
4184 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4185 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4186 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4187 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4201 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4204 caseValues.push_back(value);
4213 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4215 p <<
"case " << value <<
' ';
4221 if (getCases().size() != getCaseRegions().size()) {
4222 return emitOpError(
"has ")
4223 << getCaseRegions().size() <<
" case regions but "
4224 << getCases().size() <<
" case values";
4228 for (int64_t value : getCases())
4229 if (!valueSet.insert(value).second)
4230 return emitOpError(
"has duplicate case value: ") << value;
4232 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4234 return emitOpError(
"expected region to end with scf.yield, but got ")
4237 if (yield.getNumOperands() != getNumResults()) {
4238 return (emitOpError(
"expected each region to return ")
4239 << getNumResults() <<
" values, but " << name <<
" returns "
4240 << yield.getNumOperands())
4241 .attachNote(yield.getLoc())
4242 <<
"see yield operation here";
4244 for (
auto [idx, result, operand] :
4245 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4246 yield.getOperandTypes())) {
4247 if (result == operand)
4249 return (emitOpError(
"expected result #")
4250 << idx <<
" of each region to be " << result)
4251 .attachNote(yield.getLoc())
4252 << name <<
" returns " << operand <<
" here";
4257 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4260 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4266 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4268 Block &scf::IndexSwitchOp::getDefaultBlock() {
4269 return getDefaultRegion().
front();
4272 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4273 assert(idx < getNumCases() &&
"case index out-of-bounds");
4274 return getCaseRegions()[idx].front();
4277 void IndexSwitchOp::getSuccessorRegions(
4281 successors.emplace_back(getResults());
4285 llvm::copy(getRegions(), std::back_inserter(successors));
4288 void IndexSwitchOp::getEntrySuccessorRegions(
4291 FoldAdaptor adaptor(operands, *
this);
4294 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4296 llvm::copy(getRegions(), std::back_inserter(successors));
4302 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4303 if (caseValue == arg.getInt()) {
4304 successors.emplace_back(&caseRegion);
4308 successors.emplace_back(&getDefaultRegion());
4311 void IndexSwitchOp::getRegionInvocationBounds(
4313 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4314 if (!operandValue) {
4320 unsigned liveIndex = getNumRegions() - 1;
4321 const auto *it = llvm::find(getCases(), operandValue.getInt());
4322 if (it != getCases().end())
4323 liveIndex = std::distance(getCases().begin(), it);
4324 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4325 bounds.emplace_back(0, i == liveIndex);
4336 if (!maybeCst.has_value())
4338 int64_t cst = *maybeCst;
4339 int64_t caseIdx, e = op.getNumCases();
4340 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4341 if (cst == op.getCases()[caseIdx])
4345 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4346 : op.getDefaultRegion();
4347 Block &source = r.front();
4370 #define GET_OP_CLASSES
4371 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.