25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77 InParallelOp, ReduceReturnOp>();
78 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80 ForallOp, InParallelOp, WhileOp, YieldOp>();
81 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
86 builder.
create<scf::YieldOp>(loc);
91 template <
typename TerminatorTy>
93 StringRef errorMessage) {
96 terminatorOperation = ®ion.
front().
back();
97 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
101 if (terminatorOperation)
102 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
114 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
160 if (getRegion().empty())
161 return emitOpError(
"region needs to have at least one block");
162 if (getRegion().front().getNumArguments() > 0)
163 return emitOpError(
"region cannot have any arguments");
186 if (!llvm::hasSingleElement(op.getRegion()))
235 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238 Block *prevBlock = op->getBlock();
242 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 for (
Block &blk : op.getRegion()) {
245 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248 yieldOp.getResults());
256 for (
auto res : op.getResults())
257 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
269 void ExecuteRegionOp::getSuccessorRegions(
287 assert((point.
isParent() || point == getParentOp().getAfter()) &&
288 "condition op can only exit the loop or branch to the after"
291 return getArgsMutable();
294 void ConditionOp::getSuccessorRegions(
296 FoldAdaptor adaptor(operands, *
this);
298 WhileOp whileOp = getParentOp();
302 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303 if (!boolAttr || boolAttr.getValue())
304 regions.emplace_back(&whileOp.getAfter(),
305 whileOp.getAfter().getArguments());
306 if (!boolAttr || !boolAttr.getValue())
307 regions.emplace_back(whileOp.getResults());
316 BodyBuilderFn bodyBuilder) {
321 for (
Value v : initArgs)
327 for (
Value v : initArgs)
333 if (initArgs.empty() && !bodyBuilder) {
334 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
335 }
else if (bodyBuilder) {
345 if (getInitArgs().size() != getNumResults())
347 "mismatch in number of loop-carried values and defined values");
352 LogicalResult ForOp::verifyRegions() {
357 "expected induction variable to be same type as bounds and step");
359 if (getNumRegionIterArgs() != getNumResults())
361 "mismatch in number of basic block args and defined values");
363 auto initArgs = getInitArgs();
364 auto iterArgs = getRegionIterArgs();
365 auto opResults = getResults();
367 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 return emitOpError() <<
"types mismatch between " << i
370 <<
"th iter operand and defined value";
372 return emitOpError() <<
"types mismatch between " << i
373 <<
"th iter region arg and defined value";
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
401 std::optional<int64_t> tripCount =
403 if (!tripCount.has_value() || tripCount != 1)
407 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
414 llvm::append_range(bbArgReplacements, getInitArgs());
418 getOperation()->getIterator(), bbArgReplacements);
434 StringRef prefix =
"") {
435 assert(blocksArgs.size() == initializers.size() &&
436 "expected same length of arguments and initializers");
437 if (initializers.empty())
441 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
442 p << std::get<0>(it) <<
" = " << std::get<1>(it);
448 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
452 if (!getInitArgs().empty())
453 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 p <<
" : " << t <<
' ';
459 !getInitArgs().empty());
481 regionArgs.push_back(inductionVariable);
491 if (regionArgs.size() != result.
types.size() + 1)
494 "mismatch in number of loop-carried values and defined values");
503 regionArgs.front().type = type;
509 for (
auto argOperandType :
510 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
511 Type type = std::get<2>(argOperandType);
512 std::get<0>(argOperandType).type = type;
524 ForOp::ensureTerminator(*body, builder, result.
location);
536 return getBody()->getArguments().drop_front(getNumInductionVars());
540 return getInitArgsMutable();
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
546 bool replaceInitOperandUsesInLoop,
551 auto inits = llvm::to_vector(getInitArgs());
552 inits.append(newInitOperands.begin(), newInitOperands.end());
553 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
559 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
561 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
566 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567 assert(newInitOperands.size() == newYieldedValues.size() &&
568 "expected as many new yield values as new iter operands");
570 yieldOp.getResultsMutable().append(newYieldedValues);
576 newLoop.getBody()->getArguments().take_front(
577 getBody()->getNumArguments()));
579 if (replaceInitOperandUsesInLoop) {
582 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
593 newLoop->getResults().take_front(getNumResults()));
594 return cast<LoopLikeOpInterface>(newLoop.getOperation());
598 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
601 assert(ivArg.getOwner() &&
"unlinked block argument");
602 auto *containingOp = ivArg.getOwner()->getParentOp();
603 return dyn_cast_or_null<ForOp>(containingOp);
607 return getInitArgs();
624 for (
auto [lb, ub, step] :
625 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
627 if (!tripCount.has_value() || *tripCount != 1)
636 return getBody()->getArguments().drop_front(getRank());
640 return getOutputsMutable();
646 scf::InParallelOp terminator = forallOp.getTerminator();
651 bbArgReplacements.append(forallOp.getOutputs().begin(),
652 forallOp.getOutputs().end());
656 forallOp->getIterator(), bbArgReplacements);
661 results.reserve(forallOp.getResults().size());
662 for (
auto &yieldingOp : terminator.getYieldingOps()) {
663 auto parallelInsertSliceOp =
664 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
666 Value dst = parallelInsertSliceOp.getDest();
667 Value src = parallelInsertSliceOp.getSource();
668 if (llvm::isa<TensorType>(src.
getType())) {
669 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
670 forallOp.getLoc(), dst.
getType(), src, dst,
671 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672 parallelInsertSliceOp.getStrides(),
673 parallelInsertSliceOp.getStaticOffsets(),
674 parallelInsertSliceOp.getStaticSizes(),
675 parallelInsertSliceOp.getStaticStrides()));
677 llvm_unreachable(
"unsupported terminator");
692 assert(lbs.size() == ubs.size() &&
693 "expected the same number of lower and upper bounds");
694 assert(lbs.size() == steps.size() &&
695 "expected the same number of lower bounds and steps");
700 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
702 assert(results.size() == iterArgs.size() &&
703 "loop nest body must return as many values as loop has iteration "
705 return LoopNest{{}, std::move(results)};
713 loops.reserve(lbs.size());
714 ivs.reserve(lbs.size());
717 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
718 auto loop = builder.
create<scf::ForOp>(
719 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725 currentIterArgs = args;
726 currentLoc = nestedLoc;
732 loops.push_back(loop);
736 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
738 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
745 ? bodyBuilder(builder, currentLoc, ivs,
746 loops.back().getRegionIterArgs())
748 assert(results.size() == iterArgs.size() &&
749 "loop nest body must return as many values as loop has iteration "
752 builder.
create<scf::YieldOp>(loc, results);
756 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757 return LoopNest{std::move(loops), std::move(nestResults)};
765 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
770 bodyBuilder(nestedBuilder, nestedLoc, ivs);
779 assert(operand.
getOwner() == forOp);
784 "expected an iter OpOperand");
786 "Expected a different type");
788 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
790 newIterOperands.push_back(replacement);
793 newIterOperands.push_back(opOperand.get());
797 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
798 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799 forOp.getStep(), newIterOperands);
800 newForOp->
setAttrs(forOp->getAttrs());
801 Block &newBlock = newForOp.getRegion().
front();
809 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
811 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
816 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
819 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
822 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
823 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824 clonedYieldOp.getOperand(yieldIdx));
826 newYieldOperands[yieldIdx] = castOut;
827 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828 rewriter.
eraseOp(clonedYieldOp);
833 newResults[yieldIdx] =
834 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
856 LogicalResult matchAndRewrite(scf::ForOp forOp,
858 bool canonicalize =
false;
865 int64_t numResults = forOp.getNumResults();
867 keepMask.reserve(numResults);
870 newBlockTransferArgs.reserve(1 + numResults);
871 newBlockTransferArgs.push_back(
Value());
872 newIterArgs.reserve(forOp.getInitArgs().size());
873 newYieldValues.reserve(numResults);
874 newResultValues.reserve(numResults);
875 for (
auto it : llvm::zip(forOp.getInitArgs(),
876 forOp.getRegionIterArgs(),
878 forOp.getYieldedValues()
886 bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
887 (std::get<1>(it).use_empty() &&
888 (std::get<0>(it) == std::get<3>(it) ||
889 std::get<2>(it).use_empty())));
890 keepMask.push_back(!forwarded);
891 canonicalize |= forwarded;
893 newBlockTransferArgs.push_back(std::get<0>(it));
894 newResultValues.push_back(std::get<0>(it));
897 newIterArgs.push_back(std::get<0>(it));
898 newYieldValues.push_back(std::get<3>(it));
899 newBlockTransferArgs.push_back(
Value());
900 newResultValues.push_back(
Value());
906 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
907 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
908 forOp.getStep(), newIterArgs);
909 newForOp->
setAttrs(forOp->getAttrs());
910 Block &newBlock = newForOp.getRegion().
front();
914 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
916 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
917 Value &newResultVal = newResultValues[idx];
918 assert((blockTransferArg && newResultVal) ||
919 (!blockTransferArg && !newResultVal));
920 if (!blockTransferArg) {
921 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
922 newResultVal = newForOp.getResult(collapsedIdx++);
928 "unexpected argument size mismatch");
933 if (newIterArgs.empty()) {
934 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
937 rewriter.
replaceOp(forOp, newResultValues);
942 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
946 filteredOperands.reserve(newResultValues.size());
947 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
949 filteredOperands.push_back(mergedTerminator.getOperand(idx));
950 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
954 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
955 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
956 cloneFilteredTerminator(mergedYieldOp);
957 rewriter.
eraseOp(mergedYieldOp);
958 rewriter.
replaceOp(forOp, newResultValues);
966 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
967 IntegerAttr clb, cub;
969 llvm::APInt lbValue = clb.getValue();
970 llvm::APInt ubValue = cub.getValue();
971 return (ubValue - lbValue).getSExtValue();
980 return diff.getSExtValue();
990 LogicalResult matchAndRewrite(ForOp op,
994 if (op.getLowerBound() == op.getUpperBound()) {
995 rewriter.
replaceOp(op, op.getInitArgs());
999 std::optional<int64_t> diff =
1000 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1006 rewriter.
replaceOp(op, op.getInitArgs());
1010 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1011 if (!maybeStepValue)
1016 llvm::APInt stepValue = *maybeStepValue;
1017 if (stepValue.sge(*diff)) {
1019 blockArgs.reserve(op.getInitArgs().size() + 1);
1020 blockArgs.push_back(op.getLowerBound());
1021 llvm::append_range(blockArgs, op.getInitArgs());
1028 if (!llvm::hasSingleElement(block))
1032 if (llvm::any_of(op.getYieldedValues(),
1033 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1035 rewriter.
replaceOp(op, op.getYieldedValues());
1069 LogicalResult matchAndRewrite(ForOp op,
1071 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1072 OpOperand &iterOpOperand = std::get<0>(it);
1074 if (!incomingCast ||
1075 incomingCast.getSource().getType() == incomingCast.getType())
1080 incomingCast.getDest().getType(),
1081 incomingCast.getSource().getType()))
1083 if (!std::get<1>(it).hasOneUse())
1089 rewriter, op, iterOpOperand, incomingCast.getSource(),
1091 return b.create<tensor::CastOp>(loc, type, source);
1103 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1107 std::optional<APInt> ForOp::getConstantStep() {
1110 return step.getValue();
1114 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1115 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1121 if (
auto constantStep = getConstantStep())
1122 if (*constantStep == 1)
1135 unsigned numLoops = getRank();
1137 if (getNumResults() != getOutputs().size())
1138 return emitOpError(
"produces ")
1139 << getNumResults() <<
" results, but has only "
1140 << getOutputs().size() <<
" outputs";
1143 auto *body = getBody();
1145 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1146 for (int64_t i = 0; i < numLoops; ++i)
1148 return emitOpError(
"expects ")
1149 << i <<
"-th block argument to be an index";
1150 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1152 return emitOpError(
"type mismatch between ")
1153 << i <<
"-th output and corresponding block argument";
1154 if (getMapping().has_value() && !getMapping()->empty()) {
1155 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1156 return emitOpError() <<
"mapping attribute size must match op rank";
1157 for (
auto map : getMapping()->getValue()) {
1158 if (!isa<DeviceMappingAttrInterface>(map))
1159 return emitOpError()
1167 getStaticLowerBound(),
1168 getDynamicLowerBound())))
1171 getStaticUpperBound(),
1172 getDynamicUpperBound())))
1175 getStaticStep(), getDynamicStep())))
1183 p <<
" (" << getInductionVars();
1184 if (isNormalized()) {
1205 if (!getRegionOutArgs().empty())
1206 p <<
"-> (" << getResultTypes() <<
") ";
1207 p.printRegion(getRegion(),
1209 getNumResults() > 0);
1210 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1211 getStaticLowerBoundAttrName(),
1212 getStaticUpperBoundAttrName(),
1213 getStaticStepAttrName()});
1218 auto indexType = b.getIndexType();
1238 unsigned numLoops = ivs.size();
1273 if (outOperands.size() != result.
types.size())
1275 "mismatch between out operands and types");
1285 std::unique_ptr<Region> region = std::make_unique<Region>();
1286 for (
auto &iv : ivs) {
1287 iv.type = b.getIndexType();
1288 regionArgs.push_back(iv);
1291 auto &out = it.value();
1292 out.type = result.
types[it.index()];
1293 regionArgs.push_back(out);
1299 ForallOp::ensureTerminator(*region, b, result.
location);
1311 {static_cast<int32_t>(dynamicLbs.size()),
1312 static_cast<int32_t>(dynamicUbs.size()),
1313 static_cast<int32_t>(dynamicSteps.size()),
1314 static_cast<int32_t>(outOperands.size())}));
1319 void ForallOp::build(
1323 std::optional<ArrayAttr> mapping,
1344 "operandSegmentSizes",
1346 static_cast<int32_t>(dynamicUbs.size()),
1347 static_cast<int32_t>(dynamicSteps.size()),
1348 static_cast<int32_t>(outputs.size())}));
1349 if (mapping.has_value()) {
1368 if (!bodyBuilderFn) {
1369 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1376 void ForallOp::build(
1379 std::optional<ArrayAttr> mapping,
1381 unsigned numLoops = ubs.size();
1384 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1388 bool ForallOp::isNormalized() {
1392 return intValue.has_value() && intValue == val;
1395 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1404 ForallOp>::ensureTerminator(region, builder, loc);
1411 InParallelOp ForallOp::getTerminator() {
1412 return cast<InParallelOp>(getBody()->getTerminator());
1417 InParallelOp inParallelOp = getTerminator();
1418 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1419 if (
auto parallelInsertSliceOp =
1420 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1421 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1422 storeOps.push_back(parallelInsertSliceOp);
1428 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1433 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1435 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1439 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1441 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1445 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1451 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1454 assert(tidxArg.getOwner() &&
"unlinked block argument");
1455 auto *containingOp = tidxArg.getOwner()->getParentOp();
1456 return dyn_cast<ForallOp>(containingOp);
1464 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1466 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1470 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1473 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1482 LogicalResult matchAndRewrite(ForallOp op,
1497 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1498 op.setStaticLowerBound(staticLowerBound);
1502 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1503 op.setStaticUpperBound(staticUpperBound);
1506 op.getDynamicStepMutable().assign(dynamicStep);
1507 op.setStaticStep(staticStep);
1509 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1511 {static_cast<int32_t>(dynamicLowerBound.size()),
1512 static_cast<int32_t>(dynamicUpperBound.size()),
1513 static_cast<int32_t>(dynamicStep.size()),
1514 static_cast<int32_t>(op.getNumResults())}));
1596 LogicalResult matchAndRewrite(ForallOp forallOp,
1615 for (
OpResult result : forallOp.getResults()) {
1616 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1617 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1618 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1619 resultToDelete.insert(result);
1621 resultToReplace.push_back(result);
1622 newOuts.push_back(opOperand->
get());
1628 if (resultToDelete.empty())
1636 for (
OpResult result : resultToDelete) {
1637 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1638 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1640 forallOp.getCombiningOps(blockArg);
1641 for (
Operation *combiningOp : combiningOps)
1642 rewriter.
eraseOp(combiningOp);
1647 auto newForallOp = rewriter.
create<scf::ForallOp>(
1648 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1649 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1650 forallOp.getMapping(),
1655 Block *loopBody = forallOp.getBody();
1656 Block *newLoopBody = newForallOp.getBody();
1661 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1668 for (
OpResult result : forallOp.getResults()) {
1669 if (resultToDelete.count(result)) {
1670 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1672 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1675 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1679 for (
auto &&[oldResult, newResult] :
1680 llvm::zip(resultToReplace, newForallOp->getResults()))
1686 for (
OpResult oldResult : resultToDelete)
1688 forallOp.getTiedOpOperand(oldResult)->get());
1693 struct ForallOpSingleOrZeroIterationDimsFolder
1697 LogicalResult matchAndRewrite(ForallOp op,
1700 if (op.getMapping().has_value() && !op.getMapping()->empty())
1708 for (
auto [lb, ub, step, iv] :
1709 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1710 op.getMixedStep(), op.getInductionVars())) {
1712 if (numIterations.has_value()) {
1714 if (*numIterations == 0) {
1715 rewriter.
replaceOp(op, op.getOutputs());
1720 if (*numIterations == 1) {
1725 newMixedLowerBounds.push_back(lb);
1726 newMixedUpperBounds.push_back(ub);
1727 newMixedSteps.push_back(step);
1731 if (newMixedLowerBounds.empty()) {
1737 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1739 op,
"no dimensions have 0 or 1 iterations");
1744 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1745 newMixedUpperBounds, newMixedSteps,
1746 op.getOutputs(), std::nullopt,
nullptr);
1747 newOp.getBodyRegion().getBlocks().clear();
1752 newOp.getStaticLowerBoundAttrName(),
1753 newOp.getStaticUpperBoundAttrName(),
1754 newOp.getStaticStepAttrName()};
1755 for (
const auto &namedAttr : op->getAttrs()) {
1756 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1759 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1763 newOp.getRegion().begin(), mapping);
1764 rewriter.
replaceOp(op, newOp.getResults());
1770 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1773 LogicalResult matchAndRewrite(ForallOp op,
1777 for (
auto [lb, ub, step, iv] :
1778 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1779 op.getMixedStep(), op.getInductionVars())) {
1780 if (iv.getUses().begin() == iv.getUses().end())
1783 if (!numIterations.has_value() || numIterations.value() != 1) {
1794 struct FoldTensorCastOfOutputIntoForallOp
1803 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1805 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1808 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1815 castOp.getSource().getType())) {
1819 tensorCastProducers[en.index()] =
1820 TypeCast{castOp.getSource().getType(), castOp.getType()};
1821 newOutputTensors[en.index()] = castOp.getSource();
1824 if (tensorCastProducers.empty())
1829 auto newForallOp = rewriter.
create<ForallOp>(
1830 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1831 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1833 auto castBlockArgs =
1834 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1835 for (
auto [index, cast] : tensorCastProducers) {
1836 Value &oldTypeBBArg = castBlockArgs[index];
1837 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1838 nestedLoc, cast.dstType, oldTypeBBArg);
1843 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1844 ivsBlockArgs.append(castBlockArgs);
1846 bbArgs.front().getParentBlock(), ivsBlockArgs);
1852 auto terminator = newForallOp.getTerminator();
1853 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1854 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1855 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1856 insertSliceOp.getDestMutable().assign(outputBlockArg);
1862 for (
auto &item : tensorCastProducers) {
1863 Value &oldTypeResult = castResults[item.first];
1864 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1867 rewriter.
replaceOp(forallOp, castResults);
1876 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1877 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1878 ForallOpSingleOrZeroIterationDimsFolder,
1879 ForallOpReplaceConstantInductionVar>(context);
1908 scf::ForallOp forallOp =
1909 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1911 return this->emitOpError(
"expected forall op parent");
1914 for (
Operation &op : getRegion().front().getOperations()) {
1915 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1916 return this->emitOpError(
"expected only ")
1917 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1921 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1923 if (!llvm::is_contained(regionOutArgs, dest))
1924 return op.emitOpError(
"may only insert into an output block argument");
1941 std::unique_ptr<Region> region = std::make_unique<Region>();
1945 if (region->empty())
1955 OpResult InParallelOp::getParentResult(int64_t idx) {
1956 return getOperation()->getParentOp()->getResult(idx);
1960 return llvm::to_vector<4>(
1961 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1963 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1964 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1969 return getRegion().front().getOperations();
1977 assert(a &&
"expected non-empty operation");
1978 assert(b &&
"expected non-empty operation");
1983 if (ifOp->isProperAncestor(b))
1986 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1987 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1989 ifOp = ifOp->getParentOfType<IfOp>();
1997 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1998 IfOp::Adaptor adaptor,
2000 if (adaptor.getRegions().empty())
2002 Region *r = &adaptor.getThenRegion();
2005 Block &b = r->front();
2008 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2011 TypeRange types = yieldOp.getOperandTypes();
2012 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2019 return build(builder, result, resultTypes, cond,
false,
2025 bool addElseBlock) {
2026 assert((!addElseBlock || addThenBlock) &&
2027 "must not create else block w/o then block");
2042 bool withElseRegion) {
2043 build(builder, result,
TypeRange{}, cond, withElseRegion);
2055 if (resultTypes.empty())
2056 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2060 if (withElseRegion) {
2062 if (resultTypes.empty())
2063 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2070 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2077 thenBuilder(builder, result.
location);
2083 elseBuilder(builder, result.
location);
2090 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2092 inferredReturnTypes))) {
2093 result.
addTypes(inferredReturnTypes);
2098 if (getNumResults() != 0 && getElseRegion().empty())
2099 return emitOpError(
"must have an else block if defining values");
2137 bool printBlockTerminators =
false;
2139 p <<
" " << getCondition();
2140 if (!getResults().empty()) {
2141 p <<
" -> (" << getResultTypes() <<
")";
2143 printBlockTerminators =
true;
2148 printBlockTerminators);
2151 auto &elseRegion = getElseRegion();
2152 if (!elseRegion.
empty()) {
2156 printBlockTerminators);
2173 Region *elseRegion = &this->getElseRegion();
2174 if (elseRegion->
empty())
2182 FoldAdaptor adaptor(operands, *
this);
2183 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2184 if (!boolAttr || boolAttr.getValue())
2185 regions.emplace_back(&getThenRegion());
2188 if (!boolAttr || !boolAttr.getValue()) {
2189 if (!getElseRegion().empty())
2190 regions.emplace_back(&getElseRegion());
2192 regions.emplace_back(getResults());
2196 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2199 if (getElseRegion().empty())
2202 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2209 getConditionMutable().assign(xorStmt.getLhs());
2213 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2214 getElseRegion().getBlocks());
2215 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2216 getThenRegion().getBlocks(), thenBlock);
2220 void IfOp::getRegionInvocationBounds(
2223 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2226 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2227 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2230 invocationBounds.assign(2, {0, 1});
2246 llvm::transform(usedResults, std::back_inserter(usedOperands),
2251 [&]() { yieldOp->setOperands(usedOperands); });
2254 LogicalResult matchAndRewrite(IfOp op,
2258 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2259 [](
OpResult result) { return !result.use_empty(); });
2262 if (usedResults.size() == op.getNumResults())
2267 llvm::transform(usedResults, std::back_inserter(newTypes),
2272 rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2278 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2279 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2284 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2293 LogicalResult matchAndRewrite(IfOp op,
2301 else if (!op.getElseRegion().empty())
2315 LogicalResult matchAndRewrite(IfOp op,
2317 if (op->getNumResults() == 0)
2320 auto cond = op.getCondition();
2321 auto thenYieldArgs = op.thenYield().getOperands();
2322 auto elseYieldArgs = op.elseYield().getOperands();
2325 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2326 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2327 &op.getElseRegion() == falseVal.getParentRegion())
2328 nonHoistable.push_back(trueVal.getType());
2332 if (nonHoistable.size() == op->getNumResults())
2335 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond,
2337 if (replacement.thenBlock())
2338 rewriter.
eraseBlock(replacement.thenBlock());
2339 replacement.getThenRegion().takeBody(op.getThenRegion());
2340 replacement.getElseRegion().takeBody(op.getElseRegion());
2343 assert(thenYieldArgs.size() == results.size());
2344 assert(elseYieldArgs.size() == results.size());
2349 for (
const auto &it :
2351 Value trueVal = std::get<0>(it.value());
2352 Value falseVal = std::get<1>(it.value());
2355 results[it.index()] = replacement.getResult(trueYields.size());
2356 trueYields.push_back(trueVal);
2357 falseYields.push_back(falseVal);
2358 }
else if (trueVal == falseVal)
2359 results[it.index()] = trueVal;
2361 results[it.index()] = rewriter.
create<arith::SelectOp>(
2362 op.getLoc(), cond, trueVal, falseVal);
2392 LogicalResult matchAndRewrite(IfOp op,
2404 Value constantTrue =
nullptr;
2405 Value constantFalse =
nullptr;
2408 llvm::make_early_inc_range(op.getCondition().getUses())) {
2413 constantTrue = rewriter.
create<arith::ConstantOp>(
2417 [&]() { use.
set(constantTrue); });
2418 }
else if (op.getElseRegion().isAncestor(
2423 constantFalse = rewriter.
create<arith::ConstantOp>(
2427 [&]() { use.
set(constantFalse); });
2471 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2474 LogicalResult matchAndRewrite(IfOp op,
2477 if (op.getNumResults() == 0)
2481 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2483 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2486 op.getOperation()->getIterator());
2489 for (
auto [trueResult, falseResult, opResult] :
2490 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2492 if (trueResult == falseResult) {
2493 if (!opResult.use_empty()) {
2494 opResult.replaceAllUsesWith(trueResult);
2505 bool trueVal = trueYield.
getValue();
2506 bool falseVal = falseYield.
getValue();
2507 if (!trueVal && falseVal) {
2508 if (!opResult.use_empty()) {
2509 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2511 op.getLoc(), op.getCondition(),
2521 if (trueVal && !falseVal) {
2522 if (!opResult.use_empty()) {
2523 opResult.replaceAllUsesWith(op.getCondition());
2556 LogicalResult matchAndRewrite(IfOp nextIf,
2558 Block *parent = nextIf->getBlock();
2559 if (nextIf == &parent->
front())
2562 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2570 Block *nextThen =
nullptr;
2571 Block *nextElse =
nullptr;
2572 if (nextIf.getCondition() == prevIf.getCondition()) {
2573 nextThen = nextIf.thenBlock();
2574 if (!nextIf.getElseRegion().empty())
2575 nextElse = nextIf.elseBlock();
2577 if (arith::XOrIOp notv =
2578 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2579 if (notv.getLhs() == prevIf.getCondition() &&
2581 nextElse = nextIf.thenBlock();
2582 if (!nextIf.getElseRegion().empty())
2583 nextThen = nextIf.elseBlock();
2586 if (arith::XOrIOp notv =
2587 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2588 if (notv.getLhs() == nextIf.getCondition() &&
2590 nextElse = nextIf.thenBlock();
2591 if (!nextIf.getElseRegion().empty())
2592 nextThen = nextIf.elseBlock();
2596 if (!nextThen && !nextElse)
2600 if (!prevIf.getElseRegion().empty())
2601 prevElseYielded = prevIf.elseYield().getOperands();
2604 for (
auto it : llvm::zip(prevIf.getResults(),
2605 prevIf.thenYield().getOperands(), prevElseYielded))
2607 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2611 use.
set(std::get<1>(it));
2616 use.
set(std::get<2>(it));
2622 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2624 IfOp combinedIf = rewriter.
create<IfOp>(
2625 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2626 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2629 combinedIf.getThenRegion(),
2630 combinedIf.getThenRegion().begin());
2633 YieldOp thenYield = combinedIf.thenYield();
2634 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2635 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2639 llvm::append_range(mergedYields, thenYield2.getOperands());
2640 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2646 combinedIf.getElseRegion(),
2647 combinedIf.getElseRegion().begin());
2650 if (combinedIf.getElseRegion().empty()) {
2652 combinedIf.getElseRegion(),
2653 combinedIf.getElseRegion().
begin());
2655 YieldOp elseYield = combinedIf.elseYield();
2656 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2657 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2662 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2664 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2673 if (pair.index() < prevIf.getNumResults())
2674 prevValues.push_back(pair.value());
2676 nextValues.push_back(pair.value());
2688 LogicalResult matchAndRewrite(IfOp ifOp,
2691 if (ifOp.getNumResults())
2693 Block *elseBlock = ifOp.elseBlock();
2694 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2698 newIfOp.getThenRegion().begin());
2723 LogicalResult matchAndRewrite(IfOp op,
2725 auto nestedOps = op.thenBlock()->without_terminator();
2727 if (!llvm::hasSingleElement(nestedOps))
2731 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2734 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2738 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2744 llvm::append_range(elseYield, op.elseYield().getOperands());
2758 if (tup.value().getDefiningOp() == nestedIf) {
2759 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2760 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2761 elseYield[tup.index()]) {
2766 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2779 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2782 elseYieldsToUpgradeToSelect.push_back(tup.index());
2786 Value newCondition = rewriter.
create<arith::AndIOp>(
2787 loc, op.getCondition(), nestedIf.getCondition());
2788 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2792 llvm::append_range(results, newIf.getResults());
2795 for (
auto idx : elseYieldsToUpgradeToSelect)
2796 results[idx] = rewriter.
create<arith::SelectOp>(
2797 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2799 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2802 if (!elseYield.empty()) {
2805 rewriter.
create<YieldOp>(loc, elseYield);
2816 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2817 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2818 RemoveStaticCondition, RemoveUnusedResults,
2819 ReplaceIfYieldWithConditionOrValue>(context);
2822 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2823 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2824 Block *IfOp::elseBlock() {
2825 Region &r = getElseRegion();
2830 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2836 void ParallelOp::build(
2846 ParallelOp::getOperandSegmentSizeAttr(),
2848 static_cast<int32_t>(upperBounds.size()),
2849 static_cast<int32_t>(steps.size()),
2850 static_cast<int32_t>(initVals.size())}));
2854 unsigned numIVs = steps.size();
2860 if (bodyBuilderFn) {
2862 bodyBuilderFn(builder, result.
location,
2867 if (initVals.empty())
2868 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2871 void ParallelOp::build(
2878 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2881 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2885 wrapper = wrappedBuilderFn;
2887 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2896 if (stepValues.empty())
2898 "needs at least one tuple element for lowerBound, upperBound and step");
2901 for (
Value stepValue : stepValues)
2904 return emitOpError(
"constant step operand must be positive");
2908 Block *body = getBody();
2910 return emitOpError() <<
"expects the same number of induction variables: "
2912 <<
" as bound and step values: " << stepValues.size();
2914 if (!arg.getType().isIndex())
2916 "expects arguments for the induction variable to be of index type");
2919 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2920 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2925 auto resultsSize = getResults().size();
2926 auto reductionsSize = reduceOp.getReductions().size();
2927 auto initValsSize = getInitVals().size();
2928 if (resultsSize != reductionsSize)
2929 return emitOpError() <<
"expects number of results: " << resultsSize
2930 <<
" to be the same as number of reductions: "
2932 if (resultsSize != initValsSize)
2933 return emitOpError() <<
"expects number of results: " << resultsSize
2934 <<
" to be the same as number of initial values: "
2938 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2939 auto resultType = getOperation()->getResult(i).getType();
2940 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2941 if (resultType != reductionOperandType)
2942 return reduceOp.emitOpError()
2943 <<
"expects type of " << i
2944 <<
"-th reduction operand: " << reductionOperandType
2945 <<
" to be the same as the " << i
2946 <<
"-th result type: " << resultType;
2994 for (
auto &iv : ivs)
3001 ParallelOp::getOperandSegmentSizeAttr(),
3003 static_cast<int32_t>(upper.size()),
3004 static_cast<int32_t>(steps.size()),
3005 static_cast<int32_t>(initVals.size())}));
3014 ParallelOp::ensureTerminator(*body, builder, result.
location);
3019 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3020 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3021 if (!getInitVals().empty())
3022 p <<
" init (" << getInitVals() <<
")";
3027 (*this)->getAttrs(),
3028 ParallelOp::getOperandSegmentSizeAttr());
3033 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3037 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3041 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3045 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3050 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3052 return ParallelOp();
3053 assert(ivArg.getOwner() &&
"unlinked block argument");
3054 auto *containingOp = ivArg.getOwner()->getParentOp();
3055 return dyn_cast<ParallelOp>(containingOp);
3060 struct ParallelOpSingleOrZeroIterationDimsFolder
3064 LogicalResult matchAndRewrite(ParallelOp op,
3071 for (
auto [lb, ub, step, iv] :
3072 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3073 op.getInductionVars())) {
3075 if (numIterations.has_value()) {
3077 if (*numIterations == 0) {
3078 rewriter.
replaceOp(op, op.getInitVals());
3083 if (*numIterations == 1) {
3088 newLowerBounds.push_back(lb);
3089 newUpperBounds.push_back(ub);
3090 newSteps.push_back(step);
3093 if (newLowerBounds.size() == op.getLowerBound().size())
3096 if (newLowerBounds.empty()) {
3100 results.reserve(op.getInitVals().size());
3101 for (
auto &bodyOp : op.getBody()->without_terminator())
3102 rewriter.
clone(bodyOp, mapping);
3103 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3104 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3105 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3106 auto initValIndex = results.size();
3107 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3111 rewriter.
clone(reduceBodyOp, mapping);
3114 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3115 results.push_back(result);
3123 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3124 newSteps, op.getInitVals(),
nullptr);
3130 newOp.getRegion().begin(), mapping);
3131 rewriter.
replaceOp(op, newOp.getResults());
3139 LogicalResult matchAndRewrite(ParallelOp op,
3141 Block &outerBody = *op.getBody();
3145 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3150 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3151 llvm::is_contained(innerOp.getUpperBound(), val) ||
3152 llvm::is_contained(innerOp.getStep(), val))
3156 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3161 Block &innerBody = *innerOp.getBody();
3162 assert(iterVals.size() ==
3170 builder.
clone(op, mapping);
3173 auto concatValues = [](
const auto &first,
const auto &second) {
3175 ret.reserve(first.size() + second.size());
3176 ret.assign(first.begin(), first.end());
3177 ret.append(second.begin(), second.end());
3181 auto newLowerBounds =
3182 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3183 auto newUpperBounds =
3184 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3185 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3188 newSteps, std::nullopt,
3199 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3208 void ParallelOp::getSuccessorRegions(
3226 for (
Value v : operands) {
3235 LogicalResult ReduceOp::verifyRegions() {
3238 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3239 auto type = getOperands()[i].getType();
3242 return emitOpError() << i <<
"-th reduction has an empty body";
3245 return arg.getType() != type;
3247 return emitOpError() <<
"expected two block arguments with type " << type
3248 <<
" in the " << i <<
"-th reduction region";
3252 return emitOpError(
"reduction bodies must be terminated with an "
3253 "'scf.reduce.return' op");
3272 Block *reductionBody = getOperation()->getBlock();
3274 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3276 if (expectedResultType != getResult().
getType())
3277 return emitOpError() <<
"must have type " << expectedResultType
3278 <<
" (the type of the reduction inputs)";
3288 ValueRange inits, BodyBuilderFn beforeBuilder,
3289 BodyBuilderFn afterBuilder) {
3297 beforeArgLocs.reserve(inits.size());
3298 for (
Value operand : inits) {
3299 beforeArgLocs.push_back(operand.getLoc());
3304 inits.getTypes(), beforeArgLocs);
3313 resultTypes, afterArgLocs);
3319 ConditionOp WhileOp::getConditionOp() {
3320 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3323 YieldOp WhileOp::getYieldOp() {
3324 return cast<YieldOp>(getAfterBody()->getTerminator());
3327 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3328 return getYieldOp().getResultsMutable();
3332 return getBeforeBody()->getArguments();
3336 return getAfterBody()->getArguments();
3340 return getBeforeArguments();
3344 assert(point == getBefore() &&
3345 "WhileOp is expected to branch only to the first region");
3353 regions.emplace_back(&getBefore(), getBefore().getArguments());
3357 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3358 "there are only two regions in a WhileOp");
3360 if (point == getAfter()) {
3361 regions.emplace_back(&getBefore(), getBefore().getArguments());
3365 regions.emplace_back(getResults());
3366 regions.emplace_back(&getAfter(), getAfter().getArguments());
3370 return {&getBefore(), &getAfter()};
3391 FunctionType functionType;
3396 result.
addTypes(functionType.getResults());
3398 if (functionType.getNumInputs() != operands.size()) {
3400 <<
"expected as many input types as operands "
3401 <<
"(expected " << operands.size() <<
" got "
3402 << functionType.getNumInputs() <<
")";
3412 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3413 regionArgs[i].type = functionType.getInput(i);
3415 return failure(parser.
parseRegion(*before, regionArgs) ||
3435 template <
typename OpTy>
3438 if (left.size() != right.size())
3439 return op.emitOpError(
"expects the same number of ") << message;
3441 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3442 if (left[i] != right[i]) {
3445 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3446 <<
" and " << right[i];
3455 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3457 "expects the 'before' region to terminate with 'scf.condition'");
3458 if (!beforeTerminator)
3461 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3463 "expects the 'after' region to terminate with 'scf.yield'");
3464 return success(afterTerminator !=
nullptr);
3490 LogicalResult matchAndRewrite(WhileOp op,
3492 auto term = op.getConditionOp();
3496 Value constantTrue =
nullptr;
3498 bool replaced =
false;
3499 for (
auto yieldedAndBlockArgs :
3500 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3501 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3502 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3504 constantTrue = rewriter.
create<arith::ConstantOp>(
3505 op.getLoc(), term.getCondition().getType(),
3514 return success(replaced);
3566 struct RemoveLoopInvariantArgsFromBeforeBlock
3570 LogicalResult matchAndRewrite(WhileOp op,
3572 Block &afterBlock = *op.getAfterBody();
3574 ConditionOp condOp = op.getConditionOp();
3579 bool canSimplify =
false;
3580 for (
const auto &it :
3582 auto index =
static_cast<unsigned>(it.index());
3583 auto [initVal, yieldOpArg] = it.value();
3586 if (yieldOpArg == initVal) {
3595 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3596 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3597 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3598 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3611 for (
const auto &it :
3613 auto index =
static_cast<unsigned>(it.index());
3614 auto [initVal, yieldOpArg] = it.value();
3618 if (yieldOpArg == initVal) {
3619 beforeBlockInitValMap.insert({index, initVal});
3627 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3628 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3629 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3630 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3631 beforeBlockInitValMap.insert({index, initVal});
3636 newInitArgs.emplace_back(initVal);
3637 newYieldOpArgs.emplace_back(yieldOpArg);
3638 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3648 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3651 &newWhile.getBefore(), {},
3654 Block &beforeBlock = *op.getBeforeBody();
3661 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3664 if (beforeBlockInitValMap.count(i) != 0)
3665 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3667 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3670 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3672 newWhile.getAfter().begin());
3674 rewriter.
replaceOp(op, newWhile.getResults());
3719 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3722 LogicalResult matchAndRewrite(WhileOp op,
3724 Block &beforeBlock = *op.getBeforeBody();
3725 ConditionOp condOp = op.getConditionOp();
3728 bool canSimplify =
false;
3729 for (
Value condOpArg : condOpArgs) {
3749 auto index =
static_cast<unsigned>(it.index());
3750 Value condOpArg = it.value();
3755 condOpInitValMap.insert({index, condOpArg});
3757 newCondOpArgs.emplace_back(condOpArg);
3758 newAfterBlockType.emplace_back(condOpArg.
getType());
3759 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3770 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3773 Block &newAfterBlock =
3775 newAfterBlockType, newAfterBlockArgLocs);
3777 Block &afterBlock = *op.getAfterBody();
3784 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3785 Value afterBlockArg, result;
3788 if (condOpInitValMap.count(i) != 0) {
3789 afterBlockArg = condOpInitValMap[i];
3790 result = afterBlockArg;
3792 afterBlockArg = newAfterBlock.getArgument(
j);
3793 result = newWhile.getResult(
j);
3796 newAfterBlockArgs[i] = afterBlockArg;
3797 newWhileResults[i] = result;
3800 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3802 newWhile.getBefore().begin());
3804 rewriter.
replaceOp(op, newWhileResults);
3838 LogicalResult matchAndRewrite(WhileOp op,
3840 auto term = op.getConditionOp();
3841 auto afterArgs = op.getAfterArguments();
3842 auto termArgs = term.getArgs();
3849 bool needUpdate =
false;
3850 for (
const auto &it :
3852 auto i =
static_cast<unsigned>(it.index());
3853 Value result = std::get<0>(it.value());
3854 Value afterArg = std::get<1>(it.value());
3855 Value termArg = std::get<2>(it.value());
3859 newResultsIndices.emplace_back(i);
3860 newTermArgs.emplace_back(termArg);
3861 newResultTypes.emplace_back(result.
getType());
3862 newArgLocs.emplace_back(result.
getLoc());
3877 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3880 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3887 newResults[it.value()] = newWhile.getResult(it.index());
3888 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3892 newWhile.getBefore().begin());
3894 Block &afterBlock = *op.getAfterBody();
3895 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3927 LogicalResult matchAndRewrite(scf::WhileOp op,
3929 using namespace scf;
3930 auto cond = op.getConditionOp();
3931 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3935 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3936 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3937 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3940 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3941 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3945 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3948 if (cmp2.getPredicate() == cmp.getPredicate())
3949 samePredicate =
true;
3950 else if (cmp2.getPredicate() ==
3952 samePredicate =
false;
3970 LogicalResult matchAndRewrite(WhileOp op,
3973 if (!llvm::any_of(op.getBeforeArguments(),
3974 [](
Value arg) { return arg.use_empty(); }))
3977 YieldOp yield = op.getYieldOp();
3982 llvm::BitVector argsToErase;
3984 size_t argsCount = op.getBeforeArguments().size();
3985 newYields.reserve(argsCount);
3986 newInits.reserve(argsCount);
3987 argsToErase.reserve(argsCount);
3988 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3989 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3990 if (beforeArg.use_empty()) {
3991 argsToErase.push_back(
true);
3993 argsToErase.push_back(
false);
3994 newYields.emplace_back(yieldValue);
3995 newInits.emplace_back(initValue);
3999 Block &beforeBlock = *op.getBeforeBody();
4000 Block &afterBlock = *op.getAfterBody();
4006 rewriter.
create<WhileOp>(loc, op.getResultTypes(), newInits,
4008 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4009 Block &newAfterBlock = *newWhileOp.getAfterBody();
4015 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4016 newBeforeBlock.getArguments());
4020 rewriter.
replaceOp(op, newWhileOp.getResults());
4029 LogicalResult matchAndRewrite(WhileOp op,
4031 ConditionOp condOp = op.getConditionOp();
4035 for (
Value arg : condOpArgs)
4036 argsSet.insert(arg);
4038 if (argsSet.size() == condOpArgs.size())
4041 llvm::SmallDenseMap<Value, unsigned> argsMap;
4043 argsMap.reserve(condOpArgs.size());
4044 newArgs.reserve(condOpArgs.size());
4045 for (
Value arg : condOpArgs) {
4046 if (!argsMap.count(arg)) {
4047 auto pos =
static_cast<unsigned>(argsMap.size());
4048 argsMap.insert({arg, pos});
4049 newArgs.emplace_back(arg);
4056 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4057 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4059 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4060 Block &newAfterBlock = *newWhileOp.getAfterBody();
4065 auto it = argsMap.find(arg);
4066 assert(it != argsMap.end());
4067 auto pos = it->second;
4068 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4069 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4077 Block &beforeBlock = *op.getBeforeBody();
4078 Block &afterBlock = *op.getAfterBody();
4080 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4081 newBeforeBlock.getArguments());
4082 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4090 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4092 if (args1.size() != args2.size())
4093 return std::nullopt;
4097 auto it = llvm::find(args2, arg1);
4098 if (it == args2.end())
4099 return std::nullopt;
4101 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4108 llvm::SmallDenseSet<Value> set;
4109 for (
Value arg : args) {
4110 if (!set.insert(arg).second)
4123 LogicalResult matchAndRewrite(WhileOp loop,
4125 auto oldBefore = loop.getBeforeBody();
4126 ConditionOp oldTerm = loop.getConditionOp();
4127 ValueRange beforeArgs = oldBefore->getArguments();
4129 if (beforeArgs == termArgs)
4132 if (hasDuplicates(termArgs))
4135 auto mapping = getArgsMapping(beforeArgs, termArgs);
4146 auto oldAfter = loop.getAfterBody();
4150 newResultTypes[
j] = loop.getResult(i).getType();
4152 auto newLoop = rewriter.
create<WhileOp>(
4153 loop.getLoc(), newResultTypes, loop.getInits(),
4155 auto newBefore = newLoop.getBeforeBody();
4156 auto newAfter = newLoop.getAfterBody();
4161 newResults[i] = newLoop.getResult(
j);
4162 newAfterArgs[i] = newAfter->getArgument(
j);
4166 newBefore->getArguments());
4178 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4179 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4180 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4181 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4195 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4198 caseValues.push_back(value);
4207 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4209 p <<
"case " << value <<
' ';
4215 if (getCases().size() != getCaseRegions().size()) {
4216 return emitOpError(
"has ")
4217 << getCaseRegions().size() <<
" case regions but "
4218 << getCases().size() <<
" case values";
4222 for (int64_t value : getCases())
4223 if (!valueSet.insert(value).second)
4224 return emitOpError(
"has duplicate case value: ") << value;
4226 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4228 return emitOpError(
"expected region to end with scf.yield, but got ")
4231 if (yield.getNumOperands() != getNumResults()) {
4232 return (emitOpError(
"expected each region to return ")
4233 << getNumResults() <<
" values, but " << name <<
" returns "
4234 << yield.getNumOperands())
4235 .attachNote(yield.getLoc())
4236 <<
"see yield operation here";
4238 for (
auto [idx, result, operand] :
4239 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4240 yield.getOperandTypes())) {
4241 if (result == operand)
4243 return (emitOpError(
"expected result #")
4244 << idx <<
" of each region to be " << result)
4245 .attachNote(yield.getLoc())
4246 << name <<
" returns " << operand <<
" here";
4251 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4254 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4260 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4262 Block &scf::IndexSwitchOp::getDefaultBlock() {
4263 return getDefaultRegion().
front();
4266 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4267 assert(idx < getNumCases() &&
"case index out-of-bounds");
4268 return getCaseRegions()[idx].front();
4271 void IndexSwitchOp::getSuccessorRegions(
4275 successors.emplace_back(getResults());
4279 llvm::copy(getRegions(), std::back_inserter(successors));
4282 void IndexSwitchOp::getEntrySuccessorRegions(
4285 FoldAdaptor adaptor(operands, *
this);
4288 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4290 llvm::copy(getRegions(), std::back_inserter(successors));
4296 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4297 if (caseValue == arg.getInt()) {
4298 successors.emplace_back(&caseRegion);
4302 successors.emplace_back(&getDefaultRegion());
4305 void IndexSwitchOp::getRegionInvocationBounds(
4307 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4308 if (!operandValue) {
4314 unsigned liveIndex = getNumRegions() - 1;
4315 const auto *it = llvm::find(getCases(), operandValue.getInt());
4316 if (it != getCases().end())
4317 liveIndex = std::distance(getCases().begin(), it);
4318 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4319 bounds.emplace_back(0, i == liveIndex);
4330 if (!maybeCst.has_value())
4332 int64_t cst = *maybeCst;
4333 int64_t caseIdx, e = op.getNumCases();
4334 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4335 if (cst == op.getCases()[caseIdx])
4339 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4340 : op.getDefaultRegion();
4341 Block &source = r.front();
4364 #define GET_OP_CLASSES
4365 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalables, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.