25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77 InParallelOp, ReduceReturnOp>();
78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80 ForallOp, InParallelOp, WhileOp, YieldOp>();
81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
86 builder.
create<scf::YieldOp>(loc);
91 template <
typename TerminatorTy>
93 StringRef errorMessage) {
96 terminatorOperation = ®ion.
front().
back();
97 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
101 if (terminatorOperation)
102 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
114 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
160 if (getRegion().empty())
161 return emitOpError(
"region needs to have at least one block");
162 if (getRegion().front().getNumArguments() > 0)
163 return emitOpError(
"region cannot have any arguments");
186 if (!llvm::hasSingleElement(op.getRegion()))
235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238 Block *prevBlock = op->getBlock();
242 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 for (
Block &blk : op.getRegion()) {
245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248 yieldOp.getResults());
256 for (
auto res : op.getResults())
257 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
269 void ExecuteRegionOp::getSuccessorRegions(
287 assert((point.
isParent() || point == getParentOp().getAfter()) &&
288 "condition op can only exit the loop or branch to the after"
291 return getArgsMutable();
294 void ConditionOp::getSuccessorRegions(
296 FoldAdaptor adaptor(operands, *
this);
298 WhileOp whileOp = getParentOp();
302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303 if (!boolAttr || boolAttr.getValue())
304 regions.emplace_back(&whileOp.getAfter(),
305 whileOp.getAfter().getArguments());
306 if (!boolAttr || !boolAttr.getValue())
307 regions.emplace_back(whileOp.getResults());
316 BodyBuilderFn bodyBuilder) {
321 for (
Value v : initArgs)
327 for (
Value v : initArgs)
333 if (initArgs.empty() && !bodyBuilder) {
334 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
335 }
else if (bodyBuilder) {
345 if (getInitArgs().size() != getNumResults())
347 "mismatch in number of loop-carried values and defined values");
352 LogicalResult ForOp::verifyRegions() {
357 "expected induction variable to be same type as bounds and step");
359 if (getNumRegionIterArgs() != getNumResults())
361 "mismatch in number of basic block args and defined values");
363 auto initArgs = getInitArgs();
364 auto iterArgs = getRegionIterArgs();
365 auto opResults = getResults();
367 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 return emitOpError() <<
"types mismatch between " << i
370 <<
"th iter operand and defined value";
372 return emitOpError() <<
"types mismatch between " << i
373 <<
"th iter region arg and defined value";
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
401 std::optional<int64_t> tripCount =
403 if (!tripCount.has_value() || tripCount != 1)
407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
414 llvm::append_range(bbArgReplacements, getInitArgs());
418 getOperation()->getIterator(), bbArgReplacements);
434 StringRef prefix =
"") {
435 assert(blocksArgs.size() == initializers.size() &&
436 "expected same length of arguments and initializers");
437 if (initializers.empty())
441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
442 p << std::get<0>(it) <<
" = " << std::get<1>(it);
448 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
452 if (!getInitArgs().empty())
453 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 p <<
" : " << t <<
' ';
459 !getInitArgs().empty());
481 regionArgs.push_back(inductionVariable);
491 if (regionArgs.size() != result.
types.size() + 1)
494 "mismatch in number of loop-carried values and defined values");
503 regionArgs.front().type = type;
509 for (
auto argOperandType :
510 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
511 Type type = std::get<2>(argOperandType);
512 std::get<0>(argOperandType).type = type;
524 ForOp::ensureTerminator(*body, builder, result.
location);
536 return getBody()->getArguments().drop_front(getNumInductionVars());
540 return getInitArgsMutable();
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
546 bool replaceInitOperandUsesInLoop,
551 auto inits = llvm::to_vector(getInitArgs());
552 inits.append(newInitOperands.begin(), newInitOperands.end());
553 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
559 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
561 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
566 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567 assert(newInitOperands.size() == newYieldedValues.size() &&
568 "expected as many new yield values as new iter operands");
570 yieldOp.getResultsMutable().append(newYieldedValues);
576 newLoop.getBody()->getArguments().take_front(
577 getBody()->getNumArguments()));
579 if (replaceInitOperandUsesInLoop) {
582 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
593 newLoop->getResults().take_front(getNumResults()));
594 return cast<LoopLikeOpInterface>(newLoop.getOperation());
598 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
601 assert(ivArg.getOwner() &&
"unlinked block argument");
602 auto *containingOp = ivArg.getOwner()->getParentOp();
603 return dyn_cast_or_null<ForOp>(containingOp);
607 return getInitArgs();
624 for (
auto [lb, ub, step] :
625 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
627 if (!tripCount.has_value() || *tripCount != 1)
636 return getBody()->getArguments().drop_front(getRank());
640 return getOutputsMutable();
646 scf::InParallelOp terminator = forallOp.getTerminator();
651 bbArgReplacements.append(forallOp.getOutputs().begin(),
652 forallOp.getOutputs().end());
656 forallOp->getIterator(), bbArgReplacements);
661 results.reserve(forallOp.getResults().size());
662 for (
auto &yieldingOp : terminator.getYieldingOps()) {
663 auto parallelInsertSliceOp =
664 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
666 Value dst = parallelInsertSliceOp.getDest();
667 Value src = parallelInsertSliceOp.getSource();
668 if (llvm::isa<TensorType>(src.
getType())) {
669 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
670 forallOp.getLoc(), dst.
getType(), src, dst,
671 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672 parallelInsertSliceOp.getStrides(),
673 parallelInsertSliceOp.getStaticOffsets(),
674 parallelInsertSliceOp.getStaticSizes(),
675 parallelInsertSliceOp.getStaticStrides()));
677 llvm_unreachable(
"unsupported terminator");
692 assert(lbs.size() == ubs.size() &&
693 "expected the same number of lower and upper bounds");
694 assert(lbs.size() == steps.size() &&
695 "expected the same number of lower bounds and steps");
700 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
702 assert(results.size() == iterArgs.size() &&
703 "loop nest body must return as many values as loop has iteration "
705 return LoopNest{{}, std::move(results)};
713 loops.reserve(lbs.size());
714 ivs.reserve(lbs.size());
717 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
718 auto loop = builder.
create<scf::ForOp>(
719 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725 currentIterArgs = args;
726 currentLoc = nestedLoc;
732 loops.push_back(loop);
736 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
738 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
745 ? bodyBuilder(builder, currentLoc, ivs,
746 loops.back().getRegionIterArgs())
748 assert(results.size() == iterArgs.size() &&
749 "loop nest body must return as many values as loop has iteration "
752 builder.
create<scf::YieldOp>(loc, results);
756 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757 return LoopNest{std::move(loops), std::move(nestResults)};
765 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
770 bodyBuilder(nestedBuilder, nestedLoc, ivs);
779 assert(operand.
getOwner() == forOp);
784 "expected an iter OpOperand");
786 "Expected a different type");
788 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
790 newIterOperands.push_back(replacement);
793 newIterOperands.push_back(opOperand.get());
797 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
798 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799 forOp.getStep(), newIterOperands);
800 newForOp->
setAttrs(forOp->getAttrs());
801 Block &newBlock = newForOp.getRegion().
front();
809 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
811 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
816 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
819 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
822 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
823 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824 clonedYieldOp.getOperand(yieldIdx));
826 newYieldOperands[yieldIdx] = castOut;
827 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828 rewriter.
eraseOp(clonedYieldOp);
833 newResults[yieldIdx] =
834 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
854 LogicalResult matchAndRewrite(scf::ForOp forOp,
856 bool canonicalize =
false;
863 int64_t numResults = forOp.getNumResults();
865 keepMask.reserve(numResults);
868 newBlockTransferArgs.reserve(1 + numResults);
869 newBlockTransferArgs.push_back(
Value());
870 newIterArgs.reserve(forOp.getInitArgs().size());
871 newYieldValues.reserve(numResults);
872 newResultValues.reserve(numResults);
874 for (
auto [init, arg, result, yielded] :
875 llvm::zip(forOp.getInitArgs(),
876 forOp.getRegionIterArgs(),
878 forOp.getYieldedValues()
885 bool forwarded = (arg == yielded) || (init == yielded) ||
886 (arg.use_empty() && result.use_empty());
889 keepMask.push_back(
false);
890 newBlockTransferArgs.push_back(init);
891 newResultValues.push_back(init);
897 if (
auto it = initYieldToArg.find({init, yielded});
898 it != initYieldToArg.end()) {
900 keepMask.push_back(
false);
901 auto [sameArg, sameResult] = it->second;
905 newBlockTransferArgs.push_back(init);
906 newResultValues.push_back(init);
911 initYieldToArg.insert({{init, yielded}, {arg, result}});
912 keepMask.push_back(
true);
913 newIterArgs.push_back(init);
914 newYieldValues.push_back(yielded);
915 newBlockTransferArgs.push_back(
Value());
916 newResultValues.push_back(
Value());
922 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
923 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
924 forOp.getStep(), newIterArgs);
925 newForOp->
setAttrs(forOp->getAttrs());
926 Block &newBlock = newForOp.getRegion().
front();
930 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
932 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
933 Value &newResultVal = newResultValues[idx];
934 assert((blockTransferArg && newResultVal) ||
935 (!blockTransferArg && !newResultVal));
936 if (!blockTransferArg) {
937 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
938 newResultVal = newForOp.getResult(collapsedIdx++);
944 "unexpected argument size mismatch");
949 if (newIterArgs.empty()) {
950 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
953 rewriter.
replaceOp(forOp, newResultValues);
958 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
962 filteredOperands.reserve(newResultValues.size());
963 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
965 filteredOperands.push_back(mergedTerminator.getOperand(idx));
966 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
970 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
971 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
972 cloneFilteredTerminator(mergedYieldOp);
973 rewriter.
eraseOp(mergedYieldOp);
974 rewriter.
replaceOp(forOp, newResultValues);
982 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
983 IntegerAttr clb, cub;
985 llvm::APInt lbValue = clb.getValue();
986 llvm::APInt ubValue = cub.getValue();
987 return (ubValue - lbValue).getSExtValue();
996 return diff.getSExtValue();
1006 LogicalResult matchAndRewrite(ForOp op,
1010 if (op.getLowerBound() == op.getUpperBound()) {
1011 rewriter.
replaceOp(op, op.getInitArgs());
1015 std::optional<int64_t> diff =
1016 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1022 rewriter.
replaceOp(op, op.getInitArgs());
1026 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1027 if (!maybeStepValue)
1032 llvm::APInt stepValue = *maybeStepValue;
1033 if (stepValue.sge(*diff)) {
1035 blockArgs.reserve(op.getInitArgs().size() + 1);
1036 blockArgs.push_back(op.getLowerBound());
1037 llvm::append_range(blockArgs, op.getInitArgs());
1044 if (!llvm::hasSingleElement(block))
1048 if (llvm::any_of(op.getYieldedValues(),
1049 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1051 rewriter.
replaceOp(op, op.getYieldedValues());
1085 LogicalResult matchAndRewrite(ForOp op,
1087 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1088 OpOperand &iterOpOperand = std::get<0>(it);
1090 if (!incomingCast ||
1091 incomingCast.getSource().getType() == incomingCast.getType())
1096 incomingCast.getDest().getType(),
1097 incomingCast.getSource().getType()))
1099 if (!std::get<1>(it).hasOneUse())
1105 rewriter, op, iterOpOperand, incomingCast.getSource(),
1107 return b.create<tensor::CastOp>(loc, type, source);
1119 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1123 std::optional<APInt> ForOp::getConstantStep() {
1126 return step.getValue();
1130 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1131 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1137 if (
auto constantStep = getConstantStep())
1138 if (*constantStep == 1)
1151 unsigned numLoops = getRank();
1153 if (getNumResults() != getOutputs().size())
1154 return emitOpError(
"produces ")
1155 << getNumResults() <<
" results, but has only "
1156 << getOutputs().size() <<
" outputs";
1159 auto *body = getBody();
1161 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1162 for (int64_t i = 0; i < numLoops; ++i)
1164 return emitOpError(
"expects ")
1165 << i <<
"-th block argument to be an index";
1166 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1168 return emitOpError(
"type mismatch between ")
1169 << i <<
"-th output and corresponding block argument";
1170 if (getMapping().has_value() && !getMapping()->empty()) {
1171 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1172 return emitOpError() <<
"mapping attribute size must match op rank";
1173 for (
auto map : getMapping()->getValue()) {
1174 if (!isa<DeviceMappingAttrInterface>(map))
1175 return emitOpError()
1183 getStaticLowerBound(),
1184 getDynamicLowerBound())))
1187 getStaticUpperBound(),
1188 getDynamicUpperBound())))
1191 getStaticStep(), getDynamicStep())))
1199 p <<
" (" << getInductionVars();
1200 if (isNormalized()) {
1221 if (!getRegionOutArgs().empty())
1222 p <<
"-> (" << getResultTypes() <<
") ";
1223 p.printRegion(getRegion(),
1225 getNumResults() > 0);
1226 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1227 getStaticLowerBoundAttrName(),
1228 getStaticUpperBoundAttrName(),
1229 getStaticStepAttrName()});
1234 auto indexType = b.getIndexType();
1254 unsigned numLoops = ivs.size();
1289 if (outOperands.size() != result.
types.size())
1291 "mismatch between out operands and types");
1301 std::unique_ptr<Region> region = std::make_unique<Region>();
1302 for (
auto &iv : ivs) {
1303 iv.type = b.getIndexType();
1304 regionArgs.push_back(iv);
1307 auto &out = it.value();
1308 out.type = result.
types[it.index()];
1309 regionArgs.push_back(out);
1315 ForallOp::ensureTerminator(*region, b, result.
location);
1327 {static_cast<int32_t>(dynamicLbs.size()),
1328 static_cast<int32_t>(dynamicUbs.size()),
1329 static_cast<int32_t>(dynamicSteps.size()),
1330 static_cast<int32_t>(outOperands.size())}));
1335 void ForallOp::build(
1339 std::optional<ArrayAttr> mapping,
1360 "operandSegmentSizes",
1362 static_cast<int32_t>(dynamicUbs.size()),
1363 static_cast<int32_t>(dynamicSteps.size()),
1364 static_cast<int32_t>(outputs.size())}));
1365 if (mapping.has_value()) {
1384 if (!bodyBuilderFn) {
1385 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1392 void ForallOp::build(
1395 std::optional<ArrayAttr> mapping,
1397 unsigned numLoops = ubs.size();
1400 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1404 bool ForallOp::isNormalized() {
1408 return intValue.has_value() && intValue == val;
1411 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1420 ForallOp>::ensureTerminator(region, builder, loc);
1427 InParallelOp ForallOp::getTerminator() {
1428 return cast<InParallelOp>(getBody()->getTerminator());
1433 InParallelOp inParallelOp = getTerminator();
1434 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1435 if (
auto parallelInsertSliceOp =
1436 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1437 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1438 storeOps.push_back(parallelInsertSliceOp);
1444 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1449 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1451 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1455 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1457 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1461 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1467 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1470 assert(tidxArg.getOwner() &&
"unlinked block argument");
1471 auto *containingOp = tidxArg.getOwner()->getParentOp();
1472 return dyn_cast<ForallOp>(containingOp);
1480 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1482 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1486 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1489 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1498 LogicalResult matchAndRewrite(ForallOp op,
1513 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1514 op.setStaticLowerBound(staticLowerBound);
1518 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1519 op.setStaticUpperBound(staticUpperBound);
1522 op.getDynamicStepMutable().assign(dynamicStep);
1523 op.setStaticStep(staticStep);
1525 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1527 {static_cast<int32_t>(dynamicLowerBound.size()),
1528 static_cast<int32_t>(dynamicUpperBound.size()),
1529 static_cast<int32_t>(dynamicStep.size()),
1530 static_cast<int32_t>(op.getNumResults())}));
1612 LogicalResult matchAndRewrite(ForallOp forallOp,
1631 for (
OpResult result : forallOp.getResults()) {
1632 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1633 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1634 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1635 resultToDelete.insert(result);
1637 resultToReplace.push_back(result);
1638 newOuts.push_back(opOperand->
get());
1644 if (resultToDelete.empty())
1652 for (
OpResult result : resultToDelete) {
1653 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1654 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1656 forallOp.getCombiningOps(blockArg);
1657 for (
Operation *combiningOp : combiningOps)
1658 rewriter.
eraseOp(combiningOp);
1663 auto newForallOp = rewriter.
create<scf::ForallOp>(
1664 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1665 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1666 forallOp.getMapping(),
1671 Block *loopBody = forallOp.getBody();
1672 Block *newLoopBody = newForallOp.getBody();
1677 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1684 for (
OpResult result : forallOp.getResults()) {
1685 if (resultToDelete.count(result)) {
1686 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1688 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1691 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1695 for (
auto &&[oldResult, newResult] :
1696 llvm::zip(resultToReplace, newForallOp->getResults()))
1702 for (
OpResult oldResult : resultToDelete)
1704 forallOp.getTiedOpOperand(oldResult)->get());
1709 struct ForallOpSingleOrZeroIterationDimsFolder
1713 LogicalResult matchAndRewrite(ForallOp op,
1716 if (op.getMapping().has_value() && !op.getMapping()->empty())
1724 for (
auto [lb, ub, step, iv] :
1725 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1726 op.getMixedStep(), op.getInductionVars())) {
1728 if (numIterations.has_value()) {
1730 if (*numIterations == 0) {
1731 rewriter.
replaceOp(op, op.getOutputs());
1736 if (*numIterations == 1) {
1741 newMixedLowerBounds.push_back(lb);
1742 newMixedUpperBounds.push_back(ub);
1743 newMixedSteps.push_back(step);
1747 if (newMixedLowerBounds.empty()) {
1753 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1755 op,
"no dimensions have 0 or 1 iterations");
1760 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1761 newMixedUpperBounds, newMixedSteps,
1762 op.getOutputs(), std::nullopt,
nullptr);
1763 newOp.getBodyRegion().getBlocks().clear();
1768 newOp.getStaticLowerBoundAttrName(),
1769 newOp.getStaticUpperBoundAttrName(),
1770 newOp.getStaticStepAttrName()};
1771 for (
const auto &namedAttr : op->getAttrs()) {
1772 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1775 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1779 newOp.getRegion().begin(), mapping);
1780 rewriter.
replaceOp(op, newOp.getResults());
1786 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1789 LogicalResult matchAndRewrite(ForallOp op,
1793 for (
auto [lb, ub, step, iv] :
1794 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1795 op.getMixedStep(), op.getInductionVars())) {
1796 if (iv.getUses().begin() == iv.getUses().end())
1799 if (!numIterations.has_value() || numIterations.value() != 1) {
1810 struct FoldTensorCastOfOutputIntoForallOp
1819 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1821 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1824 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1831 castOp.getSource().getType())) {
1835 tensorCastProducers[en.index()] =
1836 TypeCast{castOp.getSource().getType(), castOp.getType()};
1837 newOutputTensors[en.index()] = castOp.getSource();
1840 if (tensorCastProducers.empty())
1845 auto newForallOp = rewriter.
create<ForallOp>(
1846 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1847 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1849 auto castBlockArgs =
1850 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1851 for (
auto [index, cast] : tensorCastProducers) {
1852 Value &oldTypeBBArg = castBlockArgs[index];
1853 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1854 nestedLoc, cast.dstType, oldTypeBBArg);
1859 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1860 ivsBlockArgs.append(castBlockArgs);
1862 bbArgs.front().getParentBlock(), ivsBlockArgs);
1868 auto terminator = newForallOp.getTerminator();
1869 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1870 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1871 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1872 insertSliceOp.getDestMutable().assign(outputBlockArg);
1878 for (
auto &item : tensorCastProducers) {
1879 Value &oldTypeResult = castResults[item.first];
1880 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1883 rewriter.
replaceOp(forallOp, castResults);
1892 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1893 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1894 ForallOpSingleOrZeroIterationDimsFolder,
1895 ForallOpReplaceConstantInductionVar>(context);
1924 scf::ForallOp forallOp =
1925 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1927 return this->emitOpError(
"expected forall op parent");
1930 for (
Operation &op : getRegion().front().getOperations()) {
1931 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1932 return this->emitOpError(
"expected only ")
1933 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1937 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1939 if (!llvm::is_contained(regionOutArgs, dest))
1940 return op.emitOpError(
"may only insert into an output block argument");
1957 std::unique_ptr<Region> region = std::make_unique<Region>();
1961 if (region->empty())
1971 OpResult InParallelOp::getParentResult(int64_t idx) {
1972 return getOperation()->getParentOp()->getResult(idx);
1976 return llvm::to_vector<4>(
1977 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1979 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1980 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1985 return getRegion().front().getOperations();
1993 assert(a &&
"expected non-empty operation");
1994 assert(b &&
"expected non-empty operation");
1999 if (ifOp->isProperAncestor(b))
2002 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2003 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2005 ifOp = ifOp->getParentOfType<IfOp>();
2013 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2014 IfOp::Adaptor adaptor,
2016 if (adaptor.getRegions().empty())
2018 Region *r = &adaptor.getThenRegion();
2021 Block &b = r->front();
2024 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2027 TypeRange types = yieldOp.getOperandTypes();
2028 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2035 return build(builder, result, resultTypes, cond,
false,
2041 bool addElseBlock) {
2042 assert((!addElseBlock || addThenBlock) &&
2043 "must not create else block w/o then block");
2058 bool withElseRegion) {
2059 build(builder, result,
TypeRange{}, cond, withElseRegion);
2071 if (resultTypes.empty())
2072 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2076 if (withElseRegion) {
2078 if (resultTypes.empty())
2079 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2086 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2093 thenBuilder(builder, result.
location);
2099 elseBuilder(builder, result.
location);
2106 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2108 inferredReturnTypes))) {
2109 result.
addTypes(inferredReturnTypes);
2114 if (getNumResults() != 0 && getElseRegion().empty())
2115 return emitOpError(
"must have an else block if defining values");
2153 bool printBlockTerminators =
false;
2155 p <<
" " << getCondition();
2156 if (!getResults().empty()) {
2157 p <<
" -> (" << getResultTypes() <<
")";
2159 printBlockTerminators =
true;
2164 printBlockTerminators);
2167 auto &elseRegion = getElseRegion();
2168 if (!elseRegion.
empty()) {
2172 printBlockTerminators);
2189 Region *elseRegion = &this->getElseRegion();
2190 if (elseRegion->
empty())
2198 FoldAdaptor adaptor(operands, *
this);
2199 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2200 if (!boolAttr || boolAttr.getValue())
2201 regions.emplace_back(&getThenRegion());
2204 if (!boolAttr || !boolAttr.getValue()) {
2205 if (!getElseRegion().empty())
2206 regions.emplace_back(&getElseRegion());
2208 regions.emplace_back(getResults());
2212 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2215 if (getElseRegion().empty())
2218 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2225 getConditionMutable().assign(xorStmt.getLhs());
2229 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2230 getElseRegion().getBlocks());
2231 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2232 getThenRegion().getBlocks(), thenBlock);
2236 void IfOp::getRegionInvocationBounds(
2239 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2242 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2243 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2246 invocationBounds.assign(2, {0, 1});
2262 llvm::transform(usedResults, std::back_inserter(usedOperands),
2267 [&]() { yieldOp->setOperands(usedOperands); });
2270 LogicalResult matchAndRewrite(IfOp op,
2274 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2275 [](
OpResult result) { return !result.use_empty(); });
2278 if (usedResults.size() == op.getNumResults())
2283 llvm::transform(usedResults, std::back_inserter(newTypes),
2288 rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2294 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2295 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2300 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2309 LogicalResult matchAndRewrite(IfOp op,
2317 else if (!op.getElseRegion().empty())
2331 LogicalResult matchAndRewrite(IfOp op,
2333 if (op->getNumResults() == 0)
2336 auto cond = op.getCondition();
2337 auto thenYieldArgs = op.thenYield().getOperands();
2338 auto elseYieldArgs = op.elseYield().getOperands();
2341 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2342 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2343 &op.getElseRegion() == falseVal.getParentRegion())
2344 nonHoistable.push_back(trueVal.getType());
2348 if (nonHoistable.size() == op->getNumResults())
2351 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond,
2353 if (replacement.thenBlock())
2354 rewriter.
eraseBlock(replacement.thenBlock());
2355 replacement.getThenRegion().takeBody(op.getThenRegion());
2356 replacement.getElseRegion().takeBody(op.getElseRegion());
2359 assert(thenYieldArgs.size() == results.size());
2360 assert(elseYieldArgs.size() == results.size());
2365 for (
const auto &it :
2367 Value trueVal = std::get<0>(it.value());
2368 Value falseVal = std::get<1>(it.value());
2371 results[it.index()] = replacement.getResult(trueYields.size());
2372 trueYields.push_back(trueVal);
2373 falseYields.push_back(falseVal);
2374 }
else if (trueVal == falseVal)
2375 results[it.index()] = trueVal;
2377 results[it.index()] = rewriter.
create<arith::SelectOp>(
2378 op.getLoc(), cond, trueVal, falseVal);
2408 LogicalResult matchAndRewrite(IfOp op,
2420 Value constantTrue =
nullptr;
2421 Value constantFalse =
nullptr;
2424 llvm::make_early_inc_range(op.getCondition().getUses())) {
2429 constantTrue = rewriter.
create<arith::ConstantOp>(
2433 [&]() { use.
set(constantTrue); });
2434 }
else if (op.getElseRegion().isAncestor(
2439 constantFalse = rewriter.
create<arith::ConstantOp>(
2443 [&]() { use.
set(constantFalse); });
2487 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2490 LogicalResult matchAndRewrite(IfOp op,
2493 if (op.getNumResults() == 0)
2497 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2499 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2502 op.getOperation()->getIterator());
2505 for (
auto [trueResult, falseResult, opResult] :
2506 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2508 if (trueResult == falseResult) {
2509 if (!opResult.use_empty()) {
2510 opResult.replaceAllUsesWith(trueResult);
2521 bool trueVal = trueYield.
getValue();
2522 bool falseVal = falseYield.
getValue();
2523 if (!trueVal && falseVal) {
2524 if (!opResult.use_empty()) {
2525 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2527 op.getLoc(), op.getCondition(),
2537 if (trueVal && !falseVal) {
2538 if (!opResult.use_empty()) {
2539 opResult.replaceAllUsesWith(op.getCondition());
2572 LogicalResult matchAndRewrite(IfOp nextIf,
2574 Block *parent = nextIf->getBlock();
2575 if (nextIf == &parent->
front())
2578 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2586 Block *nextThen =
nullptr;
2587 Block *nextElse =
nullptr;
2588 if (nextIf.getCondition() == prevIf.getCondition()) {
2589 nextThen = nextIf.thenBlock();
2590 if (!nextIf.getElseRegion().empty())
2591 nextElse = nextIf.elseBlock();
2593 if (arith::XOrIOp notv =
2594 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2595 if (notv.getLhs() == prevIf.getCondition() &&
2597 nextElse = nextIf.thenBlock();
2598 if (!nextIf.getElseRegion().empty())
2599 nextThen = nextIf.elseBlock();
2602 if (arith::XOrIOp notv =
2603 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2604 if (notv.getLhs() == nextIf.getCondition() &&
2606 nextElse = nextIf.thenBlock();
2607 if (!nextIf.getElseRegion().empty())
2608 nextThen = nextIf.elseBlock();
2612 if (!nextThen && !nextElse)
2616 if (!prevIf.getElseRegion().empty())
2617 prevElseYielded = prevIf.elseYield().getOperands();
2620 for (
auto it : llvm::zip(prevIf.getResults(),
2621 prevIf.thenYield().getOperands(), prevElseYielded))
2623 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2627 use.
set(std::get<1>(it));
2632 use.
set(std::get<2>(it));
2638 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2640 IfOp combinedIf = rewriter.
create<IfOp>(
2641 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2642 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2645 combinedIf.getThenRegion(),
2646 combinedIf.getThenRegion().begin());
2649 YieldOp thenYield = combinedIf.thenYield();
2650 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2651 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2655 llvm::append_range(mergedYields, thenYield2.getOperands());
2656 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2662 combinedIf.getElseRegion(),
2663 combinedIf.getElseRegion().begin());
2666 if (combinedIf.getElseRegion().empty()) {
2668 combinedIf.getElseRegion(),
2669 combinedIf.getElseRegion().
begin());
2671 YieldOp elseYield = combinedIf.elseYield();
2672 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2673 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2678 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2680 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2689 if (pair.index() < prevIf.getNumResults())
2690 prevValues.push_back(pair.value());
2692 nextValues.push_back(pair.value());
2704 LogicalResult matchAndRewrite(IfOp ifOp,
2707 if (ifOp.getNumResults())
2709 Block *elseBlock = ifOp.elseBlock();
2710 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2714 newIfOp.getThenRegion().begin());
2739 LogicalResult matchAndRewrite(IfOp op,
2741 auto nestedOps = op.thenBlock()->without_terminator();
2743 if (!llvm::hasSingleElement(nestedOps))
2747 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2750 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2754 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2760 llvm::append_range(elseYield, op.elseYield().getOperands());
2774 if (tup.value().getDefiningOp() == nestedIf) {
2775 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2776 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2777 elseYield[tup.index()]) {
2782 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2795 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2798 elseYieldsToUpgradeToSelect.push_back(tup.index());
2802 Value newCondition = rewriter.
create<arith::AndIOp>(
2803 loc, op.getCondition(), nestedIf.getCondition());
2804 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2808 llvm::append_range(results, newIf.getResults());
2811 for (
auto idx : elseYieldsToUpgradeToSelect)
2812 results[idx] = rewriter.
create<arith::SelectOp>(
2813 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2815 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2818 if (!elseYield.empty()) {
2821 rewriter.
create<YieldOp>(loc, elseYield);
2832 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2833 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2834 RemoveStaticCondition, RemoveUnusedResults,
2835 ReplaceIfYieldWithConditionOrValue>(context);
2838 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2839 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2840 Block *IfOp::elseBlock() {
2841 Region &r = getElseRegion();
2846 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2852 void ParallelOp::build(
2862 ParallelOp::getOperandSegmentSizeAttr(),
2864 static_cast<int32_t>(upperBounds.size()),
2865 static_cast<int32_t>(steps.size()),
2866 static_cast<int32_t>(initVals.size())}));
2870 unsigned numIVs = steps.size();
2876 if (bodyBuilderFn) {
2878 bodyBuilderFn(builder, result.
location,
2883 if (initVals.empty())
2884 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2887 void ParallelOp::build(
2894 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2897 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2901 wrapper = wrappedBuilderFn;
2903 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2912 if (stepValues.empty())
2914 "needs at least one tuple element for lowerBound, upperBound and step");
2917 for (
Value stepValue : stepValues)
2920 return emitOpError(
"constant step operand must be positive");
2924 Block *body = getBody();
2926 return emitOpError() <<
"expects the same number of induction variables: "
2928 <<
" as bound and step values: " << stepValues.size();
2930 if (!arg.getType().isIndex())
2932 "expects arguments for the induction variable to be of index type");
2935 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2936 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2941 auto resultsSize = getResults().size();
2942 auto reductionsSize = reduceOp.getReductions().size();
2943 auto initValsSize = getInitVals().size();
2944 if (resultsSize != reductionsSize)
2945 return emitOpError() <<
"expects number of results: " << resultsSize
2946 <<
" to be the same as number of reductions: "
2948 if (resultsSize != initValsSize)
2949 return emitOpError() <<
"expects number of results: " << resultsSize
2950 <<
" to be the same as number of initial values: "
2954 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2955 auto resultType = getOperation()->getResult(i).getType();
2956 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2957 if (resultType != reductionOperandType)
2958 return reduceOp.emitOpError()
2959 <<
"expects type of " << i
2960 <<
"-th reduction operand: " << reductionOperandType
2961 <<
" to be the same as the " << i
2962 <<
"-th result type: " << resultType;
3010 for (
auto &iv : ivs)
3017 ParallelOp::getOperandSegmentSizeAttr(),
3019 static_cast<int32_t>(upper.size()),
3020 static_cast<int32_t>(steps.size()),
3021 static_cast<int32_t>(initVals.size())}));
3030 ParallelOp::ensureTerminator(*body, builder, result.
location);
3035 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3036 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3037 if (!getInitVals().empty())
3038 p <<
" init (" << getInitVals() <<
")";
3043 (*this)->getAttrs(),
3044 ParallelOp::getOperandSegmentSizeAttr());
3049 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3053 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3057 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3061 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3066 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3068 return ParallelOp();
3069 assert(ivArg.getOwner() &&
"unlinked block argument");
3070 auto *containingOp = ivArg.getOwner()->getParentOp();
3071 return dyn_cast<ParallelOp>(containingOp);
3076 struct ParallelOpSingleOrZeroIterationDimsFolder
3080 LogicalResult matchAndRewrite(ParallelOp op,
3087 for (
auto [lb, ub, step, iv] :
3088 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3089 op.getInductionVars())) {
3091 if (numIterations.has_value()) {
3093 if (*numIterations == 0) {
3094 rewriter.
replaceOp(op, op.getInitVals());
3099 if (*numIterations == 1) {
3104 newLowerBounds.push_back(lb);
3105 newUpperBounds.push_back(ub);
3106 newSteps.push_back(step);
3109 if (newLowerBounds.size() == op.getLowerBound().size())
3112 if (newLowerBounds.empty()) {
3116 results.reserve(op.getInitVals().size());
3117 for (
auto &bodyOp : op.getBody()->without_terminator())
3118 rewriter.
clone(bodyOp, mapping);
3119 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3120 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3121 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3122 auto initValIndex = results.size();
3123 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3127 rewriter.
clone(reduceBodyOp, mapping);
3130 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3131 results.push_back(result);
3139 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3140 newSteps, op.getInitVals(),
nullptr);
3146 newOp.getRegion().begin(), mapping);
3147 rewriter.
replaceOp(op, newOp.getResults());
3155 LogicalResult matchAndRewrite(ParallelOp op,
3157 Block &outerBody = *op.getBody();
3161 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3166 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3167 llvm::is_contained(innerOp.getUpperBound(), val) ||
3168 llvm::is_contained(innerOp.getStep(), val))
3172 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3177 Block &innerBody = *innerOp.getBody();
3178 assert(iterVals.size() ==
3186 builder.
clone(op, mapping);
3189 auto concatValues = [](
const auto &first,
const auto &second) {
3191 ret.reserve(first.size() + second.size());
3192 ret.assign(first.begin(), first.end());
3193 ret.append(second.begin(), second.end());
3197 auto newLowerBounds =
3198 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3199 auto newUpperBounds =
3200 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3201 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3204 newSteps, std::nullopt,
3215 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3224 void ParallelOp::getSuccessorRegions(
3242 for (
Value v : operands) {
3251 LogicalResult ReduceOp::verifyRegions() {
3254 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3255 auto type = getOperands()[i].getType();
3258 return emitOpError() << i <<
"-th reduction has an empty body";
3261 return arg.getType() != type;
3263 return emitOpError() <<
"expected two block arguments with type " << type
3264 <<
" in the " << i <<
"-th reduction region";
3268 return emitOpError(
"reduction bodies must be terminated with an "
3269 "'scf.reduce.return' op");
3288 Block *reductionBody = getOperation()->getBlock();
3290 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3292 if (expectedResultType != getResult().
getType())
3293 return emitOpError() <<
"must have type " << expectedResultType
3294 <<
" (the type of the reduction inputs)";
3304 ValueRange inits, BodyBuilderFn beforeBuilder,
3305 BodyBuilderFn afterBuilder) {
3313 beforeArgLocs.reserve(inits.size());
3314 for (
Value operand : inits) {
3315 beforeArgLocs.push_back(operand.getLoc());
3320 inits.getTypes(), beforeArgLocs);
3329 resultTypes, afterArgLocs);
3335 ConditionOp WhileOp::getConditionOp() {
3336 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3339 YieldOp WhileOp::getYieldOp() {
3340 return cast<YieldOp>(getAfterBody()->getTerminator());
3343 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3344 return getYieldOp().getResultsMutable();
3348 return getBeforeBody()->getArguments();
3352 return getAfterBody()->getArguments();
3356 return getBeforeArguments();
3360 assert(point == getBefore() &&
3361 "WhileOp is expected to branch only to the first region");
3369 regions.emplace_back(&getBefore(), getBefore().getArguments());
3373 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3374 "there are only two regions in a WhileOp");
3376 if (point == getAfter()) {
3377 regions.emplace_back(&getBefore(), getBefore().getArguments());
3381 regions.emplace_back(getResults());
3382 regions.emplace_back(&getAfter(), getAfter().getArguments());
3386 return {&getBefore(), &getAfter()};
3407 FunctionType functionType;
3412 result.
addTypes(functionType.getResults());
3414 if (functionType.getNumInputs() != operands.size()) {
3416 <<
"expected as many input types as operands "
3417 <<
"(expected " << operands.size() <<
" got "
3418 << functionType.getNumInputs() <<
")";
3428 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3429 regionArgs[i].type = functionType.getInput(i);
3431 return failure(parser.
parseRegion(*before, regionArgs) ||
3451 template <
typename OpTy>
3454 if (left.size() != right.size())
3455 return op.emitOpError(
"expects the same number of ") << message;
3457 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3458 if (left[i] != right[i]) {
3461 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3462 <<
" and " << right[i];
3471 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3473 "expects the 'before' region to terminate with 'scf.condition'");
3474 if (!beforeTerminator)
3477 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3479 "expects the 'after' region to terminate with 'scf.yield'");
3480 return success(afterTerminator !=
nullptr);
3506 LogicalResult matchAndRewrite(WhileOp op,
3508 auto term = op.getConditionOp();
3512 Value constantTrue =
nullptr;
3514 bool replaced =
false;
3515 for (
auto yieldedAndBlockArgs :
3516 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3517 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3518 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3520 constantTrue = rewriter.
create<arith::ConstantOp>(
3521 op.getLoc(), term.getCondition().getType(),
3530 return success(replaced);
3582 struct RemoveLoopInvariantArgsFromBeforeBlock
3586 LogicalResult matchAndRewrite(WhileOp op,
3588 Block &afterBlock = *op.getAfterBody();
3590 ConditionOp condOp = op.getConditionOp();
3595 bool canSimplify =
false;
3596 for (
const auto &it :
3598 auto index =
static_cast<unsigned>(it.index());
3599 auto [initVal, yieldOpArg] = it.value();
3602 if (yieldOpArg == initVal) {
3611 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3612 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3613 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3614 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3627 for (
const auto &it :
3629 auto index =
static_cast<unsigned>(it.index());
3630 auto [initVal, yieldOpArg] = it.value();
3634 if (yieldOpArg == initVal) {
3635 beforeBlockInitValMap.insert({index, initVal});
3643 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3644 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3645 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3646 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3647 beforeBlockInitValMap.insert({index, initVal});
3652 newInitArgs.emplace_back(initVal);
3653 newYieldOpArgs.emplace_back(yieldOpArg);
3654 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3664 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3667 &newWhile.getBefore(), {},
3670 Block &beforeBlock = *op.getBeforeBody();
3677 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3680 if (beforeBlockInitValMap.count(i) != 0)
3681 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3683 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3686 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3688 newWhile.getAfter().begin());
3690 rewriter.
replaceOp(op, newWhile.getResults());
3735 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3738 LogicalResult matchAndRewrite(WhileOp op,
3740 Block &beforeBlock = *op.getBeforeBody();
3741 ConditionOp condOp = op.getConditionOp();
3744 bool canSimplify =
false;
3745 for (
Value condOpArg : condOpArgs) {
3765 auto index =
static_cast<unsigned>(it.index());
3766 Value condOpArg = it.value();
3771 condOpInitValMap.insert({index, condOpArg});
3773 newCondOpArgs.emplace_back(condOpArg);
3774 newAfterBlockType.emplace_back(condOpArg.
getType());
3775 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3786 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3789 Block &newAfterBlock =
3791 newAfterBlockType, newAfterBlockArgLocs);
3793 Block &afterBlock = *op.getAfterBody();
3800 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3801 Value afterBlockArg, result;
3804 if (condOpInitValMap.count(i) != 0) {
3805 afterBlockArg = condOpInitValMap[i];
3806 result = afterBlockArg;
3808 afterBlockArg = newAfterBlock.getArgument(
j);
3809 result = newWhile.getResult(
j);
3812 newAfterBlockArgs[i] = afterBlockArg;
3813 newWhileResults[i] = result;
3816 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3818 newWhile.getBefore().begin());
3820 rewriter.
replaceOp(op, newWhileResults);
3854 LogicalResult matchAndRewrite(WhileOp op,
3856 auto term = op.getConditionOp();
3857 auto afterArgs = op.getAfterArguments();
3858 auto termArgs = term.getArgs();
3865 bool needUpdate =
false;
3866 for (
const auto &it :
3868 auto i =
static_cast<unsigned>(it.index());
3869 Value result = std::get<0>(it.value());
3870 Value afterArg = std::get<1>(it.value());
3871 Value termArg = std::get<2>(it.value());
3875 newResultsIndices.emplace_back(i);
3876 newTermArgs.emplace_back(termArg);
3877 newResultTypes.emplace_back(result.
getType());
3878 newArgLocs.emplace_back(result.
getLoc());
3893 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3896 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3903 newResults[it.value()] = newWhile.getResult(it.index());
3904 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3908 newWhile.getBefore().begin());
3910 Block &afterBlock = *op.getAfterBody();
3911 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3943 LogicalResult matchAndRewrite(scf::WhileOp op,
3945 using namespace scf;
3946 auto cond = op.getConditionOp();
3947 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3951 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3952 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3953 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3956 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3957 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3961 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3964 if (cmp2.getPredicate() == cmp.getPredicate())
3965 samePredicate =
true;
3966 else if (cmp2.getPredicate() ==
3968 samePredicate =
false;
3986 LogicalResult matchAndRewrite(WhileOp op,
3989 if (!llvm::any_of(op.getBeforeArguments(),
3990 [](
Value arg) { return arg.use_empty(); }))
3993 YieldOp yield = op.getYieldOp();
3998 llvm::BitVector argsToErase;
4000 size_t argsCount = op.getBeforeArguments().size();
4001 newYields.reserve(argsCount);
4002 newInits.reserve(argsCount);
4003 argsToErase.reserve(argsCount);
4004 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4005 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4006 if (beforeArg.use_empty()) {
4007 argsToErase.push_back(
true);
4009 argsToErase.push_back(
false);
4010 newYields.emplace_back(yieldValue);
4011 newInits.emplace_back(initValue);
4015 Block &beforeBlock = *op.getBeforeBody();
4016 Block &afterBlock = *op.getAfterBody();
4022 rewriter.
create<WhileOp>(loc, op.getResultTypes(), newInits,
4024 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4025 Block &newAfterBlock = *newWhileOp.getAfterBody();
4031 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4032 newBeforeBlock.getArguments());
4036 rewriter.
replaceOp(op, newWhileOp.getResults());
4045 LogicalResult matchAndRewrite(WhileOp op,
4047 ConditionOp condOp = op.getConditionOp();
4051 for (
Value arg : condOpArgs)
4052 argsSet.insert(arg);
4054 if (argsSet.size() == condOpArgs.size())
4057 llvm::SmallDenseMap<Value, unsigned> argsMap;
4059 argsMap.reserve(condOpArgs.size());
4060 newArgs.reserve(condOpArgs.size());
4061 for (
Value arg : condOpArgs) {
4062 if (!argsMap.count(arg)) {
4063 auto pos =
static_cast<unsigned>(argsMap.size());
4064 argsMap.insert({arg, pos});
4065 newArgs.emplace_back(arg);
4072 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4073 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4075 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4076 Block &newAfterBlock = *newWhileOp.getAfterBody();
4081 auto it = argsMap.find(arg);
4082 assert(it != argsMap.end());
4083 auto pos = it->second;
4084 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4085 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4093 Block &beforeBlock = *op.getBeforeBody();
4094 Block &afterBlock = *op.getAfterBody();
4096 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4097 newBeforeBlock.getArguments());
4098 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4106 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4108 if (args1.size() != args2.size())
4109 return std::nullopt;
4113 auto it = llvm::find(args2, arg1);
4114 if (it == args2.end())
4115 return std::nullopt;
4117 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4124 llvm::SmallDenseSet<Value> set;
4125 for (
Value arg : args) {
4126 if (!set.insert(arg).second)
4139 LogicalResult matchAndRewrite(WhileOp loop,
4141 auto oldBefore = loop.getBeforeBody();
4142 ConditionOp oldTerm = loop.getConditionOp();
4143 ValueRange beforeArgs = oldBefore->getArguments();
4145 if (beforeArgs == termArgs)
4148 if (hasDuplicates(termArgs))
4151 auto mapping = getArgsMapping(beforeArgs, termArgs);
4162 auto oldAfter = loop.getAfterBody();
4166 newResultTypes[
j] = loop.getResult(i).getType();
4168 auto newLoop = rewriter.
create<WhileOp>(
4169 loop.getLoc(), newResultTypes, loop.getInits(),
4171 auto newBefore = newLoop.getBeforeBody();
4172 auto newAfter = newLoop.getAfterBody();
4177 newResults[i] = newLoop.getResult(
j);
4178 newAfterArgs[i] = newAfter->getArgument(
j);
4182 newBefore->getArguments());
4194 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4195 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4196 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4197 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4211 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4214 caseValues.push_back(value);
4223 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4225 p <<
"case " << value <<
' ';
4231 if (getCases().size() != getCaseRegions().size()) {
4232 return emitOpError(
"has ")
4233 << getCaseRegions().size() <<
" case regions but "
4234 << getCases().size() <<
" case values";
4238 for (int64_t value : getCases())
4239 if (!valueSet.insert(value).second)
4240 return emitOpError(
"has duplicate case value: ") << value;
4242 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4244 return emitOpError(
"expected region to end with scf.yield, but got ")
4247 if (yield.getNumOperands() != getNumResults()) {
4248 return (emitOpError(
"expected each region to return ")
4249 << getNumResults() <<
" values, but " << name <<
" returns "
4250 << yield.getNumOperands())
4251 .attachNote(yield.getLoc())
4252 <<
"see yield operation here";
4254 for (
auto [idx, result, operand] :
4255 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4256 yield.getOperandTypes())) {
4257 if (result == operand)
4259 return (emitOpError(
"expected result #")
4260 << idx <<
" of each region to be " << result)
4261 .attachNote(yield.getLoc())
4262 << name <<
" returns " << operand <<
" here";
4267 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4270 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4276 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4278 Block &scf::IndexSwitchOp::getDefaultBlock() {
4279 return getDefaultRegion().
front();
4282 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4283 assert(idx < getNumCases() &&
"case index out-of-bounds");
4284 return getCaseRegions()[idx].front();
4287 void IndexSwitchOp::getSuccessorRegions(
4291 successors.emplace_back(getResults());
4295 llvm::copy(getRegions(), std::back_inserter(successors));
4298 void IndexSwitchOp::getEntrySuccessorRegions(
4301 FoldAdaptor adaptor(operands, *
this);
4304 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4306 llvm::copy(getRegions(), std::back_inserter(successors));
4312 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4313 if (caseValue == arg.getInt()) {
4314 successors.emplace_back(&caseRegion);
4318 successors.emplace_back(&getDefaultRegion());
4321 void IndexSwitchOp::getRegionInvocationBounds(
4323 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4324 if (!operandValue) {
4330 unsigned liveIndex = getNumRegions() - 1;
4331 const auto *it = llvm::find(getCases(), operandValue.getInt());
4332 if (it != getCases().end())
4333 liveIndex = std::distance(getCases().begin(), it);
4334 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4335 bounds.emplace_back(0, i == liveIndex);
4346 if (!maybeCst.has_value())
4348 int64_t cst = *maybeCst;
4349 int64_t caseIdx, e = op.getNumCases();
4350 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4351 if (cst == op.getCases()[caseIdx])
4355 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4356 : op.getDefaultRegion();
4357 Block &source = r.front();
4380 #define GET_OP_CLASSES
4381 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.