25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77 InParallelOp, ReduceReturnOp>();
78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80 ForallOp, InParallelOp, WhileOp, YieldOp>();
81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
86 builder.
create<scf::YieldOp>(loc);
91 template <
typename TerminatorTy>
93 StringRef errorMessage) {
96 terminatorOperation = ®ion.
front().
back();
97 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
101 if (terminatorOperation)
102 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
114 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
160 if (getRegion().empty())
161 return emitOpError(
"region needs to have at least one block");
162 if (getRegion().front().getNumArguments() > 0)
163 return emitOpError(
"region cannot have any arguments");
186 if (!llvm::hasSingleElement(op.
getRegion()))
235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->
getParentOp()))
245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248 yieldOp.getResults());
257 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
269 void ExecuteRegionOp::getSuccessorRegions(
287 assert((point.
isParent() || point == getParentOp().getAfter()) &&
288 "condition op can only exit the loop or branch to the after"
291 return getArgsMutable();
294 void ConditionOp::getSuccessorRegions(
296 FoldAdaptor adaptor(operands, *
this);
298 WhileOp whileOp = getParentOp();
302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303 if (!boolAttr || boolAttr.getValue())
304 regions.emplace_back(&whileOp.getAfter(),
305 whileOp.getAfter().getArguments());
306 if (!boolAttr || !boolAttr.getValue())
307 regions.emplace_back(whileOp.getResults());
316 BodyBuilderFn bodyBuilder) {
321 for (
Value v : initArgs)
327 for (
Value v : initArgs)
333 if (initArgs.empty() && !bodyBuilder) {
334 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
335 }
else if (bodyBuilder) {
345 if (getInitArgs().size() != getNumResults())
347 "mismatch in number of loop-carried values and defined values");
352 LogicalResult ForOp::verifyRegions() {
357 "expected induction variable to be same type as bounds and step");
359 if (getNumRegionIterArgs() != getNumResults())
361 "mismatch in number of basic block args and defined values");
363 auto initArgs = getInitArgs();
364 auto iterArgs = getRegionIterArgs();
365 auto opResults = getResults();
367 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 return emitOpError() <<
"types mismatch between " << i
370 <<
"th iter operand and defined value";
372 return emitOpError() <<
"types mismatch between " << i
373 <<
"th iter region arg and defined value";
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
401 std::optional<int64_t> tripCount =
403 if (!tripCount.has_value() || tripCount != 1)
407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
414 llvm::append_range(bbArgReplacements, getInitArgs());
418 getOperation()->getIterator(), bbArgReplacements);
434 StringRef prefix =
"") {
435 assert(blocksArgs.size() == initializers.size() &&
436 "expected same length of arguments and initializers");
437 if (initializers.empty())
441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
442 p << std::get<0>(it) <<
" = " << std::get<1>(it);
448 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
452 if (!getInitArgs().empty())
453 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 p <<
" : " << t <<
' ';
459 !getInitArgs().empty());
481 regionArgs.push_back(inductionVariable);
491 if (regionArgs.size() != result.
types.size() + 1)
494 "mismatch in number of loop-carried values and defined values");
503 regionArgs.front().type = type;
509 for (
auto argOperandType :
510 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
511 Type type = std::get<2>(argOperandType);
512 std::get<0>(argOperandType).type = type;
524 ForOp::ensureTerminator(*body, builder, result.
location);
536 return getBody()->getArguments().drop_front(getNumInductionVars());
540 return getInitArgsMutable();
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
546 bool replaceInitOperandUsesInLoop,
551 auto inits = llvm::to_vector(getInitArgs());
552 inits.append(newInitOperands.begin(), newInitOperands.end());
553 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
559 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
561 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
566 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567 assert(newInitOperands.size() == newYieldedValues.size() &&
568 "expected as many new yield values as new iter operands");
570 yieldOp.getResultsMutable().append(newYieldedValues);
576 newLoop.getBody()->getArguments().take_front(
577 getBody()->getNumArguments()));
579 if (replaceInitOperandUsesInLoop) {
582 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
593 newLoop->getResults().take_front(getNumResults()));
594 return cast<LoopLikeOpInterface>(newLoop.getOperation());
598 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
601 assert(ivArg.getOwner() &&
"unlinked block argument");
602 auto *containingOp = ivArg.getOwner()->getParentOp();
603 return dyn_cast_or_null<ForOp>(containingOp);
607 return getInitArgs();
624 for (
auto [lb, ub, step] :
625 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
627 if (!tripCount.has_value() || *tripCount != 1)
636 return getBody()->getArguments().drop_front(getRank());
640 return getOutputsMutable();
646 scf::InParallelOp terminator = forallOp.getTerminator();
651 bbArgReplacements.append(forallOp.getOutputs().begin(),
652 forallOp.getOutputs().end());
656 forallOp->getIterator(), bbArgReplacements);
661 results.reserve(forallOp.getResults().size());
662 for (
auto &yieldingOp : terminator.getYieldingOps()) {
663 auto parallelInsertSliceOp =
664 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
666 Value dst = parallelInsertSliceOp.getDest();
667 Value src = parallelInsertSliceOp.getSource();
668 if (llvm::isa<TensorType>(src.
getType())) {
669 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
670 forallOp.getLoc(), dst.
getType(), src, dst,
671 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672 parallelInsertSliceOp.getStrides(),
673 parallelInsertSliceOp.getStaticOffsets(),
674 parallelInsertSliceOp.getStaticSizes(),
675 parallelInsertSliceOp.getStaticStrides()));
677 llvm_unreachable(
"unsupported terminator");
692 assert(lbs.size() == ubs.size() &&
693 "expected the same number of lower and upper bounds");
694 assert(lbs.size() == steps.size() &&
695 "expected the same number of lower bounds and steps");
700 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
702 assert(results.size() == iterArgs.size() &&
703 "loop nest body must return as many values as loop has iteration "
705 return LoopNest{{}, std::move(results)};
713 loops.reserve(lbs.size());
714 ivs.reserve(lbs.size());
717 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
718 auto loop = builder.
create<scf::ForOp>(
719 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725 currentIterArgs = args;
726 currentLoc = nestedLoc;
732 loops.push_back(loop);
736 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
738 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
745 ? bodyBuilder(builder, currentLoc, ivs,
746 loops.back().getRegionIterArgs())
748 assert(results.size() == iterArgs.size() &&
749 "loop nest body must return as many values as loop has iteration "
752 builder.
create<scf::YieldOp>(loc, results);
756 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757 return LoopNest{std::move(loops), std::move(nestResults)};
765 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
770 bodyBuilder(nestedBuilder, nestedLoc, ivs);
779 assert(operand.
getOwner() == forOp);
784 "expected an iter OpOperand");
786 "Expected a different type");
788 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
790 newIterOperands.push_back(replacement);
793 newIterOperands.push_back(opOperand.get());
797 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
798 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799 forOp.getStep(), newIterOperands);
800 newForOp->
setAttrs(forOp->getAttrs());
801 Block &newBlock = newForOp.getRegion().
front();
809 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
811 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
816 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
819 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
822 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
823 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824 clonedYieldOp.getOperand(yieldIdx));
826 newYieldOperands[yieldIdx] = castOut;
827 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828 rewriter.
eraseOp(clonedYieldOp);
833 newResults[yieldIdx] =
834 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
856 LogicalResult matchAndRewrite(scf::ForOp forOp,
858 bool canonicalize =
false;
865 int64_t numResults = forOp.getNumResults();
867 keepMask.reserve(numResults);
870 newBlockTransferArgs.reserve(1 + numResults);
871 newBlockTransferArgs.push_back(
Value());
872 newIterArgs.reserve(forOp.getInitArgs().size());
873 newYieldValues.reserve(numResults);
874 newResultValues.reserve(numResults);
875 for (
auto it : llvm::zip(forOp.getInitArgs(),
876 forOp.getRegionIterArgs(),
878 forOp.getYieldedValues()
886 bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
887 (std::get<1>(it).use_empty() &&
888 (std::get<0>(it) == std::get<3>(it) ||
889 std::get<2>(it).use_empty())));
890 keepMask.push_back(!forwarded);
891 canonicalize |= forwarded;
893 newBlockTransferArgs.push_back(std::get<0>(it));
894 newResultValues.push_back(std::get<0>(it));
897 newIterArgs.push_back(std::get<0>(it));
898 newYieldValues.push_back(std::get<3>(it));
899 newBlockTransferArgs.push_back(
Value());
900 newResultValues.push_back(
Value());
906 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
907 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
908 forOp.getStep(), newIterArgs);
909 newForOp->
setAttrs(forOp->getAttrs());
910 Block &newBlock = newForOp.getRegion().
front();
914 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
916 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
917 Value &newResultVal = newResultValues[idx];
918 assert((blockTransferArg && newResultVal) ||
919 (!blockTransferArg && !newResultVal));
920 if (!blockTransferArg) {
921 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
922 newResultVal = newForOp.getResult(collapsedIdx++);
928 "unexpected argument size mismatch");
933 if (newIterArgs.empty()) {
934 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
937 rewriter.
replaceOp(forOp, newResultValues);
942 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
946 filteredOperands.reserve(newResultValues.size());
947 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
949 filteredOperands.push_back(mergedTerminator.getOperand(idx));
950 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
954 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
955 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
956 cloneFilteredTerminator(mergedYieldOp);
957 rewriter.
eraseOp(mergedYieldOp);
958 rewriter.
replaceOp(forOp, newResultValues);
966 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
967 IntegerAttr clb, cub;
969 llvm::APInt lbValue = clb.getValue();
970 llvm::APInt ubValue = cub.getValue();
971 return (ubValue - lbValue).getSExtValue();
980 return diff.getSExtValue();
990 LogicalResult matchAndRewrite(ForOp op,
994 if (op.getLowerBound() == op.getUpperBound()) {
995 rewriter.
replaceOp(op, op.getInitArgs());
999 std::optional<int64_t> diff =
1000 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1006 rewriter.
replaceOp(op, op.getInitArgs());
1010 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1011 if (!maybeStepValue)
1016 llvm::APInt stepValue = *maybeStepValue;
1017 if (stepValue.sge(*diff)) {
1019 blockArgs.reserve(op.getInitArgs().size() + 1);
1020 blockArgs.push_back(op.getLowerBound());
1021 llvm::append_range(blockArgs, op.getInitArgs());
1028 if (!llvm::hasSingleElement(block))
1032 if (llvm::any_of(op.getYieldedValues(),
1033 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1035 rewriter.
replaceOp(op, op.getYieldedValues());
1069 LogicalResult matchAndRewrite(ForOp op,
1071 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.
getResults())) {
1072 OpOperand &iterOpOperand = std::get<0>(it);
1074 if (!incomingCast ||
1075 incomingCast.getSource().getType() == incomingCast.getType())
1080 incomingCast.getDest().getType(),
1081 incomingCast.getSource().getType()))
1083 if (!std::get<1>(it).hasOneUse())
1089 return b.
create<tensor::CastOp>(loc, type, source);
1093 incomingCast.getSource(), castFn));
1104 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1108 std::optional<APInt> ForOp::getConstantStep() {
1111 return step.getValue();
1115 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1116 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1122 if (
auto constantStep = getConstantStep())
1123 if (*constantStep == 1)
1136 unsigned numLoops = getRank();
1138 if (getNumResults() != getOutputs().size())
1139 return emitOpError(
"produces ")
1140 << getNumResults() <<
" results, but has only "
1141 << getOutputs().size() <<
" outputs";
1144 auto *body = getBody();
1146 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1147 for (int64_t i = 0; i < numLoops; ++i)
1149 return emitOpError(
"expects ")
1150 << i <<
"-th block argument to be an index";
1151 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1153 return emitOpError(
"type mismatch between ")
1154 << i <<
"-th output and corresponding block argument";
1155 if (getMapping().has_value() && !getMapping()->empty()) {
1156 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1157 return emitOpError() <<
"mapping attribute size must match op rank";
1158 for (
auto map : getMapping()->getValue()) {
1159 if (!isa<DeviceMappingAttrInterface>(map))
1160 return emitOpError()
1168 getStaticLowerBound(),
1169 getDynamicLowerBound())))
1172 getStaticUpperBound(),
1173 getDynamicUpperBound())))
1176 getStaticStep(), getDynamicStep())))
1184 p <<
" (" << getInductionVars();
1185 if (isNormalized()) {
1206 if (!getRegionOutArgs().empty())
1207 p <<
"-> (" << getResultTypes() <<
") ";
1208 p.printRegion(getRegion(),
1210 getNumResults() > 0);
1211 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1212 getStaticLowerBoundAttrName(),
1213 getStaticUpperBoundAttrName(),
1214 getStaticStepAttrName()});
1219 auto indexType = b.getIndexType();
1239 unsigned numLoops = ivs.size();
1274 if (outOperands.size() != result.
types.size())
1276 "mismatch between out operands and types");
1286 std::unique_ptr<Region> region = std::make_unique<Region>();
1287 for (
auto &iv : ivs) {
1288 iv.type = b.getIndexType();
1289 regionArgs.push_back(iv);
1292 auto &out = it.value();
1293 out.type = result.
types[it.index()];
1294 regionArgs.push_back(out);
1300 ForallOp::ensureTerminator(*region, b, result.
location);
1312 {static_cast<int32_t>(dynamicLbs.size()),
1313 static_cast<int32_t>(dynamicUbs.size()),
1314 static_cast<int32_t>(dynamicSteps.size()),
1315 static_cast<int32_t>(outOperands.size())}));
1320 void ForallOp::build(
1324 std::optional<ArrayAttr> mapping,
1345 "operandSegmentSizes",
1347 static_cast<int32_t>(dynamicUbs.size()),
1348 static_cast<int32_t>(dynamicSteps.size()),
1349 static_cast<int32_t>(outputs.size())}));
1350 if (mapping.has_value()) {
1369 if (!bodyBuilderFn) {
1370 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1377 void ForallOp::build(
1380 std::optional<ArrayAttr> mapping,
1382 unsigned numLoops = ubs.size();
1385 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1389 bool ForallOp::isNormalized() {
1393 return intValue.has_value() && intValue == val;
1396 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1405 ForallOp>::ensureTerminator(region, builder, loc);
1412 InParallelOp ForallOp::getTerminator() {
1413 return cast<InParallelOp>(getBody()->getTerminator());
1418 InParallelOp inParallelOp = getTerminator();
1419 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1420 if (
auto parallelInsertSliceOp =
1421 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1422 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1423 storeOps.push_back(parallelInsertSliceOp);
1429 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1434 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1436 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1440 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1442 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1446 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1452 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1455 assert(tidxArg.getOwner() &&
"unlinked block argument");
1456 auto *containingOp = tidxArg.getOwner()->getParentOp();
1457 return dyn_cast<ForallOp>(containingOp);
1465 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1467 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1471 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1474 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1483 LogicalResult matchAndRewrite(ForallOp op,
1498 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1499 op.setStaticLowerBound(staticLowerBound);
1503 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1504 op.setStaticUpperBound(staticUpperBound);
1507 op.getDynamicStepMutable().assign(dynamicStep);
1508 op.setStaticStep(staticStep);
1510 op->
setAttr(ForallOp::getOperandSegmentSizeAttr(),
1512 {static_cast<int32_t>(dynamicLowerBound.size()),
1513 static_cast<int32_t>(dynamicUpperBound.size()),
1514 static_cast<int32_t>(dynamicStep.size()),
1515 static_cast<int32_t>(op.getNumResults())}));
1597 LogicalResult matchAndRewrite(ForallOp forallOp,
1616 for (
OpResult result : forallOp.getResults()) {
1617 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1618 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1619 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1620 resultToDelete.insert(result);
1622 resultToReplace.push_back(result);
1623 newOuts.push_back(opOperand->
get());
1629 if (resultToDelete.empty())
1637 for (
OpResult result : resultToDelete) {
1638 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1639 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1641 forallOp.getCombiningOps(blockArg);
1642 for (
Operation *combiningOp : combiningOps)
1643 rewriter.
eraseOp(combiningOp);
1648 auto newForallOp = rewriter.
create<scf::ForallOp>(
1649 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1650 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1651 forallOp.getMapping(),
1656 Block *loopBody = forallOp.getBody();
1657 Block *newLoopBody = newForallOp.getBody();
1662 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1669 for (
OpResult result : forallOp.getResults()) {
1670 if (resultToDelete.count(result)) {
1671 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1673 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1676 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1680 for (
auto &&[oldResult, newResult] :
1681 llvm::zip(resultToReplace, newForallOp->getResults()))
1687 for (
OpResult oldResult : resultToDelete)
1689 forallOp.getTiedOpOperand(oldResult)->get());
1694 struct ForallOpSingleOrZeroIterationDimsFolder
1698 LogicalResult matchAndRewrite(ForallOp op,
1701 if (op.getMapping().has_value() && !op.getMapping()->empty())
1709 for (
auto [lb, ub, step, iv] :
1710 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1711 op.getMixedStep(), op.getInductionVars())) {
1713 if (numIterations.has_value()) {
1715 if (*numIterations == 0) {
1716 rewriter.
replaceOp(op, op.getOutputs());
1721 if (*numIterations == 1) {
1726 newMixedLowerBounds.push_back(lb);
1727 newMixedUpperBounds.push_back(ub);
1728 newMixedSteps.push_back(step);
1732 if (newMixedLowerBounds.empty()) {
1738 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1740 op,
"no dimensions have 0 or 1 iterations");
1745 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1746 newMixedUpperBounds, newMixedSteps,
1747 op.getOutputs(), std::nullopt,
nullptr);
1748 newOp.getBodyRegion().getBlocks().clear();
1753 newOp.getStaticLowerBoundAttrName(),
1754 newOp.getStaticUpperBoundAttrName(),
1755 newOp.getStaticStepAttrName()};
1756 for (
const auto &namedAttr : op->
getAttrs()) {
1757 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1760 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1764 newOp.getRegion().
begin(), mapping);
1765 rewriter.
replaceOp(op, newOp.getResults());
1770 struct FoldTensorCastOfOutputIntoForallOp
1779 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1781 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1784 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1791 castOp.getSource().getType())) {
1795 tensorCastProducers[en.index()] =
1796 TypeCast{castOp.getSource().getType(), castOp.getType()};
1797 newOutputTensors[en.index()] = castOp.getSource();
1800 if (tensorCastProducers.empty())
1805 auto newForallOp = rewriter.
create<ForallOp>(
1806 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1807 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1809 auto castBlockArgs =
1810 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1811 for (
auto [index, cast] : tensorCastProducers) {
1812 Value &oldTypeBBArg = castBlockArgs[index];
1813 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1814 nestedLoc, cast.dstType, oldTypeBBArg);
1819 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1820 ivsBlockArgs.append(castBlockArgs);
1822 bbArgs.front().getParentBlock(), ivsBlockArgs);
1828 auto terminator = newForallOp.getTerminator();
1829 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1830 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1831 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1832 insertSliceOp.getDestMutable().assign(outputBlockArg);
1838 for (
auto &item : tensorCastProducers) {
1839 Value &oldTypeResult = castResults[item.first];
1840 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1843 rewriter.
replaceOp(forallOp, castResults);
1852 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1853 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1854 ForallOpSingleOrZeroIterationDimsFolder>(context);
1883 scf::ForallOp forallOp =
1884 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1886 return this->emitOpError(
"expected forall op parent");
1889 for (
Operation &op : getRegion().front().getOperations()) {
1890 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1891 return this->emitOpError(
"expected only ")
1892 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1896 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1898 if (!llvm::is_contained(regionOutArgs, dest))
1899 return op.
emitOpError(
"may only insert into an output block argument");
1916 std::unique_ptr<Region> region = std::make_unique<Region>();
1920 if (region->empty())
1930 OpResult InParallelOp::getParentResult(int64_t idx) {
1931 return getOperation()->getParentOp()->getResult(idx);
1935 return llvm::to_vector<4>(
1936 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1938 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1939 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1944 return getRegion().front().getOperations();
1952 assert(a &&
"expected non-empty operation");
1953 assert(b &&
"expected non-empty operation");
1958 if (ifOp->isProperAncestor(b))
1961 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1962 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1964 ifOp = ifOp->getParentOfType<IfOp>();
1972 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1973 IfOp::Adaptor adaptor,
1975 if (adaptor.getRegions().empty())
1977 Region *r = &adaptor.getThenRegion();
1980 Block &b = r->front();
1983 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
1986 TypeRange types = yieldOp.getOperandTypes();
1987 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1994 return build(builder, result, resultTypes, cond,
false,
2000 bool addElseBlock) {
2001 assert((!addElseBlock || addThenBlock) &&
2002 "must not create else block w/o then block");
2017 bool withElseRegion) {
2018 build(builder, result,
TypeRange{}, cond, withElseRegion);
2030 if (resultTypes.empty())
2031 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2035 if (withElseRegion) {
2037 if (resultTypes.empty())
2038 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2045 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2052 thenBuilder(builder, result.
location);
2058 elseBuilder(builder, result.
location);
2065 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2067 inferredReturnTypes))) {
2068 result.
addTypes(inferredReturnTypes);
2073 if (getNumResults() != 0 && getElseRegion().empty())
2074 return emitOpError(
"must have an else block if defining values");
2112 bool printBlockTerminators =
false;
2114 p <<
" " << getCondition();
2115 if (!getResults().empty()) {
2116 p <<
" -> (" << getResultTypes() <<
")";
2118 printBlockTerminators =
true;
2123 printBlockTerminators);
2126 auto &elseRegion = getElseRegion();
2127 if (!elseRegion.
empty()) {
2131 printBlockTerminators);
2148 Region *elseRegion = &this->getElseRegion();
2149 if (elseRegion->
empty())
2157 FoldAdaptor adaptor(operands, *
this);
2158 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2159 if (!boolAttr || boolAttr.getValue())
2160 regions.emplace_back(&getThenRegion());
2163 if (!boolAttr || !boolAttr.getValue()) {
2164 if (!getElseRegion().empty())
2165 regions.emplace_back(&getElseRegion());
2167 regions.emplace_back(getResults());
2171 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2174 if (getElseRegion().empty())
2177 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2184 getConditionMutable().assign(xorStmt.getLhs());
2188 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2189 getElseRegion().getBlocks());
2190 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2191 getThenRegion().getBlocks(), thenBlock);
2195 void IfOp::getRegionInvocationBounds(
2198 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2201 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2202 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2205 invocationBounds.assign(2, {0, 1});
2221 llvm::transform(usedResults, std::back_inserter(usedOperands),
2226 [&]() { yieldOp->setOperands(usedOperands); });
2229 LogicalResult matchAndRewrite(IfOp op,
2233 llvm::copy_if(op.
getResults(), std::back_inserter(usedResults),
2234 [](
OpResult result) { return !result.use_empty(); });
2242 llvm::transform(usedResults, std::back_inserter(newTypes),
2247 rewriter.
create<IfOp>(op.
getLoc(), newTypes, op.getCondition());
2253 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2254 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2259 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2268 LogicalResult matchAndRewrite(IfOp op,
2276 else if (!op.getElseRegion().empty())
2290 LogicalResult matchAndRewrite(IfOp op,
2295 auto cond = op.getCondition();
2296 auto thenYieldArgs = op.thenYield().
getOperands();
2297 auto elseYieldArgs = op.elseYield().
getOperands();
2300 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2303 nonHoistable.push_back(trueVal.getType());
2310 IfOp replacement = rewriter.
create<IfOp>(op.
getLoc(), nonHoistable, cond,
2312 if (replacement.thenBlock())
2313 rewriter.
eraseBlock(replacement.thenBlock());
2314 replacement.getThenRegion().takeBody(op.getThenRegion());
2315 replacement.getElseRegion().takeBody(op.getElseRegion());
2318 assert(thenYieldArgs.size() == results.size());
2319 assert(elseYieldArgs.size() == results.size());
2324 for (
const auto &it :
2326 Value trueVal = std::get<0>(it.value());
2327 Value falseVal = std::get<1>(it.value());
2330 results[it.index()] = replacement.getResult(trueYields.size());
2331 trueYields.push_back(trueVal);
2332 falseYields.push_back(falseVal);
2333 }
else if (trueVal == falseVal)
2334 results[it.index()] = trueVal;
2336 results[it.index()] = rewriter.
create<arith::SelectOp>(
2337 op.
getLoc(), cond, trueVal, falseVal);
2367 LogicalResult matchAndRewrite(IfOp op,
2374 bool changed =
false;
2379 Value constantTrue =
nullptr;
2380 Value constantFalse =
nullptr;
2383 llvm::make_early_inc_range(op.getCondition().
getUses())) {
2388 constantTrue = rewriter.
create<arith::ConstantOp>(
2392 [&]() { use.
set(constantTrue); });
2398 constantFalse = rewriter.
create<arith::ConstantOp>(
2402 [&]() { use.
set(constantFalse); });
2406 return success(changed);
2446 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2449 LogicalResult matchAndRewrite(IfOp op,
2456 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2458 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2461 op.getOperation()->getIterator());
2462 bool changed =
false;
2464 for (
auto [trueResult, falseResult, opResult] :
2465 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2467 if (trueResult == falseResult) {
2468 if (!opResult.use_empty()) {
2469 opResult.replaceAllUsesWith(trueResult);
2480 bool trueVal = trueYield.
getValue();
2481 bool falseVal = falseYield.
getValue();
2482 if (!trueVal && falseVal) {
2483 if (!opResult.use_empty()) {
2484 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2486 op.
getLoc(), op.getCondition(),
2496 if (trueVal && !falseVal) {
2497 if (!opResult.use_empty()) {
2498 opResult.replaceAllUsesWith(op.getCondition());
2503 return success(changed);
2531 LogicalResult matchAndRewrite(IfOp nextIf,
2533 Block *parent = nextIf->getBlock();
2534 if (nextIf == &parent->
front())
2537 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2545 Block *nextThen =
nullptr;
2546 Block *nextElse =
nullptr;
2547 if (nextIf.getCondition() == prevIf.getCondition()) {
2548 nextThen = nextIf.thenBlock();
2549 if (!nextIf.getElseRegion().empty())
2550 nextElse = nextIf.elseBlock();
2552 if (arith::XOrIOp notv =
2553 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2554 if (notv.getLhs() == prevIf.getCondition() &&
2556 nextElse = nextIf.thenBlock();
2557 if (!nextIf.getElseRegion().empty())
2558 nextThen = nextIf.elseBlock();
2561 if (arith::XOrIOp notv =
2562 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2563 if (notv.getLhs() == nextIf.getCondition() &&
2565 nextElse = nextIf.thenBlock();
2566 if (!nextIf.getElseRegion().empty())
2567 nextThen = nextIf.elseBlock();
2571 if (!nextThen && !nextElse)
2575 if (!prevIf.getElseRegion().empty())
2576 prevElseYielded = prevIf.elseYield().getOperands();
2579 for (
auto it : llvm::zip(prevIf.getResults(),
2580 prevIf.thenYield().getOperands(), prevElseYielded))
2582 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2586 use.
set(std::get<1>(it));
2591 use.
set(std::get<2>(it));
2597 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2599 IfOp combinedIf = rewriter.
create<IfOp>(
2600 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2601 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2604 combinedIf.getThenRegion(),
2605 combinedIf.getThenRegion().begin());
2608 YieldOp thenYield = combinedIf.thenYield();
2609 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2610 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2614 llvm::append_range(mergedYields, thenYield2.getOperands());
2615 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2621 combinedIf.getElseRegion(),
2622 combinedIf.getElseRegion().begin());
2625 if (combinedIf.getElseRegion().empty()) {
2627 combinedIf.getElseRegion(),
2628 combinedIf.getElseRegion().
begin());
2630 YieldOp elseYield = combinedIf.elseYield();
2631 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2632 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2637 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2639 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2648 if (pair.index() < prevIf.getNumResults())
2649 prevValues.push_back(pair.value());
2651 nextValues.push_back(pair.value());
2663 LogicalResult matchAndRewrite(IfOp ifOp,
2666 if (ifOp.getNumResults())
2668 Block *elseBlock = ifOp.elseBlock();
2669 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2673 newIfOp.getThenRegion().begin());
2698 LogicalResult matchAndRewrite(IfOp op,
2700 auto nestedOps = op.thenBlock()->without_terminator();
2702 if (!llvm::hasSingleElement(nestedOps))
2706 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2709 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2713 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2719 llvm::append_range(elseYield, op.elseYield().
getOperands());
2733 if (tup.value().getDefiningOp() == nestedIf) {
2734 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2735 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2736 elseYield[tup.index()]) {
2741 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2754 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2757 elseYieldsToUpgradeToSelect.push_back(tup.index());
2761 Value newCondition = rewriter.
create<arith::AndIOp>(
2762 loc, op.getCondition(), nestedIf.getCondition());
2767 llvm::append_range(results, newIf.getResults());
2770 for (
auto idx : elseYieldsToUpgradeToSelect)
2771 results[idx] = rewriter.
create<arith::SelectOp>(
2772 op.
getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2774 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2777 if (!elseYield.empty()) {
2780 rewriter.
create<YieldOp>(loc, elseYield);
2791 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2792 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2793 RemoveStaticCondition, RemoveUnusedResults,
2794 ReplaceIfYieldWithConditionOrValue>(context);
2797 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2798 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2799 Block *IfOp::elseBlock() {
2800 Region &r = getElseRegion();
2805 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2811 void ParallelOp::build(
2821 ParallelOp::getOperandSegmentSizeAttr(),
2823 static_cast<int32_t>(upperBounds.size()),
2824 static_cast<int32_t>(steps.size()),
2825 static_cast<int32_t>(initVals.size())}));
2829 unsigned numIVs = steps.size();
2835 if (bodyBuilderFn) {
2837 bodyBuilderFn(builder, result.
location,
2842 if (initVals.empty())
2843 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2846 void ParallelOp::build(
2853 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2856 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2860 wrapper = wrappedBuilderFn;
2862 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2871 if (stepValues.empty())
2873 "needs at least one tuple element for lowerBound, upperBound and step");
2876 for (
Value stepValue : stepValues)
2879 return emitOpError(
"constant step operand must be positive");
2883 Block *body = getBody();
2885 return emitOpError() <<
"expects the same number of induction variables: "
2887 <<
" as bound and step values: " << stepValues.size();
2889 if (!arg.getType().isIndex())
2891 "expects arguments for the induction variable to be of index type");
2894 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2895 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2900 auto resultsSize = getResults().size();
2901 auto reductionsSize = reduceOp.getReductions().size();
2902 auto initValsSize = getInitVals().size();
2903 if (resultsSize != reductionsSize)
2904 return emitOpError() <<
"expects number of results: " << resultsSize
2905 <<
" to be the same as number of reductions: "
2907 if (resultsSize != initValsSize)
2908 return emitOpError() <<
"expects number of results: " << resultsSize
2909 <<
" to be the same as number of initial values: "
2913 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2914 auto resultType = getOperation()->getResult(i).getType();
2915 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2916 if (resultType != reductionOperandType)
2917 return reduceOp.emitOpError()
2918 <<
"expects type of " << i
2919 <<
"-th reduction operand: " << reductionOperandType
2920 <<
" to be the same as the " << i
2921 <<
"-th result type: " << resultType;
2969 for (
auto &iv : ivs)
2976 ParallelOp::getOperandSegmentSizeAttr(),
2978 static_cast<int32_t>(upper.size()),
2979 static_cast<int32_t>(steps.size()),
2980 static_cast<int32_t>(initVals.size())}));
2989 ParallelOp::ensureTerminator(*body, builder, result.
location);
2994 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2995 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2996 if (!getInitVals().empty())
2997 p <<
" init (" << getInitVals() <<
")";
3002 (*this)->getAttrs(),
3003 ParallelOp::getOperandSegmentSizeAttr());
3008 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3012 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3016 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3020 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3025 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3027 return ParallelOp();
3028 assert(ivArg.getOwner() &&
"unlinked block argument");
3029 auto *containingOp = ivArg.getOwner()->getParentOp();
3030 return dyn_cast<ParallelOp>(containingOp);
3035 struct ParallelOpSingleOrZeroIterationDimsFolder
3039 LogicalResult matchAndRewrite(ParallelOp op,
3046 for (
auto [lb, ub, step, iv] :
3047 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3048 op.getInductionVars())) {
3050 if (numIterations.has_value()) {
3052 if (*numIterations == 0) {
3053 rewriter.
replaceOp(op, op.getInitVals());
3058 if (*numIterations == 1) {
3063 newLowerBounds.push_back(lb);
3064 newUpperBounds.push_back(ub);
3065 newSteps.push_back(step);
3068 if (newLowerBounds.size() == op.getLowerBound().size())
3071 if (newLowerBounds.empty()) {
3075 results.reserve(op.getInitVals().size());
3076 for (
auto &bodyOp : op.getBody()->without_terminator())
3077 rewriter.
clone(bodyOp, mapping);
3078 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3079 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3080 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3081 auto initValIndex = results.size();
3082 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3086 rewriter.
clone(reduceBodyOp, mapping);
3089 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3090 results.push_back(result);
3098 rewriter.
create<ParallelOp>(op.
getLoc(), newLowerBounds, newUpperBounds,
3099 newSteps, op.getInitVals(),
nullptr);
3105 newOp.getRegion().
begin(), mapping);
3106 rewriter.
replaceOp(op, newOp.getResults());
3114 LogicalResult matchAndRewrite(ParallelOp op,
3116 Block &outerBody = *op.getBody();
3120 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3125 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3126 llvm::is_contained(innerOp.getUpperBound(), val) ||
3127 llvm::is_contained(innerOp.getStep(), val))
3131 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3136 Block &innerBody = *innerOp.getBody();
3137 assert(iterVals.size() ==
3145 builder.
clone(op, mapping);
3148 auto concatValues = [](
const auto &first,
const auto &second) {
3150 ret.reserve(first.size() + second.size());
3151 ret.assign(first.begin(), first.end());
3152 ret.append(second.begin(), second.end());
3156 auto newLowerBounds =
3157 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3158 auto newUpperBounds =
3159 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3160 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3163 newSteps, std::nullopt,
3174 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3183 void ParallelOp::getSuccessorRegions(
3201 for (
Value v : operands) {
3210 LogicalResult ReduceOp::verifyRegions() {
3213 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3214 auto type = getOperands()[i].getType();
3217 return emitOpError() << i <<
"-th reduction has an empty body";
3220 return arg.getType() != type;
3222 return emitOpError() <<
"expected two block arguments with type " << type
3223 <<
" in the " << i <<
"-th reduction region";
3227 return emitOpError(
"reduction bodies must be terminated with an "
3228 "'scf.reduce.return' op");
3247 Block *reductionBody = getOperation()->getBlock();
3249 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3251 if (expectedResultType != getResult().
getType())
3252 return emitOpError() <<
"must have type " << expectedResultType
3253 <<
" (the type of the reduction inputs)";
3263 ValueRange inits, BodyBuilderFn beforeBuilder,
3264 BodyBuilderFn afterBuilder) {
3272 beforeArgLocs.reserve(inits.size());
3273 for (
Value operand : inits) {
3274 beforeArgLocs.push_back(operand.getLoc());
3279 inits.getTypes(), beforeArgLocs);
3288 resultTypes, afterArgLocs);
3294 ConditionOp WhileOp::getConditionOp() {
3295 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3298 YieldOp WhileOp::getYieldOp() {
3299 return cast<YieldOp>(getAfterBody()->getTerminator());
3302 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3303 return getYieldOp().getResultsMutable();
3307 return getBeforeBody()->getArguments();
3311 return getAfterBody()->getArguments();
3315 return getBeforeArguments();
3319 assert(point == getBefore() &&
3320 "WhileOp is expected to branch only to the first region");
3328 regions.emplace_back(&getBefore(), getBefore().getArguments());
3332 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3333 "there are only two regions in a WhileOp");
3335 if (point == getAfter()) {
3336 regions.emplace_back(&getBefore(), getBefore().getArguments());
3340 regions.emplace_back(getResults());
3341 regions.emplace_back(&getAfter(), getAfter().getArguments());
3345 return {&getBefore(), &getAfter()};
3366 FunctionType functionType;
3371 result.
addTypes(functionType.getResults());
3373 if (functionType.getNumInputs() != operands.size()) {
3375 <<
"expected as many input types as operands "
3376 <<
"(expected " << operands.size() <<
" got "
3377 << functionType.getNumInputs() <<
")";
3387 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3388 regionArgs[i].type = functionType.getInput(i);
3390 return failure(parser.
parseRegion(*before, regionArgs) ||
3410 template <
typename OpTy>
3413 if (left.size() != right.size())
3414 return op.
emitOpError(
"expects the same number of ") << message;
3416 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3417 if (left[i] != right[i]) {
3420 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3421 <<
" and " << right[i];
3430 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3432 "expects the 'before' region to terminate with 'scf.condition'");
3433 if (!beforeTerminator)
3436 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3438 "expects the 'after' region to terminate with 'scf.yield'");
3439 return success(afterTerminator !=
nullptr);
3465 LogicalResult matchAndRewrite(WhileOp op,
3467 auto term = op.getConditionOp();
3471 Value constantTrue =
nullptr;
3473 bool replaced =
false;
3474 for (
auto yieldedAndBlockArgs :
3475 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3476 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3477 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3479 constantTrue = rewriter.
create<arith::ConstantOp>(
3480 op.
getLoc(), term.getCondition().getType(),
3489 return success(replaced);
3541 struct RemoveLoopInvariantArgsFromBeforeBlock
3545 LogicalResult matchAndRewrite(WhileOp op,
3547 Block &afterBlock = *op.getAfterBody();
3549 ConditionOp condOp = op.getConditionOp();
3554 bool canSimplify =
false;
3555 for (
const auto &it :
3557 auto index =
static_cast<unsigned>(it.index());
3558 auto [initVal, yieldOpArg] = it.value();
3561 if (yieldOpArg == initVal) {
3570 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3571 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3572 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3573 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3586 for (
const auto &it :
3588 auto index =
static_cast<unsigned>(it.index());
3589 auto [initVal, yieldOpArg] = it.value();
3593 if (yieldOpArg == initVal) {
3594 beforeBlockInitValMap.insert({index, initVal});
3602 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3603 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3604 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3605 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3606 beforeBlockInitValMap.insert({index, initVal});
3611 newInitArgs.emplace_back(initVal);
3612 newYieldOpArgs.emplace_back(yieldOpArg);
3613 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3626 &newWhile.getBefore(), {},
3629 Block &beforeBlock = *op.getBeforeBody();
3636 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3639 if (beforeBlockInitValMap.count(i) != 0)
3640 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3642 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3645 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3647 newWhile.getAfter().begin());
3649 rewriter.
replaceOp(op, newWhile.getResults());
3694 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3697 LogicalResult matchAndRewrite(WhileOp op,
3699 Block &beforeBlock = *op.getBeforeBody();
3700 ConditionOp condOp = op.getConditionOp();
3703 bool canSimplify =
false;
3704 for (
Value condOpArg : condOpArgs) {
3724 auto index =
static_cast<unsigned>(it.index());
3725 Value condOpArg = it.value();
3730 condOpInitValMap.insert({index, condOpArg});
3732 newCondOpArgs.emplace_back(condOpArg);
3733 newAfterBlockType.emplace_back(condOpArg.
getType());
3734 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3745 auto newWhile = rewriter.
create<WhileOp>(op.
getLoc(), newAfterBlockType,
3748 Block &newAfterBlock =
3750 newAfterBlockType, newAfterBlockArgLocs);
3752 Block &afterBlock = *op.getAfterBody();
3759 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3760 Value afterBlockArg, result;
3763 if (condOpInitValMap.count(i) != 0) {
3764 afterBlockArg = condOpInitValMap[i];
3765 result = afterBlockArg;
3767 afterBlockArg = newAfterBlock.getArgument(
j);
3768 result = newWhile.getResult(
j);
3771 newAfterBlockArgs[i] = afterBlockArg;
3772 newWhileResults[i] = result;
3775 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3777 newWhile.getBefore().begin());
3779 rewriter.
replaceOp(op, newWhileResults);
3813 LogicalResult matchAndRewrite(WhileOp op,
3815 auto term = op.getConditionOp();
3816 auto afterArgs = op.getAfterArguments();
3817 auto termArgs = term.getArgs();
3824 bool needUpdate =
false;
3825 for (
const auto &it :
3827 auto i =
static_cast<unsigned>(it.index());
3828 Value result = std::get<0>(it.value());
3829 Value afterArg = std::get<1>(it.value());
3830 Value termArg = std::get<2>(it.value());
3834 newResultsIndices.emplace_back(i);
3835 newTermArgs.emplace_back(termArg);
3836 newResultTypes.emplace_back(result.
getType());
3837 newArgLocs.emplace_back(result.
getLoc());
3852 rewriter.
create<WhileOp>(op.
getLoc(), newResultTypes, op.getInits());
3855 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3862 newResults[it.value()] = newWhile.getResult(it.index());
3863 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3867 newWhile.getBefore().begin());
3869 Block &afterBlock = *op.getAfterBody();
3870 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3902 LogicalResult matchAndRewrite(scf::WhileOp op,
3904 using namespace scf;
3905 auto cond = op.getConditionOp();
3906 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3909 bool changed =
false;
3910 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3911 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3912 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3915 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3916 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3920 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3923 if (cmp2.getPredicate() == cmp.getPredicate())
3924 samePredicate =
true;
3925 else if (cmp2.getPredicate() ==
3927 samePredicate =
false;
3937 return success(changed);
3945 LogicalResult matchAndRewrite(WhileOp op,
3948 if (!llvm::any_of(op.getBeforeArguments(),
3949 [](
Value arg) { return arg.use_empty(); }))
3952 YieldOp yield = op.getYieldOp();
3957 llvm::BitVector argsToErase;
3959 size_t argsCount = op.getBeforeArguments().size();
3960 newYields.reserve(argsCount);
3961 newInits.reserve(argsCount);
3962 argsToErase.reserve(argsCount);
3963 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3964 op.getBeforeArguments(), yield.
getOperands(), op.getInits())) {
3965 if (beforeArg.use_empty()) {
3966 argsToErase.push_back(
true);
3968 argsToErase.push_back(
false);
3969 newYields.emplace_back(yieldValue);
3970 newInits.emplace_back(initValue);
3974 Block &beforeBlock = *op.getBeforeBody();
3975 Block &afterBlock = *op.getAfterBody();
3983 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3984 Block &newAfterBlock = *newWhileOp.getAfterBody();
3990 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
3991 newBeforeBlock.getArguments());
3995 rewriter.
replaceOp(op, newWhileOp.getResults());
4004 LogicalResult matchAndRewrite(WhileOp op,
4006 ConditionOp condOp = op.getConditionOp();
4010 for (
Value arg : condOpArgs)
4011 argsSet.insert(arg);
4013 if (argsSet.size() == condOpArgs.size())
4016 llvm::SmallDenseMap<Value, unsigned> argsMap;
4018 argsMap.reserve(condOpArgs.size());
4019 newArgs.reserve(condOpArgs.size());
4020 for (
Value arg : condOpArgs) {
4021 if (!argsMap.count(arg)) {
4022 auto pos =
static_cast<unsigned>(argsMap.size());
4023 argsMap.insert({arg, pos});
4024 newArgs.emplace_back(arg);
4031 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4032 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4034 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4035 Block &newAfterBlock = *newWhileOp.getAfterBody();
4040 auto it = argsMap.find(arg);
4041 assert(it != argsMap.end());
4042 auto pos = it->second;
4043 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4044 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4052 Block &beforeBlock = *op.getBeforeBody();
4053 Block &afterBlock = *op.getAfterBody();
4055 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4056 newBeforeBlock.getArguments());
4057 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4065 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4067 if (args1.size() != args2.size())
4068 return std::nullopt;
4072 auto it = llvm::find(args2, arg1);
4073 if (it == args2.end())
4074 return std::nullopt;
4076 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4083 llvm::SmallDenseSet<Value> set;
4084 for (
Value arg : args) {
4085 if (!set.insert(arg).second)
4098 LogicalResult matchAndRewrite(WhileOp loop,
4100 auto oldBefore = loop.getBeforeBody();
4101 ConditionOp oldTerm = loop.getConditionOp();
4102 ValueRange beforeArgs = oldBefore->getArguments();
4104 if (beforeArgs == termArgs)
4107 if (hasDuplicates(termArgs))
4110 auto mapping = getArgsMapping(beforeArgs, termArgs);
4121 auto oldAfter = loop.getAfterBody();
4125 newResultTypes[
j] = loop.getResult(i).getType();
4127 auto newLoop = rewriter.
create<WhileOp>(
4128 loop.getLoc(), newResultTypes, loop.getInits(),
4130 auto newBefore = newLoop.getBeforeBody();
4131 auto newAfter = newLoop.getAfterBody();
4136 newResults[i] = newLoop.getResult(
j);
4137 newAfterArgs[i] = newAfter->getArgument(
j);
4141 newBefore->getArguments());
4153 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4154 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4155 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4156 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4170 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4173 caseValues.push_back(value);
4182 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4184 p <<
"case " << value <<
' ';
4190 if (getCases().size() != getCaseRegions().size()) {
4191 return emitOpError(
"has ")
4192 << getCaseRegions().size() <<
" case regions but "
4193 << getCases().size() <<
" case values";
4197 for (int64_t value : getCases())
4198 if (!valueSet.insert(value).second)
4199 return emitOpError(
"has duplicate case value: ") << value;
4201 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4203 return emitOpError(
"expected region to end with scf.yield, but got ")
4206 if (yield.getNumOperands() != getNumResults()) {
4207 return (emitOpError(
"expected each region to return ")
4208 << getNumResults() <<
" values, but " << name <<
" returns "
4209 << yield.getNumOperands())
4210 .attachNote(yield.getLoc())
4211 <<
"see yield operation here";
4213 for (
auto [idx, result, operand] :
4214 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4215 yield.getOperandTypes())) {
4216 if (result == operand)
4218 return (emitOpError(
"expected result #")
4219 << idx <<
" of each region to be " << result)
4220 .attachNote(yield.getLoc())
4221 << name <<
" returns " << operand <<
" here";
4226 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4229 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4235 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4237 Block &scf::IndexSwitchOp::getDefaultBlock() {
4238 return getDefaultRegion().
front();
4241 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4242 assert(idx < getNumCases() &&
"case index out-of-bounds");
4243 return getCaseRegions()[idx].front();
4246 void IndexSwitchOp::getSuccessorRegions(
4250 successors.emplace_back(getResults());
4254 llvm::copy(getRegions(), std::back_inserter(successors));
4257 void IndexSwitchOp::getEntrySuccessorRegions(
4260 FoldAdaptor adaptor(operands, *
this);
4263 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4265 llvm::copy(getRegions(), std::back_inserter(successors));
4271 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4272 if (caseValue == arg.getInt()) {
4273 successors.emplace_back(&caseRegion);
4277 successors.emplace_back(&getDefaultRegion());
4280 void IndexSwitchOp::getRegionInvocationBounds(
4282 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4283 if (!operandValue) {
4289 unsigned liveIndex = getNumRegions() - 1;
4290 const auto *it = llvm::find(getCases(), operandValue.getInt());
4291 if (it != getCases().end())
4292 liveIndex = std::distance(getCases().begin(), it);
4293 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4294 bounds.emplace_back(0, i == liveIndex);
4305 if (!maybeCst.has_value())
4307 int64_t cst = *maybeCst;
4308 int64_t caseIdx, e = op.getNumCases();
4309 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4310 if (cst == op.getCases()[caseIdx])
4314 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4315 : op.getDefaultRegion();
4316 Block &source = r.front();
4339 #define GET_OP_CLASSES
4340 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalables, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.