26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
33 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
56 auto retValOp = dyn_cast<scf::YieldOp>(op);
60 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
61 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
71 void SCFDialect::initialize() {
74 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
76 addInterfaces<SCFInlinerInterface>();
77 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78 InParallelOp, ReduceReturnOp>();
79 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81 ForallOp, InParallelOp, WhileOp, YieldOp>();
82 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
87 builder.
create<scf::YieldOp>(loc);
92 template <
typename TerminatorTy>
94 StringRef errorMessage) {
97 terminatorOperation = ®ion.
front().
back();
98 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
102 if (terminatorOperation)
103 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
115 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
161 if (getRegion().empty())
162 return emitOpError(
"region needs to have at least one block");
163 if (getRegion().front().getNumArguments() > 0)
164 return emitOpError(
"region cannot have any arguments");
187 if (!llvm::hasSingleElement(op.
getRegion()))
236 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->
getParentOp()))
246 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
248 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
249 yieldOp.getResults());
258 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
270 void ExecuteRegionOp::getSuccessorRegions(
288 assert((point.
isParent() || point == getParentOp().getAfter()) &&
289 "condition op can only exit the loop or branch to the after"
292 return getArgsMutable();
295 void ConditionOp::getSuccessorRegions(
297 FoldAdaptor adaptor(operands, *
this);
299 WhileOp whileOp = getParentOp();
303 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
304 if (!boolAttr || boolAttr.getValue())
305 regions.emplace_back(&whileOp.getAfter(),
306 whileOp.getAfter().getArguments());
307 if (!boolAttr || !boolAttr.getValue())
308 regions.emplace_back(whileOp.getResults());
317 BodyBuilderFn bodyBuilder) {
322 for (
Value v : iterArgs)
328 for (
Value v : iterArgs)
334 if (iterArgs.empty() && !bodyBuilder) {
335 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
336 }
else if (bodyBuilder) {
346 if (getInitArgs().size() != getNumResults())
348 "mismatch in number of loop-carried values and defined values");
356 if (getInductionVar().getType() !=
getLowerBound().getType())
358 "expected induction variable to be same type as bounds and step");
360 if (getNumRegionIterArgs() != getNumResults())
362 "mismatch in number of basic block args and defined values");
364 auto initArgs = getInitArgs();
365 auto iterArgs = getRegionIterArgs();
366 auto opResults = getResults();
368 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369 if (std::get<0>(e).getType() != std::get<2>(e).getType())
370 return emitOpError() <<
"types mismatch between " << i
371 <<
"th iter operand and defined value";
372 if (std::get<1>(e).getType() != std::get<2>(e).getType())
373 return emitOpError() <<
"types mismatch between " << i
374 <<
"th iter region arg and defined value";
381 std::optional<Value> ForOp::getSingleInductionVar() {
382 return getInductionVar();
385 std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
389 std::optional<OpFoldResult> ForOp::getSingleStep() {
393 std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
397 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
402 std::optional<int64_t> tripCount =
404 if (!tripCount.has_value() || tripCount != 1)
408 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
415 llvm::append_range(bbArgReplacements, getInitArgs());
419 getOperation()->getIterator(), bbArgReplacements);
435 StringRef prefix =
"") {
436 assert(blocksArgs.size() == initializers.size() &&
437 "expected same length of arguments and initializers");
438 if (initializers.empty())
442 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
443 p << std::get<0>(it) <<
" = " << std::get<1>(it);
449 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
453 if (!getInitArgs().empty())
454 p <<
" -> (" << getInitArgs().getTypes() <<
')';
456 if (
Type t = getInductionVar().getType(); !t.
isIndex())
457 p <<
" : " << t <<
' ';
460 !getInitArgs().empty());
482 regionArgs.push_back(inductionVariable);
492 if (regionArgs.size() != result.
types.size() + 1)
495 "mismatch in number of loop-carried values and defined values");
504 regionArgs.front().type = type;
510 for (
auto argOperandType :
511 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
512 Type type = std::get<2>(argOperandType);
513 std::get<0>(argOperandType).type = type;
525 ForOp::ensureTerminator(*body, builder, result.
location);
537 return getBody()->getArguments().drop_front(getNumInductionVars());
541 return getInitArgsMutable();
545 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
547 bool replaceInitOperandUsesInLoop,
552 auto inits = llvm::to_vector(getInitArgs());
553 inits.append(newInitOperands.begin(), newInitOperands.end());
554 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
559 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
561 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
566 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567 assert(newInitOperands.size() == newYieldedValues.size() &&
568 "expected as many new yield values as new iter operands");
570 yieldOp.getResultsMutable().append(newYieldedValues);
576 newLoop.getBody()->getArguments().take_front(
577 getBody()->getNumArguments()));
579 if (replaceInitOperandUsesInLoop) {
582 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
593 newLoop->getResults().take_front(getNumResults()));
594 return cast<LoopLikeOpInterface>(newLoop.getOperation());
598 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
601 assert(ivArg.getOwner() &&
"unlinked block argument");
602 auto *containingOp = ivArg.getOwner()->getParentOp();
603 return dyn_cast_or_null<ForOp>(containingOp);
607 return getInitArgs();
624 for (
auto [lb, ub, step] :
625 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
627 if (!tripCount.has_value() || *tripCount != 1)
636 return getBody()->getArguments().drop_front(getRank());
640 return getOutputsMutable();
646 scf::InParallelOp terminator = forallOp.getTerminator();
651 bbArgReplacements.append(forallOp.getOutputs().begin(),
652 forallOp.getOutputs().end());
656 forallOp->getIterator(), bbArgReplacements);
661 results.reserve(forallOp.getResults().size());
662 for (
auto &yieldingOp : terminator.getYieldingOps()) {
663 auto parallelInsertSliceOp =
664 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
666 Value dst = parallelInsertSliceOp.getDest();
667 Value src = parallelInsertSliceOp.getSource();
668 if (llvm::isa<TensorType>(src.
getType())) {
669 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
670 forallOp.getLoc(), dst.
getType(), src, dst,
671 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672 parallelInsertSliceOp.getStrides(),
673 parallelInsertSliceOp.getStaticOffsets(),
674 parallelInsertSliceOp.getStaticSizes(),
675 parallelInsertSliceOp.getStaticStrides()));
677 llvm_unreachable(
"unsupported terminator");
692 assert(lbs.size() == ubs.size() &&
693 "expected the same number of lower and upper bounds");
694 assert(lbs.size() == steps.size() &&
695 "expected the same number of lower bounds and steps");
700 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
702 assert(results.size() == iterArgs.size() &&
703 "loop nest body must return as many values as loop has iteration "
705 return LoopNest{{}, std::move(results)};
713 loops.reserve(lbs.size());
714 ivs.reserve(lbs.size());
717 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
718 auto loop = builder.
create<scf::ForOp>(
719 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725 currentIterArgs = args;
726 currentLoc = nestedLoc;
732 loops.push_back(loop);
736 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
738 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
745 ? bodyBuilder(builder, currentLoc, ivs,
746 loops.back().getRegionIterArgs())
748 assert(results.size() == iterArgs.size() &&
749 "loop nest body must return as many values as loop has iteration "
752 builder.
create<scf::YieldOp>(loc, results);
756 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757 return LoopNest{std::move(loops), std::move(nestResults)};
765 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
770 bodyBuilder(nestedBuilder, nestedLoc, ivs);
794 bool canonicalize =
false;
801 int64_t numResults = forOp.getNumResults();
803 keepMask.reserve(numResults);
806 newBlockTransferArgs.reserve(1 + numResults);
807 newBlockTransferArgs.push_back(
Value());
808 newIterArgs.reserve(forOp.getInitArgs().size());
809 newYieldValues.reserve(numResults);
810 newResultValues.reserve(numResults);
811 for (
auto it : llvm::zip(forOp.getInitArgs(),
812 forOp.getRegionIterArgs(),
814 forOp.getYieldedValues()
822 bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
823 (std::get<1>(it).use_empty() &&
824 (std::get<0>(it) == std::get<3>(it) ||
825 std::get<2>(it).use_empty())));
826 keepMask.push_back(!forwarded);
827 canonicalize |= forwarded;
829 newBlockTransferArgs.push_back(std::get<0>(it));
830 newResultValues.push_back(std::get<0>(it));
833 newIterArgs.push_back(std::get<0>(it));
834 newYieldValues.push_back(std::get<3>(it));
835 newBlockTransferArgs.push_back(
Value());
836 newResultValues.push_back(
Value());
842 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
843 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
844 forOp.getStep(), newIterArgs);
845 newForOp->
setAttrs(forOp->getAttrs());
846 Block &newBlock = newForOp.getRegion().
front();
850 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
852 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
853 Value &newResultVal = newResultValues[idx];
854 assert((blockTransferArg && newResultVal) ||
855 (!blockTransferArg && !newResultVal));
856 if (!blockTransferArg) {
857 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
858 newResultVal = newForOp.getResult(collapsedIdx++);
864 "unexpected argument size mismatch");
869 if (newIterArgs.empty()) {
870 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
873 rewriter.
replaceOp(forOp, newResultValues);
878 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
882 filteredOperands.reserve(newResultValues.size());
883 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
885 filteredOperands.push_back(mergedTerminator.getOperand(idx));
886 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
890 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
891 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
892 cloneFilteredTerminator(mergedYieldOp);
893 rewriter.
eraseOp(mergedYieldOp);
894 rewriter.
replaceOp(forOp, newResultValues);
902 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
903 IntegerAttr clb, cub;
905 llvm::APInt lbValue = clb.getValue();
906 llvm::APInt ubValue = cub.getValue();
907 return (ubValue - lbValue).getSExtValue();
916 return diff.getSExtValue();
930 if (op.getLowerBound() == op.getUpperBound()) {
931 rewriter.
replaceOp(op, op.getInitArgs());
935 std::optional<int64_t> diff =
936 computeConstDiff(op.getLowerBound(), op.getUpperBound());
942 rewriter.
replaceOp(op, op.getInitArgs());
946 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
952 llvm::APInt stepValue = *maybeStepValue;
953 if (stepValue.sge(*diff)) {
955 blockArgs.reserve(op.getInitArgs().size() + 1);
956 blockArgs.push_back(op.getLowerBound());
957 llvm::append_range(blockArgs, op.getInitArgs());
964 if (!llvm::hasSingleElement(block))
968 if (llvm::any_of(op.getYieldedValues(),
969 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
971 rewriter.
replaceOp(op, op.getYieldedValues());
983 assert(llvm::isa<RankedTensorType>(oldType) &&
984 llvm::isa<RankedTensorType>(newType) &&
985 "expected ranked tensor types");
988 ForOp forOp = cast<ForOp>(operand.
getOwner());
990 "expected an iter OpOperand");
992 "Expected a different type");
994 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
996 newIterOperands.push_back(replacement);
999 newIterOperands.push_back(opOperand.get());
1003 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
1004 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1005 forOp.getStep(), newIterOperands);
1006 newForOp->
setAttrs(forOp->getAttrs());
1007 Block &newBlock = newForOp.getRegion().
front();
1015 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
1017 Value castIn = rewriter.
create<tensor::CastOp>(newForOp.getLoc(), oldType,
1019 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
1023 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1026 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1029 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
1031 newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1033 newYieldOperands[yieldIdx] = castOut;
1034 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1035 rewriter.
eraseOp(clonedYieldOp);
1040 newResults[yieldIdx] = rewriter.
create<tensor::CastOp>(
1041 newForOp.getLoc(), oldType, newResults[yieldIdx]);
1077 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.
getResults())) {
1078 OpOperand &iterOpOperand = std::get<0>(it);
1080 if (!incomingCast ||
1081 incomingCast.getSource().getType() == incomingCast.getType())
1086 incomingCast.getDest().getType(),
1087 incomingCast.getSource().getType()))
1089 if (!std::get<1>(it).hasOneUse())
1094 op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1095 incomingCast.getSource()));
1106 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1110 std::optional<APInt> ForOp::getConstantStep() {
1113 return step.getValue();
1117 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1118 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1124 if (
auto constantStep = getConstantStep())
1125 if (*constantStep == 1)
1138 unsigned numLoops = getRank();
1140 if (getNumResults() != getOutputs().size())
1141 return emitOpError(
"produces ")
1142 << getNumResults() <<
" results, but has only "
1143 << getOutputs().size() <<
" outputs";
1146 auto *body = getBody();
1148 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1149 for (int64_t i = 0; i < numLoops; ++i)
1151 return emitOpError(
"expects ")
1152 << i <<
"-th block argument to be an index";
1153 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1155 return emitOpError(
"type mismatch between ")
1156 << i <<
"-th output and corresponding block argument";
1157 if (getMapping().has_value() && !getMapping()->empty()) {
1158 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1159 return emitOpError() <<
"mapping attribute size must match op rank";
1160 for (
auto map : getMapping()->getValue()) {
1161 if (!isa<DeviceMappingAttrInterface>(map))
1162 return emitOpError()
1170 getStaticLowerBound(),
1171 getDynamicLowerBound())))
1174 getStaticUpperBound(),
1175 getDynamicUpperBound())))
1178 getStaticStep(), getDynamicStep())))
1186 p <<
" (" << getInductionVars();
1187 if (isNormalized()) {
1208 if (!getRegionOutArgs().empty())
1209 p <<
"-> (" << getResultTypes() <<
") ";
1210 p.printRegion(getRegion(),
1212 getNumResults() > 0);
1213 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1214 getStaticLowerBoundAttrName(),
1215 getStaticUpperBoundAttrName(),
1216 getStaticStepAttrName()});
1221 auto indexType = b.getIndexType();
1241 unsigned numLoops = ivs.size();
1276 if (outOperands.size() != result.
types.size())
1278 "mismatch between out operands and types");
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1289 for (
auto &iv : ivs) {
1290 iv.type = b.getIndexType();
1291 regionArgs.push_back(iv);
1294 auto &out = it.value();
1295 out.type = result.
types[it.index()];
1296 regionArgs.push_back(out);
1302 ForallOp::ensureTerminator(*region, b, result.
location);
1314 {static_cast<int32_t>(dynamicLbs.size()),
1315 static_cast<int32_t>(dynamicUbs.size()),
1316 static_cast<int32_t>(dynamicSteps.size()),
1317 static_cast<int32_t>(outOperands.size())}));
1322 void ForallOp::build(
1326 std::optional<ArrayAttr> mapping,
1347 "operandSegmentSizes",
1349 static_cast<int32_t>(dynamicUbs.size()),
1350 static_cast<int32_t>(dynamicSteps.size()),
1351 static_cast<int32_t>(outputs.size())}));
1352 if (mapping.has_value()) {
1371 if (!bodyBuilderFn) {
1372 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1379 void ForallOp::build(
1382 std::optional<ArrayAttr> mapping,
1384 unsigned numLoops = ubs.size();
1387 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1391 bool ForallOp::isNormalized() {
1395 return intValue.has_value() && intValue == val;
1398 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1407 ForallOp>::ensureTerminator(region, builder, loc);
1414 InParallelOp ForallOp::getTerminator() {
1415 return cast<InParallelOp>(getBody()->getTerminator());
1418 std::optional<Value> ForallOp::getSingleInductionVar() {
1420 return std::nullopt;
1421 return getInductionVar(0);
1424 std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
1426 return std::nullopt;
1427 return getMixedLowerBound()[0];
1430 std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
1432 return std::nullopt;
1433 return getMixedUpperBound()[0];
1436 std::optional<OpFoldResult> ForallOp::getSingleStep() {
1438 return std::nullopt;
1439 return getMixedStep()[0];
1443 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1446 assert(tidxArg.getOwner() &&
"unlinked block argument");
1447 auto *containingOp = tidxArg.getOwner()->getParentOp();
1448 return dyn_cast<ForallOp>(containingOp);
1458 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1462 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1465 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1489 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1490 op.setStaticLowerBound(staticLowerBound);
1494 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1495 op.setStaticUpperBound(staticUpperBound);
1498 op.getDynamicStepMutable().assign(dynamicStep);
1499 op.setStaticStep(staticStep);
1501 op->
setAttr(ForallOp::getOperandSegmentSizeAttr(),
1503 {static_cast<int32_t>(dynamicLowerBound.size()),
1504 static_cast<int32_t>(dynamicUpperBound.size()),
1505 static_cast<int32_t>(dynamicStep.size()),
1506 static_cast<int32_t>(op.getNumResults())}));
1512 struct ForallOpSingleOrZeroIterationDimsFolder
1519 if (op.getMapping().has_value())
1527 for (
auto [lb, ub, step, iv] :
1528 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1529 op.getMixedStep(), op.getInductionVars())) {
1531 if (numIterations.has_value()) {
1533 if (*numIterations == 0) {
1534 rewriter.
replaceOp(op, op.getOutputs());
1539 if (*numIterations == 1) {
1544 newMixedLowerBounds.push_back(lb);
1545 newMixedUpperBounds.push_back(ub);
1546 newMixedSteps.push_back(step);
1549 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1551 op,
"no dimensions have 0 or 1 iterations");
1555 if (newMixedLowerBounds.empty()) {
1562 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1563 newMixedUpperBounds, newMixedSteps,
1564 op.getOutputs(), std::nullopt,
nullptr);
1565 newOp.getBodyRegion().getBlocks().clear();
1570 newOp.getStaticLowerBoundAttrName(),
1571 newOp.getStaticUpperBoundAttrName(),
1572 newOp.getStaticStepAttrName()};
1573 for (
const auto &namedAttr : op->
getAttrs()) {
1574 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1577 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1581 newOp.getRegion().
begin(), mapping);
1582 rewriter.
replaceOp(op, newOp.getResults());
1587 struct FoldTensorCastOfOutputIntoForallOp
1598 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1601 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1608 castOp.getSource().getType())) {
1612 tensorCastProducers[en.index()] =
1613 TypeCast{castOp.getSource().getType(), castOp.getType()};
1614 newOutputTensors[en.index()] = castOp.getSource();
1617 if (tensorCastProducers.empty())
1622 auto newForallOp = rewriter.
create<ForallOp>(
1623 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1624 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1626 auto castBlockArgs =
1627 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1628 for (
auto [index, cast] : tensorCastProducers) {
1629 Value &oldTypeBBArg = castBlockArgs[index];
1630 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1631 nestedLoc, cast.dstType, oldTypeBBArg);
1636 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1637 ivsBlockArgs.append(castBlockArgs);
1639 bbArgs.front().getParentBlock(), ivsBlockArgs);
1645 auto terminator = newForallOp.getTerminator();
1646 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1647 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1648 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1649 insertSliceOp.getDestMutable().assign(outputBlockArg);
1655 for (
auto &item : tensorCastProducers) {
1656 Value &oldTypeResult = castResults[item.first];
1657 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1660 rewriter.
replaceOp(forallOp, castResults);
1669 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1670 ForallOpControlOperandsFolder,
1671 ForallOpSingleOrZeroIterationDimsFolder>(context);
1700 scf::ForallOp forallOp =
1701 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1703 return this->emitOpError(
"expected forall op parent");
1706 for (
Operation &op : getRegion().front().getOperations()) {
1707 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1708 return this->emitOpError(
"expected only ")
1709 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1713 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1715 if (!llvm::is_contained(regionOutArgs, dest))
1716 return op.
emitOpError(
"may only insert into an output block argument");
1733 std::unique_ptr<Region> region = std::make_unique<Region>();
1737 if (region->empty())
1747 OpResult InParallelOp::getParentResult(int64_t idx) {
1748 return getOperation()->getParentOp()->getResult(idx);
1752 return llvm::to_vector<4>(
1753 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1755 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1756 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1761 return getRegion().front().getOperations();
1769 assert(a &&
"expected non-empty operation");
1770 assert(b &&
"expected non-empty operation");
1775 if (ifOp->isProperAncestor(b))
1778 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1779 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1781 ifOp = ifOp->getParentOfType<IfOp>();
1789 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1790 IfOp::Adaptor adaptor,
1792 if (adaptor.getRegions().empty())
1794 Region *r = &adaptor.getThenRegion();
1800 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
1803 TypeRange types = yieldOp.getOperandTypes();
1804 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1811 return build(builder, result, resultTypes, cond,
false,
1817 bool addElseBlock) {
1818 assert((!addElseBlock || addThenBlock) &&
1819 "must not create else block w/o then block");
1834 bool withElseRegion) {
1835 build(builder, result,
TypeRange{}, cond, withElseRegion);
1847 if (resultTypes.empty())
1848 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
1852 if (withElseRegion) {
1854 if (resultTypes.empty())
1855 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
1862 assert(thenBuilder &&
"the builder callback for 'then' must be present");
1869 thenBuilder(builder, result.
location);
1875 elseBuilder(builder, result.
location);
1884 inferredReturnTypes))) {
1885 result.
addTypes(inferredReturnTypes);
1890 if (getNumResults() != 0 && getElseRegion().empty())
1891 return emitOpError(
"must have an else block if defining values");
1929 bool printBlockTerminators =
false;
1931 p <<
" " << getCondition();
1932 if (!getResults().empty()) {
1933 p <<
" -> (" << getResultTypes() <<
")";
1935 printBlockTerminators =
true;
1940 printBlockTerminators);
1943 auto &elseRegion = getElseRegion();
1944 if (!elseRegion.
empty()) {
1948 printBlockTerminators);
1965 Region *elseRegion = &this->getElseRegion();
1966 if (elseRegion->
empty())
1974 FoldAdaptor adaptor(operands, *
this);
1975 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
1976 if (!boolAttr || boolAttr.getValue())
1977 regions.emplace_back(&getThenRegion());
1980 if (!boolAttr || !boolAttr.getValue()) {
1981 if (!getElseRegion().empty())
1982 regions.emplace_back(&getElseRegion());
1984 regions.emplace_back(getResults());
1991 if (getElseRegion().empty())
1994 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2001 getConditionMutable().assign(xorStmt.getLhs());
2005 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2006 getElseRegion().getBlocks());
2007 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2008 getThenRegion().getBlocks(), thenBlock);
2012 void IfOp::getRegionInvocationBounds(
2015 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2018 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2019 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2022 invocationBounds.assign(2, {0, 1});
2038 llvm::transform(usedResults, std::back_inserter(usedOperands),
2043 [&]() { yieldOp->setOperands(usedOperands); });
2050 llvm::copy_if(op.
getResults(), std::back_inserter(usedResults),
2051 [](
OpResult result) { return !result.use_empty(); });
2059 llvm::transform(usedResults, std::back_inserter(newTypes),
2064 rewriter.
create<IfOp>(op.
getLoc(), newTypes, op.getCondition());
2070 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2071 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2076 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2093 else if (!op.getElseRegion().empty())
2112 auto cond = op.getCondition();
2113 auto thenYieldArgs = op.thenYield().
getOperands();
2114 auto elseYieldArgs = op.elseYield().
getOperands();
2117 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2120 nonHoistable.push_back(trueVal.getType());
2127 IfOp replacement = rewriter.
create<IfOp>(op.
getLoc(), nonHoistable, cond,
2129 if (replacement.thenBlock())
2130 rewriter.
eraseBlock(replacement.thenBlock());
2131 replacement.getThenRegion().takeBody(op.getThenRegion());
2132 replacement.getElseRegion().takeBody(op.getElseRegion());
2135 assert(thenYieldArgs.size() == results.size());
2136 assert(elseYieldArgs.size() == results.size());
2141 for (
const auto &it :
2143 Value trueVal = std::get<0>(it.value());
2144 Value falseVal = std::get<1>(it.value());
2147 results[it.index()] = replacement.getResult(trueYields.size());
2148 trueYields.push_back(trueVal);
2149 falseYields.push_back(falseVal);
2150 }
else if (trueVal == falseVal)
2151 results[it.index()] = trueVal;
2153 results[it.index()] = rewriter.
create<arith::SelectOp>(
2154 op.
getLoc(), cond, trueVal, falseVal);
2191 bool changed =
false;
2196 Value constantTrue =
nullptr;
2197 Value constantFalse =
nullptr;
2200 llvm::make_early_inc_range(op.getCondition().
getUses())) {
2205 constantTrue = rewriter.
create<arith::ConstantOp>(
2209 [&]() { use.
set(constantTrue); });
2215 constantFalse = rewriter.
create<arith::ConstantOp>(
2219 [&]() { use.
set(constantFalse); });
2263 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2273 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2275 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2278 op.getOperation()->getIterator());
2279 bool changed =
false;
2281 for (
auto [trueResult, falseResult, opResult] :
2282 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2284 if (trueResult == falseResult) {
2285 if (!opResult.use_empty()) {
2286 opResult.replaceAllUsesWith(trueResult);
2297 bool trueVal = trueYield.
getValue();
2298 bool falseVal = falseYield.
getValue();
2299 if (!trueVal && falseVal) {
2300 if (!opResult.use_empty()) {
2301 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2303 op.
getLoc(), op.getCondition(),
2313 if (trueVal && !falseVal) {
2314 if (!opResult.use_empty()) {
2315 opResult.replaceAllUsesWith(op.getCondition());
2350 Block *parent = nextIf->getBlock();
2351 if (nextIf == &parent->
front())
2354 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2362 Block *nextThen =
nullptr;
2363 Block *nextElse =
nullptr;
2364 if (nextIf.getCondition() == prevIf.getCondition()) {
2365 nextThen = nextIf.thenBlock();
2366 if (!nextIf.getElseRegion().empty())
2367 nextElse = nextIf.elseBlock();
2369 if (arith::XOrIOp notv =
2370 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2371 if (notv.getLhs() == prevIf.getCondition() &&
2373 nextElse = nextIf.thenBlock();
2374 if (!nextIf.getElseRegion().empty())
2375 nextThen = nextIf.elseBlock();
2378 if (arith::XOrIOp notv =
2379 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2380 if (notv.getLhs() == nextIf.getCondition() &&
2382 nextElse = nextIf.thenBlock();
2383 if (!nextIf.getElseRegion().empty())
2384 nextThen = nextIf.elseBlock();
2388 if (!nextThen && !nextElse)
2392 if (!prevIf.getElseRegion().empty())
2393 prevElseYielded = prevIf.elseYield().getOperands();
2396 for (
auto it : llvm::zip(prevIf.getResults(),
2397 prevIf.thenYield().getOperands(), prevElseYielded))
2399 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2403 use.
set(std::get<1>(it));
2408 use.
set(std::get<2>(it));
2414 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2416 IfOp combinedIf = rewriter.
create<IfOp>(
2417 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2418 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2421 combinedIf.getThenRegion(),
2422 combinedIf.getThenRegion().begin());
2425 YieldOp thenYield = combinedIf.thenYield();
2426 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2427 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2431 llvm::append_range(mergedYields, thenYield2.getOperands());
2432 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2438 combinedIf.getElseRegion(),
2439 combinedIf.getElseRegion().begin());
2442 if (combinedIf.getElseRegion().empty()) {
2444 combinedIf.getElseRegion(),
2445 combinedIf.getElseRegion().
begin());
2447 YieldOp elseYield = combinedIf.elseYield();
2448 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2449 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2454 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2456 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2465 if (pair.index() < prevIf.getNumResults())
2466 prevValues.push_back(pair.value());
2468 nextValues.push_back(pair.value());
2483 if (ifOp.getNumResults())
2485 Block *elseBlock = ifOp.elseBlock();
2486 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2490 newIfOp.getThenRegion().begin());
2517 auto nestedOps = op.thenBlock()->without_terminator();
2519 if (!llvm::hasSingleElement(nestedOps))
2523 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2526 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2530 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2536 llvm::append_range(elseYield, op.elseYield().
getOperands());
2550 if (tup.value().getDefiningOp() == nestedIf) {
2551 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2552 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2553 elseYield[tup.index()]) {
2558 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2571 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2574 elseYieldsToUpgradeToSelect.push_back(tup.index());
2578 Value newCondition = rewriter.
create<arith::AndIOp>(
2579 loc, op.getCondition(), nestedIf.getCondition());
2584 llvm::append_range(results, newIf.getResults());
2587 for (
auto idx : elseYieldsToUpgradeToSelect)
2588 results[idx] = rewriter.
create<arith::SelectOp>(
2589 op.
getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2591 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2594 if (!elseYield.empty()) {
2597 rewriter.
create<YieldOp>(loc, elseYield);
2608 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2609 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2610 RemoveStaticCondition, RemoveUnusedResults,
2611 ReplaceIfYieldWithConditionOrValue>(context);
2614 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2615 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2616 Block *IfOp::elseBlock() {
2617 Region &r = getElseRegion();
2622 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2628 void ParallelOp::build(
2638 ParallelOp::getOperandSegmentSizeAttr(),
2640 static_cast<int32_t>(upperBounds.size()),
2641 static_cast<int32_t>(steps.size()),
2642 static_cast<int32_t>(initVals.size())}));
2646 unsigned numIVs = steps.size();
2652 if (bodyBuilderFn) {
2654 bodyBuilderFn(builder, result.
location,
2659 if (initVals.empty())
2660 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2663 void ParallelOp::build(
2670 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2673 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2677 wrapper = wrappedBuilderFn;
2679 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2688 if (stepValues.empty())
2690 "needs at least one tuple element for lowerBound, upperBound and step");
2693 for (
Value stepValue : stepValues)
2696 return emitOpError(
"constant step operand must be positive");
2700 Block *body = getBody();
2702 return emitOpError() <<
"expects the same number of induction variables: "
2704 <<
" as bound and step values: " << stepValues.size();
2706 if (!arg.getType().isIndex())
2708 "expects arguments for the induction variable to be of index type");
2711 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2712 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2717 auto resultsSize = getResults().size();
2718 auto reductionsSize = reduceOp.getReductions().size();
2719 auto initValsSize = getInitVals().size();
2720 if (resultsSize != reductionsSize)
2721 return emitOpError() <<
"expects number of results: " << resultsSize
2722 <<
" to be the same as number of reductions: "
2724 if (resultsSize != initValsSize)
2725 return emitOpError() <<
"expects number of results: " << resultsSize
2726 <<
" to be the same as number of initial values: "
2730 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2731 auto resultType = getOperation()->getResult(i).getType();
2732 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2733 if (resultType != reductionOperandType)
2734 return reduceOp.emitOpError()
2735 <<
"expects type of " << i
2736 <<
"-th reduction operand: " << reductionOperandType
2737 <<
" to be the same as the " << i
2738 <<
"-th result type: " << resultType;
2786 for (
auto &iv : ivs)
2793 ParallelOp::getOperandSegmentSizeAttr(),
2795 static_cast<int32_t>(upper.size()),
2796 static_cast<int32_t>(steps.size()),
2797 static_cast<int32_t>(initVals.size())}));
2806 ParallelOp::ensureTerminator(*body, builder, result.
location);
2811 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2812 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2813 if (!getInitVals().empty())
2814 p <<
" init (" << getInitVals() <<
")";
2819 (*this)->getAttrs(),
2820 ParallelOp::getOperandSegmentSizeAttr());
2825 std::optional<Value> ParallelOp::getSingleInductionVar() {
2826 if (getNumLoops() != 1)
2827 return std::nullopt;
2828 return getBody()->getArgument(0);
2831 std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
2832 if (getNumLoops() != 1)
2833 return std::nullopt;
2837 std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
2838 if (getNumLoops() != 1)
2839 return std::nullopt;
2843 std::optional<OpFoldResult> ParallelOp::getSingleStep() {
2844 if (getNumLoops() != 1)
2845 return std::nullopt;
2846 return getStep()[0];
2850 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2852 return ParallelOp();
2853 assert(ivArg.getOwner() &&
"unlinked block argument");
2854 auto *containingOp = ivArg.getOwner()->getParentOp();
2855 return dyn_cast<ParallelOp>(containingOp);
2860 struct ParallelOpSingleOrZeroIterationDimsFolder
2871 for (
auto [lb, ub, step, iv] :
2872 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2873 op.getInductionVars())) {
2875 if (numIterations.has_value()) {
2877 if (*numIterations == 0) {
2878 rewriter.
replaceOp(op, op.getInitVals());
2883 if (*numIterations == 1) {
2888 newLowerBounds.push_back(lb);
2889 newUpperBounds.push_back(ub);
2890 newSteps.push_back(step);
2893 if (newLowerBounds.size() == op.getLowerBound().size())
2896 if (newLowerBounds.empty()) {
2900 results.reserve(op.getInitVals().size());
2901 for (
auto &bodyOp : op.getBody()->without_terminator())
2902 rewriter.
clone(bodyOp, mapping);
2903 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
2904 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
2905 Block &reduceBlock = reduceOp.getReductions()[i].
front();
2906 auto initValIndex = results.size();
2907 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
2911 rewriter.
clone(reduceBodyOp, mapping);
2914 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
2915 results.push_back(result);
2923 rewriter.
create<ParallelOp>(op.
getLoc(), newLowerBounds, newUpperBounds,
2924 newSteps, op.getInitVals(),
nullptr);
2930 newOp.getRegion().
begin(), mapping);
2931 rewriter.
replaceOp(op, newOp.getResults());
2941 Block &outerBody = *op.getBody();
2945 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
2950 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2951 llvm::is_contained(innerOp.getUpperBound(), val) ||
2952 llvm::is_contained(innerOp.getStep(), val))
2956 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2961 Block &innerBody = *innerOp.getBody();
2962 assert(iterVals.size() ==
2970 builder.
clone(op, mapping);
2973 auto concatValues = [](
const auto &first,
const auto &second) {
2975 ret.reserve(first.size() + second.size());
2976 ret.assign(first.begin(), first.end());
2977 ret.append(second.begin(), second.end());
2981 auto newLowerBounds =
2982 concatValues(op.getLowerBound(), innerOp.getLowerBound());
2983 auto newUpperBounds =
2984 concatValues(op.getUpperBound(), innerOp.getUpperBound());
2985 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2988 newSteps, std::nullopt,
2999 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3008 void ParallelOp::getSuccessorRegions(
3026 for (
Value v : operands) {
3038 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3039 auto type = getOperands()[i].getType();
3042 return emitOpError() << i <<
"-th reduction has an empty body";
3045 return arg.getType() != type;
3047 return emitOpError() <<
"expected two block arguments with type " << type
3048 <<
" in the " << i <<
"-th reduction region";
3052 return emitOpError(
"reduction bodies must be terminated with an "
3053 "'scf.reduce.return' op");
3072 Block *reductionBody = getOperation()->getBlock();
3074 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3076 if (expectedResultType != getResult().getType())
3077 return emitOpError() <<
"must have type " << expectedResultType
3078 <<
" (the type of the reduction inputs)";
3088 ValueRange operands, BodyBuilderFn beforeBuilder,
3089 BodyBuilderFn afterBuilder) {
3097 beforeArgLocs.reserve(operands.size());
3098 for (
Value operand : operands) {
3099 beforeArgLocs.push_back(operand.getLoc());
3104 beforeRegion, {}, operands.getTypes(), beforeArgLocs);
3113 resultTypes, afterArgLocs);
3119 ConditionOp WhileOp::getConditionOp() {
3120 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3123 YieldOp WhileOp::getYieldOp() {
3124 return cast<YieldOp>(getAfterBody()->getTerminator());
3127 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3128 return getYieldOp().getResultsMutable();
3132 return getBeforeBody()->getArguments();
3136 return getAfterBody()->getArguments();
3140 return getBeforeArguments();
3144 assert(point == getBefore() &&
3145 "WhileOp is expected to branch only to the first region");
3153 regions.emplace_back(&getBefore(), getBefore().getArguments());
3157 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3158 "there are only two regions in a WhileOp");
3160 if (point == getAfter()) {
3161 regions.emplace_back(&getBefore(), getBefore().getArguments());
3165 regions.emplace_back(getResults());
3166 regions.emplace_back(&getAfter(), getAfter().getArguments());
3170 return {&getBefore(), &getAfter()};
3191 FunctionType functionType;
3196 result.
addTypes(functionType.getResults());
3198 if (functionType.getNumInputs() != operands.size()) {
3200 <<
"expected as many input types as operands "
3201 <<
"(expected " << operands.size() <<
" got "
3202 << functionType.getNumInputs() <<
")";
3212 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3213 regionArgs[i].type = functionType.getInput(i);
3235 template <
typename OpTy>
3238 if (left.size() != right.size())
3239 return op.
emitOpError(
"expects the same number of ") << message;
3241 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3242 if (left[i] != right[i]) {
3245 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3246 <<
" and " << right[i];
3255 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3257 "expects the 'before' region to terminate with 'scf.condition'");
3258 if (!beforeTerminator)
3261 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3263 "expects the 'after' region to terminate with 'scf.yield'");
3264 return success(afterTerminator !=
nullptr);
3292 auto term = op.getConditionOp();
3296 Value constantTrue =
nullptr;
3298 bool replaced =
false;
3299 for (
auto yieldedAndBlockArgs :
3300 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3301 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3302 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3304 constantTrue = rewriter.
create<arith::ConstantOp>(
3305 op.
getLoc(), term.getCondition().getType(),
3366 struct RemoveLoopInvariantArgsFromBeforeBlock
3372 Block &afterBlock = *op.getAfterBody();
3374 ConditionOp condOp = op.getConditionOp();
3379 bool canSimplify =
false;
3380 for (
const auto &it :
3382 auto index =
static_cast<unsigned>(it.index());
3383 auto [initVal, yieldOpArg] = it.value();
3386 if (yieldOpArg == initVal) {
3395 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3396 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3397 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3398 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3411 for (
const auto &it :
3413 auto index =
static_cast<unsigned>(it.index());
3414 auto [initVal, yieldOpArg] = it.value();
3418 if (yieldOpArg == initVal) {
3419 beforeBlockInitValMap.insert({index, initVal});
3427 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3428 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3429 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3430 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3431 beforeBlockInitValMap.insert({index, initVal});
3436 newInitArgs.emplace_back(initVal);
3437 newYieldOpArgs.emplace_back(yieldOpArg);
3438 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3451 &newWhile.getBefore(), {},
3454 Block &beforeBlock = *op.getBeforeBody();
3461 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3464 if (beforeBlockInitValMap.count(i) != 0)
3465 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3467 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3470 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3472 newWhile.getAfter().begin());
3474 rewriter.
replaceOp(op, newWhile.getResults());
3519 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3524 Block &beforeBlock = *op.getBeforeBody();
3525 ConditionOp condOp = op.getConditionOp();
3528 bool canSimplify =
false;
3529 for (
Value condOpArg : condOpArgs) {
3549 auto index =
static_cast<unsigned>(it.index());
3550 Value condOpArg = it.value();
3555 condOpInitValMap.insert({index, condOpArg});
3557 newCondOpArgs.emplace_back(condOpArg);
3558 newAfterBlockType.emplace_back(condOpArg.
getType());
3559 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3570 auto newWhile = rewriter.
create<WhileOp>(op.
getLoc(), newAfterBlockType,
3573 Block &newAfterBlock =
3575 newAfterBlockType, newAfterBlockArgLocs);
3577 Block &afterBlock = *op.getAfterBody();
3584 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3585 Value afterBlockArg, result;
3588 if (condOpInitValMap.count(i) != 0) {
3589 afterBlockArg = condOpInitValMap[i];
3590 result = afterBlockArg;
3592 afterBlockArg = newAfterBlock.getArgument(
j);
3593 result = newWhile.getResult(
j);
3596 newAfterBlockArgs[i] = afterBlockArg;
3597 newWhileResults[i] = result;
3600 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3602 newWhile.getBefore().begin());
3604 rewriter.
replaceOp(op, newWhileResults);
3640 auto term = op.getConditionOp();
3641 auto afterArgs = op.getAfterArguments();
3642 auto termArgs = term.getArgs();
3649 bool needUpdate =
false;
3650 for (
const auto &it :
3652 auto i =
static_cast<unsigned>(it.index());
3653 Value result = std::get<0>(it.value());
3654 Value afterArg = std::get<1>(it.value());
3655 Value termArg = std::get<2>(it.value());
3659 newResultsIndices.emplace_back(i);
3660 newTermArgs.emplace_back(termArg);
3661 newResultTypes.emplace_back(result.
getType());
3662 newArgLocs.emplace_back(result.
getLoc());
3677 rewriter.
create<WhileOp>(op.
getLoc(), newResultTypes, op.getInits());
3680 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3687 newResults[it.value()] = newWhile.getResult(it.index());
3688 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3692 newWhile.getBefore().begin());
3694 Block &afterBlock = *op.getAfterBody();
3695 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3729 using namespace scf;
3730 auto cond = op.getConditionOp();
3731 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3734 bool changed =
false;
3735 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3736 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3737 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3740 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3741 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3745 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3748 if (cmp2.getPredicate() == cmp.getPredicate())
3749 samePredicate =
true;
3750 else if (cmp2.getPredicate() ==
3752 samePredicate =
false;
3773 if (!llvm::any_of(op.getBeforeArguments(),
3774 [](
Value arg) { return arg.use_empty(); }))
3777 YieldOp yield = op.getYieldOp();
3782 llvm::BitVector argsToErase;
3784 size_t argsCount = op.getBeforeArguments().size();
3785 newYields.reserve(argsCount);
3786 newInits.reserve(argsCount);
3787 argsToErase.reserve(argsCount);
3788 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3789 op.getBeforeArguments(), yield.
getOperands(), op.getInits())) {
3790 if (beforeArg.use_empty()) {
3791 argsToErase.push_back(
true);
3793 argsToErase.push_back(
false);
3794 newYields.emplace_back(yieldValue);
3795 newInits.emplace_back(initValue);
3799 Block &beforeBlock = *op.getBeforeBody();
3800 Block &afterBlock = *op.getAfterBody();
3808 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3809 Block &newAfterBlock = *newWhileOp.getAfterBody();
3815 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
3816 newBeforeBlock.getArguments());
3820 rewriter.
replaceOp(op, newWhileOp.getResults());
3831 ConditionOp condOp = op.getConditionOp();
3835 for (
Value arg : condOpArgs)
3836 argsSet.insert(arg);
3838 if (argsSet.size() == condOpArgs.size())
3841 llvm::SmallDenseMap<Value, unsigned> argsMap;
3843 argsMap.reserve(condOpArgs.size());
3844 newArgs.reserve(condOpArgs.size());
3845 for (
Value arg : condOpArgs) {
3846 if (!argsMap.count(arg)) {
3847 auto pos =
static_cast<unsigned>(argsMap.size());
3848 argsMap.insert({arg, pos});
3849 newArgs.emplace_back(arg);
3856 auto newWhileOp = rewriter.
create<scf::WhileOp>(
3857 loc, argsRange.getTypes(), op.getInits(),
nullptr,
3859 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3860 Block &newAfterBlock = *newWhileOp.getAfterBody();
3865 auto it = argsMap.find(arg);
3866 assert(it != argsMap.end());
3867 auto pos = it->second;
3868 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
3869 resultsMapping.emplace_back(newWhileOp->getResult(pos));
3877 Block &beforeBlock = *op.getBeforeBody();
3878 Block &afterBlock = *op.getAfterBody();
3880 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
3881 newBeforeBlock.getArguments());
3882 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
3890 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
3892 if (args1.size() != args2.size())
3893 return std::nullopt;
3897 auto it = llvm::find(args2, arg1);
3898 if (it == args2.end())
3899 return std::nullopt;
3901 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
3908 llvm::SmallDenseSet<Value> set;
3909 for (
Value arg : args) {
3910 if (set.contains(arg))
3927 auto oldBefore = loop.getBeforeBody();
3928 ConditionOp oldTerm = loop.getConditionOp();
3929 ValueRange beforeArgs = oldBefore->getArguments();
3931 if (beforeArgs == termArgs)
3934 if (hasDuplicates(termArgs))
3937 auto mapping = getArgsMapping(beforeArgs, termArgs);
3948 auto oldAfter = loop.getAfterBody();
3952 newResultTypes[
j] = loop.getResult(i).getType();
3954 auto newLoop = rewriter.
create<WhileOp>(
3955 loop.getLoc(), newResultTypes, loop.getInits(),
3957 auto newBefore = newLoop.getBeforeBody();
3958 auto newAfter = newLoop.getAfterBody();
3963 newResults[i] = newLoop.getResult(
j);
3964 newAfterArgs[i] = newAfter->getArgument(
j);
3968 newBefore->getArguments());
3980 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
3981 RemoveLoopInvariantValueYielded, WhileConditionTruth,
3982 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3983 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
3997 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4000 caseValues.push_back(value);
4009 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4011 p <<
"case " << value <<
' ';
4017 if (getCases().size() != getCaseRegions().size()) {
4018 return emitOpError(
"has ")
4019 << getCaseRegions().size() <<
" case regions but "
4020 << getCases().size() <<
" case values";
4024 for (int64_t value : getCases())
4025 if (!valueSet.insert(value).second)
4026 return emitOpError(
"has duplicate case value: ") << value;
4028 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4030 return emitOpError(
"expected region to end with scf.yield, but got ")
4033 if (yield.getNumOperands() != getNumResults()) {
4034 return (emitOpError(
"expected each region to return ")
4035 << getNumResults() <<
" values, but " << name <<
" returns "
4036 << yield.getNumOperands())
4037 .attachNote(yield.getLoc())
4038 <<
"see yield operation here";
4040 for (
auto [idx, result, operand] :
4041 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4042 yield.getOperandTypes())) {
4043 if (result == operand)
4045 return (emitOpError(
"expected result #")
4046 << idx <<
" of each region to be " << result)
4047 .attachNote(yield.getLoc())
4048 << name <<
" returns " << operand <<
" here";
4053 if (
failed(verifyRegion(getDefaultRegion(),
"default region")))
4056 if (
failed(verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4062 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4064 Block &scf::IndexSwitchOp::getDefaultBlock() {
4065 return getDefaultRegion().
front();
4068 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4069 assert(idx < getNumCases() &&
"case index out-of-bounds");
4070 return getCaseRegions()[idx].front();
4073 void IndexSwitchOp::getSuccessorRegions(
4077 successors.emplace_back(getResults());
4081 llvm::copy(getRegions(), std::back_inserter(successors));
4084 void IndexSwitchOp::getEntrySuccessorRegions(
4087 FoldAdaptor adaptor(operands, *
this);
4090 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4092 llvm::copy(getRegions(), std::back_inserter(successors));
4098 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4099 if (caseValue == arg.getInt()) {
4100 successors.emplace_back(&caseRegion);
4104 successors.emplace_back(&getDefaultRegion());
4107 void IndexSwitchOp::getRegionInvocationBounds(
4109 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4110 if (!operandValue) {
4116 unsigned liveIndex = getNumRegions() - 1;
4117 const auto *it = llvm::find(getCases(), operandValue.getInt());
4118 if (it != getCases().end())
4119 liveIndex = std::distance(getCases().begin(), it);
4120 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4121 bounds.emplace_back(0, i == liveIndex);
4127 if (!maybeCst.has_value())
4129 int64_t cst = *maybeCst;
4130 int64_t caseIdx, e = getNumCases();
4131 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4132 if (cst == getCases()[caseIdx])
4136 Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
4137 : getDefaultRegion();
4142 Block *pDestination = (*this)->getBlock();
4157 #define GET_OP_CLASSES
4158 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.