25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77 InParallelOp, ReduceReturnOp>();
78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80 ForallOp, InParallelOp, WhileOp, YieldOp>();
81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
86 builder.
create<scf::YieldOp>(loc);
91 template <
typename TerminatorTy>
93 StringRef errorMessage) {
96 terminatorOperation = ®ion.
front().
back();
97 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
101 if (terminatorOperation)
102 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
114 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
160 if (getRegion().empty())
161 return emitOpError(
"region needs to have at least one block");
162 if (getRegion().front().getNumArguments() > 0)
163 return emitOpError(
"region cannot have any arguments");
186 if (!llvm::hasSingleElement(op.getRegion()))
235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238 Block *prevBlock = op->getBlock();
242 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 for (
Block &blk : op.getRegion()) {
245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248 yieldOp.getResults());
256 for (
auto res : op.getResults())
257 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
269 void ExecuteRegionOp::getSuccessorRegions(
287 assert((point.
isParent() || point == getParentOp().getAfter()) &&
288 "condition op can only exit the loop or branch to the after"
291 return getArgsMutable();
294 void ConditionOp::getSuccessorRegions(
296 FoldAdaptor adaptor(operands, *
this);
298 WhileOp whileOp = getParentOp();
302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303 if (!boolAttr || boolAttr.getValue())
304 regions.emplace_back(&whileOp.getAfter(),
305 whileOp.getAfter().getArguments());
306 if (!boolAttr || !boolAttr.getValue())
307 regions.emplace_back(whileOp.getResults());
316 BodyBuilderFn bodyBuilder) {
321 for (
Value v : initArgs)
327 for (
Value v : initArgs)
333 if (initArgs.empty() && !bodyBuilder) {
334 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
335 }
else if (bodyBuilder) {
345 if (getInitArgs().size() != getNumResults())
347 "mismatch in number of loop-carried values and defined values");
352 LogicalResult ForOp::verifyRegions() {
357 "expected induction variable to be same type as bounds and step");
359 if (getNumRegionIterArgs() != getNumResults())
361 "mismatch in number of basic block args and defined values");
363 auto initArgs = getInitArgs();
364 auto iterArgs = getRegionIterArgs();
365 auto opResults = getResults();
367 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 return emitOpError() <<
"types mismatch between " << i
370 <<
"th iter operand and defined value";
372 return emitOpError() <<
"types mismatch between " << i
373 <<
"th iter region arg and defined value";
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
401 std::optional<int64_t> tripCount =
403 if (!tripCount.has_value() || tripCount != 1)
407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
414 llvm::append_range(bbArgReplacements, getInitArgs());
418 getOperation()->getIterator(), bbArgReplacements);
434 StringRef prefix =
"") {
435 assert(blocksArgs.size() == initializers.size() &&
436 "expected same length of arguments and initializers");
437 if (initializers.empty())
441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
442 p << std::get<0>(it) <<
" = " << std::get<1>(it);
448 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
452 if (!getInitArgs().empty())
453 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 p <<
" : " << t <<
' ';
459 !getInitArgs().empty());
481 regionArgs.push_back(inductionVariable);
491 if (regionArgs.size() != result.
types.size() + 1)
494 "mismatch in number of loop-carried values and defined values");
503 regionArgs.front().type = type;
509 for (
auto argOperandType :
510 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
511 Type type = std::get<2>(argOperandType);
512 std::get<0>(argOperandType).type = type;
524 ForOp::ensureTerminator(*body, builder, result.
location);
536 return getBody()->getArguments().drop_front(getNumInductionVars());
540 return getInitArgsMutable();
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
546 bool replaceInitOperandUsesInLoop,
551 auto inits = llvm::to_vector(getInitArgs());
552 inits.append(newInitOperands.begin(), newInitOperands.end());
553 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
559 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
561 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
566 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567 assert(newInitOperands.size() == newYieldedValues.size() &&
568 "expected as many new yield values as new iter operands");
570 yieldOp.getResultsMutable().append(newYieldedValues);
576 newLoop.getBody()->getArguments().take_front(
577 getBody()->getNumArguments()));
579 if (replaceInitOperandUsesInLoop) {
582 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
593 newLoop->getResults().take_front(getNumResults()));
594 return cast<LoopLikeOpInterface>(newLoop.getOperation());
598 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
601 assert(ivArg.getOwner() &&
"unlinked block argument");
602 auto *containingOp = ivArg.getOwner()->getParentOp();
603 return dyn_cast_or_null<ForOp>(containingOp);
607 return getInitArgs();
624 for (
auto [lb, ub, step] :
625 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
627 if (!tripCount.has_value() || *tripCount != 1)
636 return getBody()->getArguments().drop_front(getRank());
640 return getOutputsMutable();
646 scf::InParallelOp terminator = forallOp.getTerminator();
651 bbArgReplacements.append(forallOp.getOutputs().begin(),
652 forallOp.getOutputs().end());
656 forallOp->getIterator(), bbArgReplacements);
661 results.reserve(forallOp.getResults().size());
662 for (
auto &yieldingOp : terminator.getYieldingOps()) {
663 auto parallelInsertSliceOp =
664 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
666 Value dst = parallelInsertSliceOp.getDest();
667 Value src = parallelInsertSliceOp.getSource();
668 if (llvm::isa<TensorType>(src.
getType())) {
669 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
670 forallOp.getLoc(), dst.
getType(), src, dst,
671 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672 parallelInsertSliceOp.getStrides(),
673 parallelInsertSliceOp.getStaticOffsets(),
674 parallelInsertSliceOp.getStaticSizes(),
675 parallelInsertSliceOp.getStaticStrides()));
677 llvm_unreachable(
"unsupported terminator");
692 assert(lbs.size() == ubs.size() &&
693 "expected the same number of lower and upper bounds");
694 assert(lbs.size() == steps.size() &&
695 "expected the same number of lower bounds and steps");
700 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
702 assert(results.size() == iterArgs.size() &&
703 "loop nest body must return as many values as loop has iteration "
705 return LoopNest{{}, std::move(results)};
713 loops.reserve(lbs.size());
714 ivs.reserve(lbs.size());
717 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
718 auto loop = builder.
create<scf::ForOp>(
719 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725 currentIterArgs = args;
726 currentLoc = nestedLoc;
732 loops.push_back(loop);
736 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
738 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
745 ? bodyBuilder(builder, currentLoc, ivs,
746 loops.back().getRegionIterArgs())
748 assert(results.size() == iterArgs.size() &&
749 "loop nest body must return as many values as loop has iteration "
752 builder.
create<scf::YieldOp>(loc, results);
756 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757 return LoopNest{std::move(loops), std::move(nestResults)};
765 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
770 bodyBuilder(nestedBuilder, nestedLoc, ivs);
779 assert(operand.
getOwner() == forOp);
784 "expected an iter OpOperand");
786 "Expected a different type");
788 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
790 newIterOperands.push_back(replacement);
793 newIterOperands.push_back(opOperand.get());
797 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
798 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799 forOp.getStep(), newIterOperands);
800 newForOp->
setAttrs(forOp->getAttrs());
801 Block &newBlock = newForOp.getRegion().
front();
809 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
811 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
816 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
819 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
822 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
823 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824 clonedYieldOp.getOperand(yieldIdx));
826 newYieldOperands[yieldIdx] = castOut;
827 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828 rewriter.
eraseOp(clonedYieldOp);
833 newResults[yieldIdx] =
834 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
855 LogicalResult matchAndRewrite(scf::ForOp forOp,
857 bool canonicalize =
false;
864 int64_t numResults = forOp.getNumResults();
866 keepMask.reserve(numResults);
869 newBlockTransferArgs.reserve(1 + numResults);
870 newBlockTransferArgs.push_back(
Value());
871 newIterArgs.reserve(forOp.getInitArgs().size());
872 newYieldValues.reserve(numResults);
873 newResultValues.reserve(numResults);
874 for (
auto [init, arg, result, yielded] :
875 llvm::zip(forOp.getInitArgs(),
876 forOp.getRegionIterArgs(),
878 forOp.getYieldedValues()
885 bool forwarded = (arg == yielded) || (init == yielded) ||
886 (arg.use_empty() && result.use_empty());
887 keepMask.push_back(!forwarded);
888 canonicalize |= forwarded;
890 newBlockTransferArgs.push_back(init);
891 newResultValues.push_back(init);
894 newIterArgs.push_back(init);
895 newYieldValues.push_back(yielded);
896 newBlockTransferArgs.push_back(
Value());
897 newResultValues.push_back(
Value());
903 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
904 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
905 forOp.getStep(), newIterArgs);
906 newForOp->
setAttrs(forOp->getAttrs());
907 Block &newBlock = newForOp.getRegion().
front();
911 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
913 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
914 Value &newResultVal = newResultValues[idx];
915 assert((blockTransferArg && newResultVal) ||
916 (!blockTransferArg && !newResultVal));
917 if (!blockTransferArg) {
918 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
919 newResultVal = newForOp.getResult(collapsedIdx++);
925 "unexpected argument size mismatch");
930 if (newIterArgs.empty()) {
931 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
934 rewriter.
replaceOp(forOp, newResultValues);
939 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
943 filteredOperands.reserve(newResultValues.size());
944 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
946 filteredOperands.push_back(mergedTerminator.getOperand(idx));
947 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
951 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
952 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
953 cloneFilteredTerminator(mergedYieldOp);
954 rewriter.
eraseOp(mergedYieldOp);
955 rewriter.
replaceOp(forOp, newResultValues);
963 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
964 IntegerAttr clb, cub;
966 llvm::APInt lbValue = clb.getValue();
967 llvm::APInt ubValue = cub.getValue();
968 return (ubValue - lbValue).getSExtValue();
977 return diff.getSExtValue();
987 LogicalResult matchAndRewrite(ForOp op,
991 if (op.getLowerBound() == op.getUpperBound()) {
992 rewriter.
replaceOp(op, op.getInitArgs());
996 std::optional<int64_t> diff =
997 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1003 rewriter.
replaceOp(op, op.getInitArgs());
1007 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1008 if (!maybeStepValue)
1013 llvm::APInt stepValue = *maybeStepValue;
1014 if (stepValue.sge(*diff)) {
1016 blockArgs.reserve(op.getInitArgs().size() + 1);
1017 blockArgs.push_back(op.getLowerBound());
1018 llvm::append_range(blockArgs, op.getInitArgs());
1025 if (!llvm::hasSingleElement(block))
1029 if (llvm::any_of(op.getYieldedValues(),
1030 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1032 rewriter.
replaceOp(op, op.getYieldedValues());
1066 LogicalResult matchAndRewrite(ForOp op,
1068 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1069 OpOperand &iterOpOperand = std::get<0>(it);
1071 if (!incomingCast ||
1072 incomingCast.getSource().getType() == incomingCast.getType())
1077 incomingCast.getDest().getType(),
1078 incomingCast.getSource().getType()))
1080 if (!std::get<1>(it).hasOneUse())
1086 rewriter, op, iterOpOperand, incomingCast.getSource(),
1088 return b.create<tensor::CastOp>(loc, type, source);
1100 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1104 std::optional<APInt> ForOp::getConstantStep() {
1107 return step.getValue();
1111 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1112 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1118 if (
auto constantStep = getConstantStep())
1119 if (*constantStep == 1)
1132 unsigned numLoops = getRank();
1134 if (getNumResults() != getOutputs().size())
1135 return emitOpError(
"produces ")
1136 << getNumResults() <<
" results, but has only "
1137 << getOutputs().size() <<
" outputs";
1140 auto *body = getBody();
1142 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1143 for (int64_t i = 0; i < numLoops; ++i)
1145 return emitOpError(
"expects ")
1146 << i <<
"-th block argument to be an index";
1147 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1149 return emitOpError(
"type mismatch between ")
1150 << i <<
"-th output and corresponding block argument";
1151 if (getMapping().has_value() && !getMapping()->empty()) {
1152 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1153 return emitOpError() <<
"mapping attribute size must match op rank";
1154 for (
auto map : getMapping()->getValue()) {
1155 if (!isa<DeviceMappingAttrInterface>(map))
1156 return emitOpError()
1164 getStaticLowerBound(),
1165 getDynamicLowerBound())))
1168 getStaticUpperBound(),
1169 getDynamicUpperBound())))
1172 getStaticStep(), getDynamicStep())))
1180 p <<
" (" << getInductionVars();
1181 if (isNormalized()) {
1202 if (!getRegionOutArgs().empty())
1203 p <<
"-> (" << getResultTypes() <<
") ";
1204 p.printRegion(getRegion(),
1206 getNumResults() > 0);
1207 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1208 getStaticLowerBoundAttrName(),
1209 getStaticUpperBoundAttrName(),
1210 getStaticStepAttrName()});
1215 auto indexType = b.getIndexType();
1235 unsigned numLoops = ivs.size();
1270 if (outOperands.size() != result.
types.size())
1272 "mismatch between out operands and types");
1282 std::unique_ptr<Region> region = std::make_unique<Region>();
1283 for (
auto &iv : ivs) {
1284 iv.type = b.getIndexType();
1285 regionArgs.push_back(iv);
1288 auto &out = it.value();
1289 out.type = result.
types[it.index()];
1290 regionArgs.push_back(out);
1296 ForallOp::ensureTerminator(*region, b, result.
location);
1308 {static_cast<int32_t>(dynamicLbs.size()),
1309 static_cast<int32_t>(dynamicUbs.size()),
1310 static_cast<int32_t>(dynamicSteps.size()),
1311 static_cast<int32_t>(outOperands.size())}));
1316 void ForallOp::build(
1320 std::optional<ArrayAttr> mapping,
1341 "operandSegmentSizes",
1343 static_cast<int32_t>(dynamicUbs.size()),
1344 static_cast<int32_t>(dynamicSteps.size()),
1345 static_cast<int32_t>(outputs.size())}));
1346 if (mapping.has_value()) {
1365 if (!bodyBuilderFn) {
1366 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1373 void ForallOp::build(
1376 std::optional<ArrayAttr> mapping,
1378 unsigned numLoops = ubs.size();
1381 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1385 bool ForallOp::isNormalized() {
1389 return intValue.has_value() && intValue == val;
1392 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1401 ForallOp>::ensureTerminator(region, builder, loc);
1408 InParallelOp ForallOp::getTerminator() {
1409 return cast<InParallelOp>(getBody()->getTerminator());
1414 InParallelOp inParallelOp = getTerminator();
1415 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1416 if (
auto parallelInsertSliceOp =
1417 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1418 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1419 storeOps.push_back(parallelInsertSliceOp);
1425 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1430 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1432 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1436 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1438 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1442 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1448 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1451 assert(tidxArg.getOwner() &&
"unlinked block argument");
1452 auto *containingOp = tidxArg.getOwner()->getParentOp();
1453 return dyn_cast<ForallOp>(containingOp);
1461 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1463 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1467 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1470 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1479 LogicalResult matchAndRewrite(ForallOp op,
1494 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1495 op.setStaticLowerBound(staticLowerBound);
1499 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1500 op.setStaticUpperBound(staticUpperBound);
1503 op.getDynamicStepMutable().assign(dynamicStep);
1504 op.setStaticStep(staticStep);
1506 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1508 {static_cast<int32_t>(dynamicLowerBound.size()),
1509 static_cast<int32_t>(dynamicUpperBound.size()),
1510 static_cast<int32_t>(dynamicStep.size()),
1511 static_cast<int32_t>(op.getNumResults())}));
1593 LogicalResult matchAndRewrite(ForallOp forallOp,
1612 for (
OpResult result : forallOp.getResults()) {
1613 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1614 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1615 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1616 resultToDelete.insert(result);
1618 resultToReplace.push_back(result);
1619 newOuts.push_back(opOperand->
get());
1625 if (resultToDelete.empty())
1633 for (
OpResult result : resultToDelete) {
1634 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1635 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1637 forallOp.getCombiningOps(blockArg);
1638 for (
Operation *combiningOp : combiningOps)
1639 rewriter.
eraseOp(combiningOp);
1644 auto newForallOp = rewriter.
create<scf::ForallOp>(
1645 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1646 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1647 forallOp.getMapping(),
1652 Block *loopBody = forallOp.getBody();
1653 Block *newLoopBody = newForallOp.getBody();
1658 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1665 for (
OpResult result : forallOp.getResults()) {
1666 if (resultToDelete.count(result)) {
1667 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1669 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1672 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1676 for (
auto &&[oldResult, newResult] :
1677 llvm::zip(resultToReplace, newForallOp->getResults()))
1683 for (
OpResult oldResult : resultToDelete)
1685 forallOp.getTiedOpOperand(oldResult)->get());
1690 struct ForallOpSingleOrZeroIterationDimsFolder
1694 LogicalResult matchAndRewrite(ForallOp op,
1697 if (op.getMapping().has_value() && !op.getMapping()->empty())
1705 for (
auto [lb, ub, step, iv] :
1706 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1707 op.getMixedStep(), op.getInductionVars())) {
1709 if (numIterations.has_value()) {
1711 if (*numIterations == 0) {
1712 rewriter.
replaceOp(op, op.getOutputs());
1717 if (*numIterations == 1) {
1722 newMixedLowerBounds.push_back(lb);
1723 newMixedUpperBounds.push_back(ub);
1724 newMixedSteps.push_back(step);
1728 if (newMixedLowerBounds.empty()) {
1734 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1736 op,
"no dimensions have 0 or 1 iterations");
1741 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1742 newMixedUpperBounds, newMixedSteps,
1743 op.getOutputs(), std::nullopt,
nullptr);
1744 newOp.getBodyRegion().getBlocks().clear();
1749 newOp.getStaticLowerBoundAttrName(),
1750 newOp.getStaticUpperBoundAttrName(),
1751 newOp.getStaticStepAttrName()};
1752 for (
const auto &namedAttr : op->getAttrs()) {
1753 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1756 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1760 newOp.getRegion().begin(), mapping);
1761 rewriter.
replaceOp(op, newOp.getResults());
1767 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1770 LogicalResult matchAndRewrite(ForallOp op,
1774 for (
auto [lb, ub, step, iv] :
1775 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1776 op.getMixedStep(), op.getInductionVars())) {
1777 if (iv.getUses().begin() == iv.getUses().end())
1780 if (!numIterations.has_value() || numIterations.value() != 1) {
1791 struct FoldTensorCastOfOutputIntoForallOp
1800 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1802 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1805 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1812 castOp.getSource().getType())) {
1816 tensorCastProducers[en.index()] =
1817 TypeCast{castOp.getSource().getType(), castOp.getType()};
1818 newOutputTensors[en.index()] = castOp.getSource();
1821 if (tensorCastProducers.empty())
1826 auto newForallOp = rewriter.
create<ForallOp>(
1827 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1828 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1830 auto castBlockArgs =
1831 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1832 for (
auto [index, cast] : tensorCastProducers) {
1833 Value &oldTypeBBArg = castBlockArgs[index];
1834 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1835 nestedLoc, cast.dstType, oldTypeBBArg);
1840 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1841 ivsBlockArgs.append(castBlockArgs);
1843 bbArgs.front().getParentBlock(), ivsBlockArgs);
1849 auto terminator = newForallOp.getTerminator();
1850 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1851 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1852 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1853 insertSliceOp.getDestMutable().assign(outputBlockArg);
1859 for (
auto &item : tensorCastProducers) {
1860 Value &oldTypeResult = castResults[item.first];
1861 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1864 rewriter.
replaceOp(forallOp, castResults);
1873 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1874 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1875 ForallOpSingleOrZeroIterationDimsFolder,
1876 ForallOpReplaceConstantInductionVar>(context);
1905 scf::ForallOp forallOp =
1906 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1908 return this->emitOpError(
"expected forall op parent");
1911 for (
Operation &op : getRegion().front().getOperations()) {
1912 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1913 return this->emitOpError(
"expected only ")
1914 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1918 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1920 if (!llvm::is_contained(regionOutArgs, dest))
1921 return op.emitOpError(
"may only insert into an output block argument");
1938 std::unique_ptr<Region> region = std::make_unique<Region>();
1942 if (region->empty())
1952 OpResult InParallelOp::getParentResult(int64_t idx) {
1953 return getOperation()->getParentOp()->getResult(idx);
1957 return llvm::to_vector<4>(
1958 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1960 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1961 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1966 return getRegion().front().getOperations();
1974 assert(a &&
"expected non-empty operation");
1975 assert(b &&
"expected non-empty operation");
1980 if (ifOp->isProperAncestor(b))
1983 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1984 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1986 ifOp = ifOp->getParentOfType<IfOp>();
1994 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1995 IfOp::Adaptor adaptor,
1997 if (adaptor.getRegions().empty())
1999 Region *r = &adaptor.getThenRegion();
2002 Block &b = r->front();
2005 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2008 TypeRange types = yieldOp.getOperandTypes();
2009 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2016 return build(builder, result, resultTypes, cond,
false,
2022 bool addElseBlock) {
2023 assert((!addElseBlock || addThenBlock) &&
2024 "must not create else block w/o then block");
2039 bool withElseRegion) {
2040 build(builder, result,
TypeRange{}, cond, withElseRegion);
2052 if (resultTypes.empty())
2053 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2057 if (withElseRegion) {
2059 if (resultTypes.empty())
2060 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2067 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2074 thenBuilder(builder, result.
location);
2080 elseBuilder(builder, result.
location);
2087 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2089 inferredReturnTypes))) {
2090 result.
addTypes(inferredReturnTypes);
2095 if (getNumResults() != 0 && getElseRegion().empty())
2096 return emitOpError(
"must have an else block if defining values");
2134 bool printBlockTerminators =
false;
2136 p <<
" " << getCondition();
2137 if (!getResults().empty()) {
2138 p <<
" -> (" << getResultTypes() <<
")";
2140 printBlockTerminators =
true;
2145 printBlockTerminators);
2148 auto &elseRegion = getElseRegion();
2149 if (!elseRegion.
empty()) {
2153 printBlockTerminators);
2170 Region *elseRegion = &this->getElseRegion();
2171 if (elseRegion->
empty())
2179 FoldAdaptor adaptor(operands, *
this);
2180 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2181 if (!boolAttr || boolAttr.getValue())
2182 regions.emplace_back(&getThenRegion());
2185 if (!boolAttr || !boolAttr.getValue()) {
2186 if (!getElseRegion().empty())
2187 regions.emplace_back(&getElseRegion());
2189 regions.emplace_back(getResults());
2193 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2196 if (getElseRegion().empty())
2199 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2206 getConditionMutable().assign(xorStmt.getLhs());
2210 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2211 getElseRegion().getBlocks());
2212 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2213 getThenRegion().getBlocks(), thenBlock);
2217 void IfOp::getRegionInvocationBounds(
2220 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2223 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2224 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2227 invocationBounds.assign(2, {0, 1});
2243 llvm::transform(usedResults, std::back_inserter(usedOperands),
2248 [&]() { yieldOp->setOperands(usedOperands); });
2251 LogicalResult matchAndRewrite(IfOp op,
2255 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2256 [](
OpResult result) { return !result.use_empty(); });
2259 if (usedResults.size() == op.getNumResults())
2264 llvm::transform(usedResults, std::back_inserter(newTypes),
2269 rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2275 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2276 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2281 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2290 LogicalResult matchAndRewrite(IfOp op,
2298 else if (!op.getElseRegion().empty())
2312 LogicalResult matchAndRewrite(IfOp op,
2314 if (op->getNumResults() == 0)
2317 auto cond = op.getCondition();
2318 auto thenYieldArgs = op.thenYield().getOperands();
2319 auto elseYieldArgs = op.elseYield().getOperands();
2322 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2323 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2324 &op.getElseRegion() == falseVal.getParentRegion())
2325 nonHoistable.push_back(trueVal.getType());
2329 if (nonHoistable.size() == op->getNumResults())
2332 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond,
2334 if (replacement.thenBlock())
2335 rewriter.
eraseBlock(replacement.thenBlock());
2336 replacement.getThenRegion().takeBody(op.getThenRegion());
2337 replacement.getElseRegion().takeBody(op.getElseRegion());
2340 assert(thenYieldArgs.size() == results.size());
2341 assert(elseYieldArgs.size() == results.size());
2346 for (
const auto &it :
2348 Value trueVal = std::get<0>(it.value());
2349 Value falseVal = std::get<1>(it.value());
2352 results[it.index()] = replacement.getResult(trueYields.size());
2353 trueYields.push_back(trueVal);
2354 falseYields.push_back(falseVal);
2355 }
else if (trueVal == falseVal)
2356 results[it.index()] = trueVal;
2358 results[it.index()] = rewriter.
create<arith::SelectOp>(
2359 op.getLoc(), cond, trueVal, falseVal);
2389 LogicalResult matchAndRewrite(IfOp op,
2401 Value constantTrue =
nullptr;
2402 Value constantFalse =
nullptr;
2405 llvm::make_early_inc_range(op.getCondition().getUses())) {
2410 constantTrue = rewriter.
create<arith::ConstantOp>(
2414 [&]() { use.
set(constantTrue); });
2415 }
else if (op.getElseRegion().isAncestor(
2420 constantFalse = rewriter.
create<arith::ConstantOp>(
2424 [&]() { use.
set(constantFalse); });
2468 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2471 LogicalResult matchAndRewrite(IfOp op,
2474 if (op.getNumResults() == 0)
2478 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2480 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2483 op.getOperation()->getIterator());
2486 for (
auto [trueResult, falseResult, opResult] :
2487 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2489 if (trueResult == falseResult) {
2490 if (!opResult.use_empty()) {
2491 opResult.replaceAllUsesWith(trueResult);
2502 bool trueVal = trueYield.
getValue();
2503 bool falseVal = falseYield.
getValue();
2504 if (!trueVal && falseVal) {
2505 if (!opResult.use_empty()) {
2506 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2508 op.getLoc(), op.getCondition(),
2518 if (trueVal && !falseVal) {
2519 if (!opResult.use_empty()) {
2520 opResult.replaceAllUsesWith(op.getCondition());
2553 LogicalResult matchAndRewrite(IfOp nextIf,
2555 Block *parent = nextIf->getBlock();
2556 if (nextIf == &parent->
front())
2559 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2567 Block *nextThen =
nullptr;
2568 Block *nextElse =
nullptr;
2569 if (nextIf.getCondition() == prevIf.getCondition()) {
2570 nextThen = nextIf.thenBlock();
2571 if (!nextIf.getElseRegion().empty())
2572 nextElse = nextIf.elseBlock();
2574 if (arith::XOrIOp notv =
2575 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2576 if (notv.getLhs() == prevIf.getCondition() &&
2578 nextElse = nextIf.thenBlock();
2579 if (!nextIf.getElseRegion().empty())
2580 nextThen = nextIf.elseBlock();
2583 if (arith::XOrIOp notv =
2584 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2585 if (notv.getLhs() == nextIf.getCondition() &&
2587 nextElse = nextIf.thenBlock();
2588 if (!nextIf.getElseRegion().empty())
2589 nextThen = nextIf.elseBlock();
2593 if (!nextThen && !nextElse)
2597 if (!prevIf.getElseRegion().empty())
2598 prevElseYielded = prevIf.elseYield().getOperands();
2601 for (
auto it : llvm::zip(prevIf.getResults(),
2602 prevIf.thenYield().getOperands(), prevElseYielded))
2604 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2608 use.
set(std::get<1>(it));
2613 use.
set(std::get<2>(it));
2619 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2621 IfOp combinedIf = rewriter.
create<IfOp>(
2622 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2623 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2626 combinedIf.getThenRegion(),
2627 combinedIf.getThenRegion().begin());
2630 YieldOp thenYield = combinedIf.thenYield();
2631 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2632 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2636 llvm::append_range(mergedYields, thenYield2.getOperands());
2637 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2643 combinedIf.getElseRegion(),
2644 combinedIf.getElseRegion().begin());
2647 if (combinedIf.getElseRegion().empty()) {
2649 combinedIf.getElseRegion(),
2650 combinedIf.getElseRegion().
begin());
2652 YieldOp elseYield = combinedIf.elseYield();
2653 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2654 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2659 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2661 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2670 if (pair.index() < prevIf.getNumResults())
2671 prevValues.push_back(pair.value());
2673 nextValues.push_back(pair.value());
2685 LogicalResult matchAndRewrite(IfOp ifOp,
2688 if (ifOp.getNumResults())
2690 Block *elseBlock = ifOp.elseBlock();
2691 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2695 newIfOp.getThenRegion().begin());
2720 LogicalResult matchAndRewrite(IfOp op,
2722 auto nestedOps = op.thenBlock()->without_terminator();
2724 if (!llvm::hasSingleElement(nestedOps))
2728 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2731 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2735 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2741 llvm::append_range(elseYield, op.elseYield().getOperands());
2755 if (tup.value().getDefiningOp() == nestedIf) {
2756 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2757 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2758 elseYield[tup.index()]) {
2763 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2776 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2779 elseYieldsToUpgradeToSelect.push_back(tup.index());
2783 Value newCondition = rewriter.
create<arith::AndIOp>(
2784 loc, op.getCondition(), nestedIf.getCondition());
2785 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2789 llvm::append_range(results, newIf.getResults());
2792 for (
auto idx : elseYieldsToUpgradeToSelect)
2793 results[idx] = rewriter.
create<arith::SelectOp>(
2794 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2796 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2799 if (!elseYield.empty()) {
2802 rewriter.
create<YieldOp>(loc, elseYield);
2813 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2814 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2815 RemoveStaticCondition, RemoveUnusedResults,
2816 ReplaceIfYieldWithConditionOrValue>(context);
2819 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2820 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2821 Block *IfOp::elseBlock() {
2822 Region &r = getElseRegion();
2827 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2833 void ParallelOp::build(
2843 ParallelOp::getOperandSegmentSizeAttr(),
2845 static_cast<int32_t>(upperBounds.size()),
2846 static_cast<int32_t>(steps.size()),
2847 static_cast<int32_t>(initVals.size())}));
2851 unsigned numIVs = steps.size();
2857 if (bodyBuilderFn) {
2859 bodyBuilderFn(builder, result.
location,
2864 if (initVals.empty())
2865 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2868 void ParallelOp::build(
2875 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2878 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2882 wrapper = wrappedBuilderFn;
2884 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2893 if (stepValues.empty())
2895 "needs at least one tuple element for lowerBound, upperBound and step");
2898 for (
Value stepValue : stepValues)
2901 return emitOpError(
"constant step operand must be positive");
2905 Block *body = getBody();
2907 return emitOpError() <<
"expects the same number of induction variables: "
2909 <<
" as bound and step values: " << stepValues.size();
2911 if (!arg.getType().isIndex())
2913 "expects arguments for the induction variable to be of index type");
2916 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2917 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2922 auto resultsSize = getResults().size();
2923 auto reductionsSize = reduceOp.getReductions().size();
2924 auto initValsSize = getInitVals().size();
2925 if (resultsSize != reductionsSize)
2926 return emitOpError() <<
"expects number of results: " << resultsSize
2927 <<
" to be the same as number of reductions: "
2929 if (resultsSize != initValsSize)
2930 return emitOpError() <<
"expects number of results: " << resultsSize
2931 <<
" to be the same as number of initial values: "
2935 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2936 auto resultType = getOperation()->getResult(i).getType();
2937 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2938 if (resultType != reductionOperandType)
2939 return reduceOp.emitOpError()
2940 <<
"expects type of " << i
2941 <<
"-th reduction operand: " << reductionOperandType
2942 <<
" to be the same as the " << i
2943 <<
"-th result type: " << resultType;
2991 for (
auto &iv : ivs)
2998 ParallelOp::getOperandSegmentSizeAttr(),
3000 static_cast<int32_t>(upper.size()),
3001 static_cast<int32_t>(steps.size()),
3002 static_cast<int32_t>(initVals.size())}));
3011 ParallelOp::ensureTerminator(*body, builder, result.
location);
3016 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3017 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3018 if (!getInitVals().empty())
3019 p <<
" init (" << getInitVals() <<
")";
3024 (*this)->getAttrs(),
3025 ParallelOp::getOperandSegmentSizeAttr());
3030 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3034 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3038 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3042 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3047 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3049 return ParallelOp();
3050 assert(ivArg.getOwner() &&
"unlinked block argument");
3051 auto *containingOp = ivArg.getOwner()->getParentOp();
3052 return dyn_cast<ParallelOp>(containingOp);
3057 struct ParallelOpSingleOrZeroIterationDimsFolder
3061 LogicalResult matchAndRewrite(ParallelOp op,
3068 for (
auto [lb, ub, step, iv] :
3069 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3070 op.getInductionVars())) {
3072 if (numIterations.has_value()) {
3074 if (*numIterations == 0) {
3075 rewriter.
replaceOp(op, op.getInitVals());
3080 if (*numIterations == 1) {
3085 newLowerBounds.push_back(lb);
3086 newUpperBounds.push_back(ub);
3087 newSteps.push_back(step);
3090 if (newLowerBounds.size() == op.getLowerBound().size())
3093 if (newLowerBounds.empty()) {
3097 results.reserve(op.getInitVals().size());
3098 for (
auto &bodyOp : op.getBody()->without_terminator())
3099 rewriter.
clone(bodyOp, mapping);
3100 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3101 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3102 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3103 auto initValIndex = results.size();
3104 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3108 rewriter.
clone(reduceBodyOp, mapping);
3111 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3112 results.push_back(result);
3120 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3121 newSteps, op.getInitVals(),
nullptr);
3127 newOp.getRegion().begin(), mapping);
3128 rewriter.
replaceOp(op, newOp.getResults());
3136 LogicalResult matchAndRewrite(ParallelOp op,
3138 Block &outerBody = *op.getBody();
3142 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3147 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3148 llvm::is_contained(innerOp.getUpperBound(), val) ||
3149 llvm::is_contained(innerOp.getStep(), val))
3153 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3158 Block &innerBody = *innerOp.getBody();
3159 assert(iterVals.size() ==
3167 builder.
clone(op, mapping);
3170 auto concatValues = [](
const auto &first,
const auto &second) {
3172 ret.reserve(first.size() + second.size());
3173 ret.assign(first.begin(), first.end());
3174 ret.append(second.begin(), second.end());
3178 auto newLowerBounds =
3179 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3180 auto newUpperBounds =
3181 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3182 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3185 newSteps, std::nullopt,
3196 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3205 void ParallelOp::getSuccessorRegions(
3223 for (
Value v : operands) {
3232 LogicalResult ReduceOp::verifyRegions() {
3235 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3236 auto type = getOperands()[i].getType();
3239 return emitOpError() << i <<
"-th reduction has an empty body";
3242 return arg.getType() != type;
3244 return emitOpError() <<
"expected two block arguments with type " << type
3245 <<
" in the " << i <<
"-th reduction region";
3249 return emitOpError(
"reduction bodies must be terminated with an "
3250 "'scf.reduce.return' op");
3269 Block *reductionBody = getOperation()->getBlock();
3271 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3273 if (expectedResultType != getResult().
getType())
3274 return emitOpError() <<
"must have type " << expectedResultType
3275 <<
" (the type of the reduction inputs)";
3285 ValueRange inits, BodyBuilderFn beforeBuilder,
3286 BodyBuilderFn afterBuilder) {
3294 beforeArgLocs.reserve(inits.size());
3295 for (
Value operand : inits) {
3296 beforeArgLocs.push_back(operand.getLoc());
3301 inits.getTypes(), beforeArgLocs);
3310 resultTypes, afterArgLocs);
3316 ConditionOp WhileOp::getConditionOp() {
3317 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3320 YieldOp WhileOp::getYieldOp() {
3321 return cast<YieldOp>(getAfterBody()->getTerminator());
3324 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3325 return getYieldOp().getResultsMutable();
3329 return getBeforeBody()->getArguments();
3333 return getAfterBody()->getArguments();
3337 return getBeforeArguments();
3341 assert(point == getBefore() &&
3342 "WhileOp is expected to branch only to the first region");
3350 regions.emplace_back(&getBefore(), getBefore().getArguments());
3354 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3355 "there are only two regions in a WhileOp");
3357 if (point == getAfter()) {
3358 regions.emplace_back(&getBefore(), getBefore().getArguments());
3362 regions.emplace_back(getResults());
3363 regions.emplace_back(&getAfter(), getAfter().getArguments());
3367 return {&getBefore(), &getAfter()};
3388 FunctionType functionType;
3393 result.
addTypes(functionType.getResults());
3395 if (functionType.getNumInputs() != operands.size()) {
3397 <<
"expected as many input types as operands "
3398 <<
"(expected " << operands.size() <<
" got "
3399 << functionType.getNumInputs() <<
")";
3409 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3410 regionArgs[i].type = functionType.getInput(i);
3412 return failure(parser.
parseRegion(*before, regionArgs) ||
3432 template <
typename OpTy>
3435 if (left.size() != right.size())
3436 return op.emitOpError(
"expects the same number of ") << message;
3438 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3439 if (left[i] != right[i]) {
3442 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3443 <<
" and " << right[i];
3452 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3454 "expects the 'before' region to terminate with 'scf.condition'");
3455 if (!beforeTerminator)
3458 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3460 "expects the 'after' region to terminate with 'scf.yield'");
3461 return success(afterTerminator !=
nullptr);
3487 LogicalResult matchAndRewrite(WhileOp op,
3489 auto term = op.getConditionOp();
3493 Value constantTrue =
nullptr;
3495 bool replaced =
false;
3496 for (
auto yieldedAndBlockArgs :
3497 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3498 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3499 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3501 constantTrue = rewriter.
create<arith::ConstantOp>(
3502 op.getLoc(), term.getCondition().getType(),
3511 return success(replaced);
3563 struct RemoveLoopInvariantArgsFromBeforeBlock
3567 LogicalResult matchAndRewrite(WhileOp op,
3569 Block &afterBlock = *op.getAfterBody();
3571 ConditionOp condOp = op.getConditionOp();
3576 bool canSimplify =
false;
3577 for (
const auto &it :
3579 auto index =
static_cast<unsigned>(it.index());
3580 auto [initVal, yieldOpArg] = it.value();
3583 if (yieldOpArg == initVal) {
3592 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3593 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3594 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3595 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3608 for (
const auto &it :
3610 auto index =
static_cast<unsigned>(it.index());
3611 auto [initVal, yieldOpArg] = it.value();
3615 if (yieldOpArg == initVal) {
3616 beforeBlockInitValMap.insert({index, initVal});
3624 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3625 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3626 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3627 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3628 beforeBlockInitValMap.insert({index, initVal});
3633 newInitArgs.emplace_back(initVal);
3634 newYieldOpArgs.emplace_back(yieldOpArg);
3635 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3645 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3648 &newWhile.getBefore(), {},
3651 Block &beforeBlock = *op.getBeforeBody();
3658 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3661 if (beforeBlockInitValMap.count(i) != 0)
3662 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3664 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3667 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3669 newWhile.getAfter().begin());
3671 rewriter.
replaceOp(op, newWhile.getResults());
3716 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3719 LogicalResult matchAndRewrite(WhileOp op,
3721 Block &beforeBlock = *op.getBeforeBody();
3722 ConditionOp condOp = op.getConditionOp();
3725 bool canSimplify =
false;
3726 for (
Value condOpArg : condOpArgs) {
3746 auto index =
static_cast<unsigned>(it.index());
3747 Value condOpArg = it.value();
3752 condOpInitValMap.insert({index, condOpArg});
3754 newCondOpArgs.emplace_back(condOpArg);
3755 newAfterBlockType.emplace_back(condOpArg.
getType());
3756 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3767 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3770 Block &newAfterBlock =
3772 newAfterBlockType, newAfterBlockArgLocs);
3774 Block &afterBlock = *op.getAfterBody();
3781 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3782 Value afterBlockArg, result;
3785 if (condOpInitValMap.count(i) != 0) {
3786 afterBlockArg = condOpInitValMap[i];
3787 result = afterBlockArg;
3789 afterBlockArg = newAfterBlock.getArgument(
j);
3790 result = newWhile.getResult(
j);
3793 newAfterBlockArgs[i] = afterBlockArg;
3794 newWhileResults[i] = result;
3797 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3799 newWhile.getBefore().begin());
3801 rewriter.
replaceOp(op, newWhileResults);
3835 LogicalResult matchAndRewrite(WhileOp op,
3837 auto term = op.getConditionOp();
3838 auto afterArgs = op.getAfterArguments();
3839 auto termArgs = term.getArgs();
3846 bool needUpdate =
false;
3847 for (
const auto &it :
3849 auto i =
static_cast<unsigned>(it.index());
3850 Value result = std::get<0>(it.value());
3851 Value afterArg = std::get<1>(it.value());
3852 Value termArg = std::get<2>(it.value());
3856 newResultsIndices.emplace_back(i);
3857 newTermArgs.emplace_back(termArg);
3858 newResultTypes.emplace_back(result.
getType());
3859 newArgLocs.emplace_back(result.
getLoc());
3874 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3877 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3884 newResults[it.value()] = newWhile.getResult(it.index());
3885 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3889 newWhile.getBefore().begin());
3891 Block &afterBlock = *op.getAfterBody();
3892 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3924 LogicalResult matchAndRewrite(scf::WhileOp op,
3926 using namespace scf;
3927 auto cond = op.getConditionOp();
3928 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3932 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3933 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3934 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3937 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3938 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3942 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3945 if (cmp2.getPredicate() == cmp.getPredicate())
3946 samePredicate =
true;
3947 else if (cmp2.getPredicate() ==
3949 samePredicate =
false;
3967 LogicalResult matchAndRewrite(WhileOp op,
3970 if (!llvm::any_of(op.getBeforeArguments(),
3971 [](
Value arg) { return arg.use_empty(); }))
3974 YieldOp yield = op.getYieldOp();
3979 llvm::BitVector argsToErase;
3981 size_t argsCount = op.getBeforeArguments().size();
3982 newYields.reserve(argsCount);
3983 newInits.reserve(argsCount);
3984 argsToErase.reserve(argsCount);
3985 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3986 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3987 if (beforeArg.use_empty()) {
3988 argsToErase.push_back(
true);
3990 argsToErase.push_back(
false);
3991 newYields.emplace_back(yieldValue);
3992 newInits.emplace_back(initValue);
3996 Block &beforeBlock = *op.getBeforeBody();
3997 Block &afterBlock = *op.getAfterBody();
4003 rewriter.
create<WhileOp>(loc, op.getResultTypes(), newInits,
4005 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4006 Block &newAfterBlock = *newWhileOp.getAfterBody();
4012 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4013 newBeforeBlock.getArguments());
4017 rewriter.
replaceOp(op, newWhileOp.getResults());
4026 LogicalResult matchAndRewrite(WhileOp op,
4028 ConditionOp condOp = op.getConditionOp();
4032 for (
Value arg : condOpArgs)
4033 argsSet.insert(arg);
4035 if (argsSet.size() == condOpArgs.size())
4038 llvm::SmallDenseMap<Value, unsigned> argsMap;
4040 argsMap.reserve(condOpArgs.size());
4041 newArgs.reserve(condOpArgs.size());
4042 for (
Value arg : condOpArgs) {
4043 if (!argsMap.count(arg)) {
4044 auto pos =
static_cast<unsigned>(argsMap.size());
4045 argsMap.insert({arg, pos});
4046 newArgs.emplace_back(arg);
4053 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4054 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4056 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4057 Block &newAfterBlock = *newWhileOp.getAfterBody();
4062 auto it = argsMap.find(arg);
4063 assert(it != argsMap.end());
4064 auto pos = it->second;
4065 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4066 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4074 Block &beforeBlock = *op.getBeforeBody();
4075 Block &afterBlock = *op.getAfterBody();
4077 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4078 newBeforeBlock.getArguments());
4079 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4087 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4089 if (args1.size() != args2.size())
4090 return std::nullopt;
4094 auto it = llvm::find(args2, arg1);
4095 if (it == args2.end())
4096 return std::nullopt;
4098 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4105 llvm::SmallDenseSet<Value> set;
4106 for (
Value arg : args) {
4107 if (!set.insert(arg).second)
4120 LogicalResult matchAndRewrite(WhileOp loop,
4122 auto oldBefore = loop.getBeforeBody();
4123 ConditionOp oldTerm = loop.getConditionOp();
4124 ValueRange beforeArgs = oldBefore->getArguments();
4126 if (beforeArgs == termArgs)
4129 if (hasDuplicates(termArgs))
4132 auto mapping = getArgsMapping(beforeArgs, termArgs);
4143 auto oldAfter = loop.getAfterBody();
4147 newResultTypes[
j] = loop.getResult(i).getType();
4149 auto newLoop = rewriter.
create<WhileOp>(
4150 loop.getLoc(), newResultTypes, loop.getInits(),
4152 auto newBefore = newLoop.getBeforeBody();
4153 auto newAfter = newLoop.getAfterBody();
4158 newResults[i] = newLoop.getResult(
j);
4159 newAfterArgs[i] = newAfter->getArgument(
j);
4163 newBefore->getArguments());
4175 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4176 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4177 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4178 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4192 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4195 caseValues.push_back(value);
4204 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4206 p <<
"case " << value <<
' ';
4212 if (getCases().size() != getCaseRegions().size()) {
4213 return emitOpError(
"has ")
4214 << getCaseRegions().size() <<
" case regions but "
4215 << getCases().size() <<
" case values";
4219 for (int64_t value : getCases())
4220 if (!valueSet.insert(value).second)
4221 return emitOpError(
"has duplicate case value: ") << value;
4223 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4225 return emitOpError(
"expected region to end with scf.yield, but got ")
4228 if (yield.getNumOperands() != getNumResults()) {
4229 return (emitOpError(
"expected each region to return ")
4230 << getNumResults() <<
" values, but " << name <<
" returns "
4231 << yield.getNumOperands())
4232 .attachNote(yield.getLoc())
4233 <<
"see yield operation here";
4235 for (
auto [idx, result, operand] :
4236 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4237 yield.getOperandTypes())) {
4238 if (result == operand)
4240 return (emitOpError(
"expected result #")
4241 << idx <<
" of each region to be " << result)
4242 .attachNote(yield.getLoc())
4243 << name <<
" returns " << operand <<
" here";
4248 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4251 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4257 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4259 Block &scf::IndexSwitchOp::getDefaultBlock() {
4260 return getDefaultRegion().
front();
4263 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4264 assert(idx < getNumCases() &&
"case index out-of-bounds");
4265 return getCaseRegions()[idx].front();
4268 void IndexSwitchOp::getSuccessorRegions(
4272 successors.emplace_back(getResults());
4276 llvm::copy(getRegions(), std::back_inserter(successors));
4279 void IndexSwitchOp::getEntrySuccessorRegions(
4282 FoldAdaptor adaptor(operands, *
this);
4285 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4287 llvm::copy(getRegions(), std::back_inserter(successors));
4293 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4294 if (caseValue == arg.getInt()) {
4295 successors.emplace_back(&caseRegion);
4299 successors.emplace_back(&getDefaultRegion());
4302 void IndexSwitchOp::getRegionInvocationBounds(
4304 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4305 if (!operandValue) {
4311 unsigned liveIndex = getNumRegions() - 1;
4312 const auto *it = llvm::find(getCases(), operandValue.getInt());
4313 if (it != getCases().end())
4314 liveIndex = std::distance(getCases().begin(), it);
4315 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4316 bounds.emplace_back(0, i == liveIndex);
4327 if (!maybeCst.has_value())
4329 int64_t cst = *maybeCst;
4330 int64_t caseIdx, e = op.getNumCases();
4331 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4332 if (cst == op.getCases()[caseIdx])
4336 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4337 : op.getDefaultRegion();
4338 Block &source = r.front();
4361 #define GET_OP_CLASSES
4362 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.