24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/TypeSwitch.h"
31 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
54 auto retValOp = dyn_cast<scf::YieldOp>(op);
58 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
59 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
69 void SCFDialect::initialize() {
72 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74 addInterfaces<SCFInlinerInterface>();
79 builder.
create<scf::YieldOp>(loc);
84 template <
typename TerminatorTy>
86 StringRef errorMessage) {
89 terminatorOperation = ®ion.
front().
back();
90 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
94 if (terminatorOperation)
95 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
107 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
153 if (getRegion().empty())
154 return emitOpError(
"region needs to have at least one block");
155 if (getRegion().front().getNumArguments() > 0)
156 return emitOpError(
"region cannot have any arguments");
179 if (!llvm::hasSingleElement(op.
getRegion()))
228 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->
getParentOp()))
238 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
240 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
241 yieldOp.getResults());
250 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
262 void ExecuteRegionOp::getSuccessorRegions(
280 assert((point.
isParent() || point == getParentOp().getAfter()) &&
281 "condition op can only exit the loop or branch to the after"
284 return getArgsMutable();
287 void ConditionOp::getSuccessorRegions(
289 FoldAdaptor adaptor(operands, *
this);
291 WhileOp whileOp = getParentOp();
295 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
296 if (!boolAttr || boolAttr.getValue())
297 regions.emplace_back(&whileOp.getAfter(),
298 whileOp.getAfter().getArguments());
299 if (!boolAttr || !boolAttr.getValue())
300 regions.emplace_back(whileOp.getResults());
309 BodyBuilderFn bodyBuilder) {
312 for (
Value v : iterArgs)
319 for (
Value v : iterArgs)
325 if (iterArgs.empty() && !bodyBuilder) {
326 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
327 }
else if (bodyBuilder) {
338 return emitOpError(
"constant step operand must be positive");
341 if (getInitArgs().size() != getNumResults())
343 "mismatch in number of loop-carried values and defined values");
351 if (getInductionVar().getType() !=
getLowerBound().getType())
353 "expected induction variable to be same type as bounds and step");
355 if (getNumRegionIterArgs() != getNumResults())
357 "mismatch in number of basic block args and defined values");
359 auto initArgs = getInitArgs();
360 auto iterArgs = getRegionIterArgs();
361 auto opResults = getResults();
363 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
364 if (std::get<0>(e).getType() != std::get<2>(e).getType())
365 return emitOpError() <<
"types mismatch between " << i
366 <<
"th iter operand and defined value";
367 if (std::get<1>(e).getType() != std::get<2>(e).getType())
368 return emitOpError() <<
"types mismatch between " << i
369 <<
"th iter region arg and defined value";
376 std::optional<Value> ForOp::getSingleInductionVar() {
377 return getInductionVar();
380 std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
384 std::optional<OpFoldResult> ForOp::getSingleStep() {
388 std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
392 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
397 std::optional<int64_t> tripCount =
399 if (!tripCount.has_value() || tripCount != 1)
403 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
410 llvm::append_range(bbArgReplacements, getInitArgs());
414 getOperation()->getIterator(), bbArgReplacements);
430 StringRef prefix =
"") {
431 assert(blocksArgs.size() == initializers.size() &&
432 "expected same length of arguments and initializers");
433 if (initializers.empty())
437 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
438 p << std::get<0>(it) <<
" = " << std::get<1>(it);
444 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
448 if (!getInitArgs().empty())
449 p <<
" -> (" << getInitArgs().getTypes() <<
')';
451 if (
Type t = getInductionVar().getType(); !t.
isIndex())
452 p <<
" : " << t <<
' ';
455 !getInitArgs().empty());
477 regionArgs.push_back(inductionVariable);
487 if (regionArgs.size() != result.
types.size() + 1)
490 "mismatch in number of loop-carried values and defined values");
499 regionArgs.front().type = type;
505 for (
auto argOperandType :
506 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
507 Type type = std::get<2>(argOperandType);
508 std::get<0>(argOperandType).type = type;
520 ForOp::ensureTerminator(*body, builder, result.
location);
532 return getInitArgsMutable();
536 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
538 bool replaceInitOperandUsesInLoop,
543 auto inits = llvm::to_vector(getInitArgs());
544 inits.append(newInitOperands.begin(), newInitOperands.end());
545 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
550 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
552 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
557 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
558 assert(newInitOperands.size() == newYieldedValues.size() &&
559 "expected as many new yield values as new iter operands");
561 yieldOp.getResultsMutable().append(newYieldedValues);
567 newLoop.getBody()->getArguments().take_front(
568 getBody()->getNumArguments()));
570 if (replaceInitOperandUsesInLoop) {
573 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
584 newLoop->getResults().take_front(getNumResults()));
585 return cast<LoopLikeOpInterface>(newLoop.getOperation());
589 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
592 assert(ivArg.getOwner() &&
"unlinked block argument");
593 auto *containingOp = ivArg.getOwner()->getParentOp();
594 return dyn_cast_or_null<ForOp>(containingOp);
598 return getInitArgs();
615 for (
auto [lb, ub, step] :
616 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
618 if (!tripCount.has_value() || *tripCount != 1)
629 scf::InParallelOp terminator = forallOp.getTerminator();
634 bbArgReplacements.append(forallOp.getOutputs().begin(),
635 forallOp.getOutputs().end());
639 forallOp->getIterator(), bbArgReplacements);
644 results.reserve(forallOp.getResults().size());
645 for (
auto &yieldingOp : terminator.getYieldingOps()) {
646 auto parallelInsertSliceOp =
647 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
649 Value dst = parallelInsertSliceOp.getDest();
650 Value src = parallelInsertSliceOp.getSource();
651 if (llvm::isa<TensorType>(src.
getType())) {
652 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
653 forallOp.getLoc(), dst.
getType(), src, dst,
654 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
655 parallelInsertSliceOp.getStrides(),
656 parallelInsertSliceOp.getStaticOffsets(),
657 parallelInsertSliceOp.getStaticSizes(),
658 parallelInsertSliceOp.getStaticStrides()));
660 llvm_unreachable(
"unsupported terminator");
675 assert(lbs.size() == ubs.size() &&
676 "expected the same number of lower and upper bounds");
677 assert(lbs.size() == steps.size() &&
678 "expected the same number of lower bounds and steps");
683 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
685 assert(results.size() == iterArgs.size() &&
686 "loop nest body must return as many values as loop has iteration "
688 return LoopNest{{}, std::move(results)};
696 loops.reserve(lbs.size());
697 ivs.reserve(lbs.size());
700 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
701 auto loop = builder.
create<scf::ForOp>(
702 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
708 currentIterArgs = args;
709 currentLoc = nestedLoc;
715 loops.push_back(loop);
719 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
721 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
728 ? bodyBuilder(builder, currentLoc, ivs,
729 loops.back().getRegionIterArgs())
731 assert(results.size() == iterArgs.size() &&
732 "loop nest body must return as many values as loop has iteration "
735 builder.
create<scf::YieldOp>(loc, results);
739 llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
740 return LoopNest{std::move(loops), std::move(nestResults)};
748 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
753 bodyBuilder(nestedBuilder, nestedLoc, ivs);
777 bool canonicalize =
false;
784 int64_t numResults = forOp.getNumResults();
786 keepMask.reserve(numResults);
789 newBlockTransferArgs.reserve(1 + numResults);
790 newBlockTransferArgs.push_back(
Value());
791 newIterArgs.reserve(forOp.getInitArgs().size());
792 newYieldValues.reserve(numResults);
793 newResultValues.reserve(numResults);
794 for (
auto it : llvm::zip(forOp.getInitArgs(),
795 forOp.getRegionIterArgs(),
797 forOp.getYieldedValues()
805 bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
806 (std::get<1>(it).use_empty() &&
807 (std::get<0>(it) == std::get<3>(it) ||
808 std::get<2>(it).use_empty())));
809 keepMask.push_back(!forwarded);
810 canonicalize |= forwarded;
812 newBlockTransferArgs.push_back(std::get<0>(it));
813 newResultValues.push_back(std::get<0>(it));
816 newIterArgs.push_back(std::get<0>(it));
817 newYieldValues.push_back(std::get<3>(it));
818 newBlockTransferArgs.push_back(
Value());
819 newResultValues.push_back(
Value());
825 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
826 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
827 forOp.getStep(), newIterArgs);
828 newForOp->
setAttrs(forOp->getAttrs());
829 Block &newBlock = newForOp.getRegion().
front();
833 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
835 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
836 Value &newResultVal = newResultValues[idx];
837 assert((blockTransferArg && newResultVal) ||
838 (!blockTransferArg && !newResultVal));
839 if (!blockTransferArg) {
840 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
841 newResultVal = newForOp.getResult(collapsedIdx++);
847 "unexpected argument size mismatch");
852 if (newIterArgs.empty()) {
853 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
856 rewriter.
replaceOp(forOp, newResultValues);
861 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
865 filteredOperands.reserve(newResultValues.size());
866 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
868 filteredOperands.push_back(mergedTerminator.getOperand(idx));
869 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
873 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
874 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
875 cloneFilteredTerminator(mergedYieldOp);
876 rewriter.
eraseOp(mergedYieldOp);
877 rewriter.
replaceOp(forOp, newResultValues);
885 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
886 IntegerAttr clb, cub;
888 llvm::APInt lbValue = clb.getValue();
889 llvm::APInt ubValue = cub.getValue();
890 return (ubValue - lbValue).getSExtValue();
899 return diff.getSExtValue();
913 if (op.getLowerBound() == op.getUpperBound()) {
914 rewriter.
replaceOp(op, op.getInitArgs());
918 std::optional<int64_t> diff =
919 computeConstDiff(op.getLowerBound(), op.getUpperBound());
925 rewriter.
replaceOp(op, op.getInitArgs());
929 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
935 llvm::APInt stepValue = *maybeStepValue;
936 if (stepValue.sge(*diff)) {
938 blockArgs.reserve(op.getInitArgs().size() + 1);
939 blockArgs.push_back(op.getLowerBound());
940 llvm::append_range(blockArgs, op.getInitArgs());
947 if (!llvm::hasSingleElement(block))
951 if (llvm::any_of(op.getYieldedValues(),
952 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
954 rewriter.
replaceOp(op, op.getYieldedValues());
966 assert(llvm::isa<RankedTensorType>(oldType) &&
967 llvm::isa<RankedTensorType>(newType) &&
968 "expected ranked tensor types");
971 ForOp forOp = cast<ForOp>(operand.
getOwner());
973 "expected an iter OpOperand");
975 "Expected a different type");
977 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
979 newIterOperands.push_back(replacement);
982 newIterOperands.push_back(opOperand.get());
986 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
987 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
988 forOp.getStep(), newIterOperands);
989 newForOp->
setAttrs(forOp->getAttrs());
990 Block &newBlock = newForOp.getRegion().
front();
998 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
1000 Value castIn = rewriter.
create<tensor::CastOp>(newForOp.getLoc(), oldType,
1002 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
1006 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1009 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1012 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
1014 newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1016 newYieldOperands[yieldIdx] = castOut;
1017 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1018 rewriter.
eraseOp(clonedYieldOp);
1023 newResults[yieldIdx] = rewriter.
create<tensor::CastOp>(
1024 newForOp.getLoc(), oldType, newResults[yieldIdx]);
1060 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.
getResults())) {
1061 OpOperand &iterOpOperand = std::get<0>(it);
1063 if (!incomingCast ||
1064 incomingCast.getSource().getType() == incomingCast.getType())
1069 incomingCast.getDest().getType(),
1070 incomingCast.getSource().getType()))
1072 if (!std::get<1>(it).hasOneUse())
1077 op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1078 incomingCast.getSource()));
1140 assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
1141 "unexpected multiple blocks");
1146 unsigned idx = bbArg.getArgNumber() - 1;
1148 cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
1149 Value yieldVal = yieldOp->getOperand(idx);
1150 auto tensorLoadOp = yieldVal.
getDefiningOp<bufferization::ToTensorOp>();
1151 bool isTensor = llvm::isa<TensorType>(bbArg.getType());
1153 bufferization::ToMemrefOp tensorToMemref;
1155 if (bbArg.hasOneUse())
1157 dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
1158 if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
1161 if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
1168 if (tensorLoadOp->getNextNode() != yieldOp)
1172 if (tensorToMemref) {
1175 tensorToMemref, tensorToMemref.getMemref().getType(),
1176 tensorToMemref.getTensor());
1181 Value newTensorLoad = rewriter.
create<bufferization::ToTensorOp>(
1182 loc, tensorLoadOp.getMemref());
1183 Value forOpResult = forOp.getResult(bbArg.getArgNumber() - 1);
1184 replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1189 yieldOp.setOperand(idx, bbArg);
1192 if (replacements.empty())
1200 newResults.reserve(forOp.getNumResults());
1201 for (
Value v : forOp.getResults()) {
1202 auto it = replacements.find(v);
1203 newResults.push_back((it != replacements.end()) ? it->second : v);
1207 return op.get() != newResults[idx++];
1216 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1217 LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1220 std::optional<APInt> ForOp::getConstantStep() {
1223 return step.getValue();
1228 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1234 if (
auto constantStep = getConstantStep())
1235 if (*constantStep == 1)
1248 unsigned numLoops = getRank();
1250 if (getNumResults() != getOutputs().size())
1251 return emitOpError(
"produces ")
1252 << getNumResults() <<
" results, but has only "
1253 << getOutputs().size() <<
" outputs";
1256 auto *body = getBody();
1258 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1259 for (int64_t i = 0; i < numLoops; ++i)
1261 return emitOpError(
"expects ")
1262 << i <<
"-th block argument to be an index";
1263 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1265 return emitOpError(
"type mismatch between ")
1266 << i <<
"-th output and corresponding block argument";
1267 if (getMapping().has_value() && !getMapping()->empty()) {
1268 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1269 return emitOpError() <<
"mapping attribute size must match op rank";
1270 for (
auto map : getMapping()->getValue()) {
1271 if (!isa<DeviceMappingAttrInterface>(map))
1272 return emitOpError()
1280 getStaticLowerBound(),
1281 getDynamicLowerBound())))
1284 getStaticUpperBound(),
1285 getDynamicUpperBound())))
1288 getStaticStep(), getDynamicStep())))
1296 p <<
" (" << getInductionVars();
1297 if (isNormalized()) {
1318 if (!getRegionOutArgs().empty())
1319 p <<
"-> (" << getResultTypes() <<
") ";
1320 p.printRegion(getRegion(),
1322 getNumResults() > 0);
1323 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1324 getStaticLowerBoundAttrName(),
1325 getStaticUpperBoundAttrName(),
1326 getStaticStepAttrName()});
1331 auto indexType = b.getIndexType();
1351 unsigned numLoops = ivs.size();
1386 if (outOperands.size() != result.
types.size())
1388 "mismatch between out operands and types");
1398 std::unique_ptr<Region> region = std::make_unique<Region>();
1399 for (
auto &iv : ivs) {
1400 iv.type = b.getIndexType();
1401 regionArgs.push_back(iv);
1404 auto &out = it.value();
1405 out.type = result.
types[it.index()];
1406 regionArgs.push_back(out);
1412 ForallOp::ensureTerminator(*region, b, result.
location);
1424 {static_cast<int32_t>(dynamicLbs.size()),
1425 static_cast<int32_t>(dynamicUbs.size()),
1426 static_cast<int32_t>(dynamicSteps.size()),
1427 static_cast<int32_t>(outOperands.size())}));
1432 void ForallOp::build(
1436 std::optional<ArrayAttr> mapping,
1457 "operandSegmentSizes",
1459 static_cast<int32_t>(dynamicUbs.size()),
1460 static_cast<int32_t>(dynamicSteps.size()),
1461 static_cast<int32_t>(outputs.size())}));
1462 if (mapping.has_value()) {
1481 if (!bodyBuilderFn) {
1482 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1487 auto terminator = llvm::dyn_cast<InParallelOp>(bodyBlock.
getTerminator());
1488 assert(terminator &&
1489 "expected bodyBuilderFn to create InParallelOp terminator");
1494 void ForallOp::build(
1497 std::optional<ArrayAttr> mapping,
1499 unsigned numLoops = ubs.size();
1502 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1506 bool ForallOp::isNormalized() {
1510 return intValue.has_value() && intValue == val;
1513 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1522 ForallOp>::ensureTerminator(region, builder, loc);
1529 InParallelOp ForallOp::getTerminator() {
1530 return cast<InParallelOp>(getBody()->getTerminator());
1533 std::optional<Value> ForallOp::getSingleInductionVar() {
1535 return std::nullopt;
1536 return getInductionVar(0);
1539 std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
1541 return std::nullopt;
1542 return getMixedLowerBound()[0];
1545 std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
1547 return std::nullopt;
1548 return getMixedUpperBound()[0];
1551 std::optional<OpFoldResult> ForallOp::getSingleStep() {
1553 return std::nullopt;
1554 return getMixedStep()[0];
1558 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1561 assert(tidxArg.getOwner() &&
"unlinked block argument");
1562 auto *containingOp = tidxArg.getOwner()->getParentOp();
1563 return dyn_cast<ForallOp>(containingOp);
1573 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1577 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1580 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1604 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1605 op.setStaticLowerBound(staticLowerBound);
1609 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1610 op.setStaticUpperBound(staticUpperBound);
1613 op.getDynamicStepMutable().assign(dynamicStep);
1614 op.setStaticStep(staticStep);
1616 op->
setAttr(ForallOp::getOperandSegmentSizeAttr(),
1618 {static_cast<int32_t>(dynamicLowerBound.size()),
1619 static_cast<int32_t>(dynamicUpperBound.size()),
1620 static_cast<int32_t>(dynamicStep.size()),
1621 static_cast<int32_t>(op.getNumResults())}));
1627 struct ForallOpSingleOrZeroIterationDimsFolder
1634 if (op.getMapping().has_value())
1642 for (
auto [lb, ub, step, iv] :
1643 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1644 op.getMixedStep(), op.getInductionVars())) {
1646 if (numIterations.has_value()) {
1648 if (*numIterations == 0) {
1649 rewriter.
replaceOp(op, op.getOutputs());
1654 if (*numIterations == 1) {
1659 newMixedLowerBounds.push_back(lb);
1660 newMixedUpperBounds.push_back(ub);
1661 newMixedSteps.push_back(step);
1664 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1666 op,
"no dimensions have 0 or 1 iterations");
1670 if (newMixedLowerBounds.empty()) {
1677 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1678 newMixedUpperBounds, newMixedSteps,
1679 op.getOutputs(), std::nullopt,
nullptr);
1680 newOp.getBodyRegion().getBlocks().clear();
1685 newOp.getStaticLowerBoundAttrName(),
1686 newOp.getStaticUpperBoundAttrName(),
1687 newOp.getStaticStepAttrName()};
1688 for (
const auto &namedAttr : op->
getAttrs()) {
1689 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1692 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1696 newOp.getRegion().
begin(), mapping);
1697 rewriter.
replaceOp(op, newOp.getResults());
1702 struct FoldTensorCastOfOutputIntoForallOp
1713 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1716 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1723 castOp.getSource().getType())) {
1727 tensorCastProducers[en.index()] =
1728 TypeCast{castOp.getSource().getType(), castOp.getType()};
1729 newOutputTensors[en.index()] = castOp.getSource();
1732 if (tensorCastProducers.empty())
1737 auto newForallOp = rewriter.
create<ForallOp>(
1738 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1739 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1741 auto castBlockArgs =
1742 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1743 for (
auto [index, cast] : tensorCastProducers) {
1744 Value &oldTypeBBArg = castBlockArgs[index];
1745 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1746 nestedLoc, cast.dstType, oldTypeBBArg);
1751 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1752 ivsBlockArgs.append(castBlockArgs);
1754 bbArgs.front().getParentBlock(), ivsBlockArgs);
1760 auto terminator = newForallOp.getTerminator();
1761 for (
auto [yieldingOp, outputBlockArg] :
1762 llvm::zip(terminator.getYieldingOps(),
1763 newForallOp.getOutputBlockArguments())) {
1764 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1765 insertSliceOp.getDestMutable().assign(outputBlockArg);
1771 for (
auto &item : tensorCastProducers) {
1772 Value &oldTypeResult = castResults[item.first];
1773 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1776 rewriter.
replaceOp(forallOp, castResults);
1785 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1786 ForallOpControlOperandsFolder,
1787 ForallOpSingleOrZeroIterationDimsFolder>(context);
1816 scf::ForallOp forallOp =
1817 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1819 return this->emitOpError(
"expected forall op parent");
1822 for (
Operation &op : getRegion().front().getOperations()) {
1823 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1824 return this->emitOpError(
"expected only ")
1825 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1829 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1831 if (!llvm::is_contained(regionOutArgs, dest))
1832 return op.
emitOpError(
"may only insert into an output block argument");
1849 std::unique_ptr<Region> region = std::make_unique<Region>();
1853 if (region->empty())
1863 OpResult InParallelOp::getParentResult(int64_t idx) {
1864 return getOperation()->getParentOp()->getResult(idx);
1868 return llvm::to_vector<4>(
1869 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1871 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1872 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1877 return getRegion().front().getOperations();
1885 assert(a &&
"expected non-empty operation");
1886 assert(b &&
"expected non-empty operation");
1891 if (ifOp->isProperAncestor(b))
1894 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1895 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1897 ifOp = ifOp->getParentOfType<IfOp>();
1905 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1906 IfOp::Adaptor adaptor,
1908 if (adaptor.getRegions().empty())
1910 Region *r = &adaptor.getThenRegion();
1916 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
1919 TypeRange types = yieldOp.getOperandTypes();
1920 inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1927 return build(builder, result, resultTypes, cond,
false,
1933 bool addElseBlock) {
1934 assert((!addElseBlock || addThenBlock) &&
1935 "must not create else block w/o then block");
1950 bool withElseRegion) {
1951 build(builder, result,
TypeRange{}, cond, withElseRegion);
1963 if (resultTypes.empty())
1964 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
1968 if (withElseRegion) {
1970 if (resultTypes.empty())
1971 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
1978 assert(thenBuilder &&
"the builder callback for 'then' must be present");
1985 thenBuilder(builder, result.
location);
1991 elseBuilder(builder, result.
location);
2000 inferredReturnTypes))) {
2001 result.
addTypes(inferredReturnTypes);
2006 if (getNumResults() != 0 && getElseRegion().empty())
2007 return emitOpError(
"must have an else block if defining values");
2045 bool printBlockTerminators =
false;
2047 p <<
" " << getCondition();
2048 if (!getResults().empty()) {
2049 p <<
" -> (" << getResultTypes() <<
")";
2051 printBlockTerminators =
true;
2056 printBlockTerminators);
2059 auto &elseRegion = getElseRegion();
2060 if (!elseRegion.
empty()) {
2064 printBlockTerminators);
2081 Region *elseRegion = &this->getElseRegion();
2082 if (elseRegion->
empty())
2090 FoldAdaptor adaptor(operands, *
this);
2091 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2092 if (!boolAttr || boolAttr.getValue())
2093 regions.emplace_back(&getThenRegion());
2096 if (!boolAttr || !boolAttr.getValue()) {
2097 if (!getElseRegion().empty())
2098 regions.emplace_back(&getElseRegion());
2100 regions.emplace_back(getResults());
2107 if (getElseRegion().empty())
2110 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2117 getConditionMutable().assign(xorStmt.getLhs());
2121 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2122 getElseRegion().getBlocks());
2123 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2124 getThenRegion().getBlocks(), thenBlock);
2128 void IfOp::getRegionInvocationBounds(
2131 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2134 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2135 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2138 invocationBounds.assign(2, {0, 1});
2154 llvm::transform(usedResults, std::back_inserter(usedOperands),
2159 [&]() { yieldOp->setOperands(usedOperands); });
2166 llvm::copy_if(op.
getResults(), std::back_inserter(usedResults),
2167 [](
OpResult result) { return !result.use_empty(); });
2175 llvm::transform(usedResults, std::back_inserter(newTypes),
2180 rewriter.
create<IfOp>(op.
getLoc(), newTypes, op.getCondition());
2186 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2187 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2192 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2209 else if (!op.getElseRegion().empty())
2228 auto cond = op.getCondition();
2229 auto thenYieldArgs = op.thenYield().
getOperands();
2230 auto elseYieldArgs = op.elseYield().
getOperands();
2233 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2236 nonHoistable.push_back(trueVal.getType());
2243 IfOp replacement = rewriter.
create<IfOp>(op.
getLoc(), nonHoistable, cond,
2245 if (replacement.thenBlock())
2246 rewriter.
eraseBlock(replacement.thenBlock());
2247 replacement.getThenRegion().takeBody(op.getThenRegion());
2248 replacement.getElseRegion().takeBody(op.getElseRegion());
2251 assert(thenYieldArgs.size() == results.size());
2252 assert(elseYieldArgs.size() == results.size());
2257 for (
const auto &it :
2259 Value trueVal = std::get<0>(it.value());
2260 Value falseVal = std::get<1>(it.value());
2263 results[it.index()] = replacement.getResult(trueYields.size());
2264 trueYields.push_back(trueVal);
2265 falseYields.push_back(falseVal);
2266 }
else if (trueVal == falseVal)
2267 results[it.index()] = trueVal;
2269 results[it.index()] = rewriter.
create<arith::SelectOp>(
2270 op.
getLoc(), cond, trueVal, falseVal);
2307 bool changed =
false;
2312 Value constantTrue =
nullptr;
2313 Value constantFalse =
nullptr;
2316 llvm::make_early_inc_range(op.getCondition().
getUses())) {
2321 constantTrue = rewriter.
create<arith::ConstantOp>(
2325 [&]() { use.
set(constantTrue); });
2331 constantFalse = rewriter.
create<arith::ConstantOp>(
2335 [&]() { use.
set(constantFalse); });
2379 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2389 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2391 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2394 op.getOperation()->getIterator());
2395 bool changed =
false;
2397 for (
auto [trueResult, falseResult, opResult] :
2398 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2400 if (trueResult == falseResult) {
2401 if (!opResult.use_empty()) {
2402 opResult.replaceAllUsesWith(trueResult);
2413 bool trueVal = trueYield.
getValue();
2414 bool falseVal = falseYield.
getValue();
2415 if (!trueVal && falseVal) {
2416 if (!opResult.use_empty()) {
2417 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2419 op.
getLoc(), op.getCondition(),
2429 if (trueVal && !falseVal) {
2430 if (!opResult.use_empty()) {
2431 opResult.replaceAllUsesWith(op.getCondition());
2466 Block *parent = nextIf->getBlock();
2467 if (nextIf == &parent->
front())
2470 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2478 Block *nextThen =
nullptr;
2479 Block *nextElse =
nullptr;
2480 if (nextIf.getCondition() == prevIf.getCondition()) {
2481 nextThen = nextIf.thenBlock();
2482 if (!nextIf.getElseRegion().empty())
2483 nextElse = nextIf.elseBlock();
2485 if (arith::XOrIOp notv =
2486 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2487 if (notv.getLhs() == prevIf.getCondition() &&
2489 nextElse = nextIf.thenBlock();
2490 if (!nextIf.getElseRegion().empty())
2491 nextThen = nextIf.elseBlock();
2494 if (arith::XOrIOp notv =
2495 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2496 if (notv.getLhs() == nextIf.getCondition() &&
2498 nextElse = nextIf.thenBlock();
2499 if (!nextIf.getElseRegion().empty())
2500 nextThen = nextIf.elseBlock();
2504 if (!nextThen && !nextElse)
2508 if (!prevIf.getElseRegion().empty())
2509 prevElseYielded = prevIf.elseYield().getOperands();
2512 for (
auto it : llvm::zip(prevIf.getResults(),
2513 prevIf.thenYield().getOperands(), prevElseYielded))
2515 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2519 use.
set(std::get<1>(it));
2524 use.
set(std::get<2>(it));
2530 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2532 IfOp combinedIf = rewriter.
create<IfOp>(
2533 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2534 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2537 combinedIf.getThenRegion(),
2538 combinedIf.getThenRegion().begin());
2541 YieldOp thenYield = combinedIf.thenYield();
2542 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2543 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2547 llvm::append_range(mergedYields, thenYield2.getOperands());
2548 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2554 combinedIf.getElseRegion(),
2555 combinedIf.getElseRegion().begin());
2558 if (combinedIf.getElseRegion().empty()) {
2560 combinedIf.getElseRegion(),
2561 combinedIf.getElseRegion().
begin());
2563 YieldOp elseYield = combinedIf.elseYield();
2564 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2565 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2570 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2572 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2581 if (pair.index() < prevIf.getNumResults())
2582 prevValues.push_back(pair.value());
2584 nextValues.push_back(pair.value());
2599 if (ifOp.getNumResults())
2601 Block *elseBlock = ifOp.elseBlock();
2602 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2606 newIfOp.getThenRegion().begin());
2633 auto nestedOps = op.thenBlock()->without_terminator();
2635 if (!llvm::hasSingleElement(nestedOps))
2639 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2642 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2646 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2652 llvm::append_range(elseYield, op.elseYield().
getOperands());
2666 if (tup.value().getDefiningOp() == nestedIf) {
2667 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2668 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2669 elseYield[tup.index()]) {
2674 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2687 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2690 elseYieldsToUpgradeToSelect.push_back(tup.index());
2694 Value newCondition = rewriter.
create<arith::AndIOp>(
2695 loc, op.getCondition(), nestedIf.getCondition());
2700 llvm::append_range(results, newIf.getResults());
2703 for (
auto idx : elseYieldsToUpgradeToSelect)
2704 results[idx] = rewriter.
create<arith::SelectOp>(
2705 op.
getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2707 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2710 if (!elseYield.empty()) {
2713 rewriter.
create<YieldOp>(loc, elseYield);
2724 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2725 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2726 RemoveStaticCondition, RemoveUnusedResults,
2727 ReplaceIfYieldWithConditionOrValue>(context);
2730 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2731 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2732 Block *IfOp::elseBlock() {
2733 Region &r = getElseRegion();
2738 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2744 void ParallelOp::build(
2754 ParallelOp::getOperandSegmentSizeAttr(),
2756 static_cast<int32_t>(upperBounds.size()),
2757 static_cast<int32_t>(steps.size()),
2758 static_cast<int32_t>(initVals.size())}));
2762 unsigned numIVs = steps.size();
2768 if (bodyBuilderFn) {
2770 bodyBuilderFn(builder, result.
location,
2774 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2777 void ParallelOp::build(
2784 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2787 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2791 wrapper = wrappedBuilderFn;
2793 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2802 if (stepValues.empty())
2804 "needs at least one tuple element for lowerBound, upperBound and step");
2807 for (
Value stepValue : stepValues)
2810 return emitOpError(
"constant step operand must be positive");
2814 Block *body = getBody();
2816 return emitOpError() <<
"expects the same number of induction variables: "
2818 <<
" as bound and step values: " << stepValues.size();
2820 if (!arg.getType().isIndex())
2822 "expects arguments for the induction variable to be of index type");
2825 auto yield = verifyAndGetTerminator<scf::YieldOp>(
2826 *
this, getRegion(),
"expects body to terminate with 'scf.yield'");
2829 if (yield->getNumOperands() != 0)
2830 return yield.emitOpError() <<
"not allowed to have operands inside '"
2831 << ParallelOp::getOperationName() <<
"'";
2835 auto resultsSize = getResults().size();
2836 auto reductionsSize = reductions.size();
2837 auto initValsSize = getInitVals().size();
2838 if (resultsSize != reductionsSize)
2839 return emitOpError() <<
"expects number of results: " << resultsSize
2840 <<
" to be the same as number of reductions: "
2842 if (resultsSize != initValsSize)
2843 return emitOpError() <<
"expects number of results: " << resultsSize
2844 <<
" to be the same as number of initial values: "
2848 for (
auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2849 auto resultType = std::get<0>(resultAndReduce).getType();
2850 auto reduceOp = std::get<1>(resultAndReduce);
2851 auto reduceType = reduceOp.getOperand().getType();
2852 if (resultType != reduceType)
2853 return reduceOp.emitOpError()
2854 <<
"expects type of reduce: " << reduceType
2855 <<
" to be the same as result type: " << resultType;
2903 for (
auto &iv : ivs)
2910 ParallelOp::getOperandSegmentSizeAttr(),
2912 static_cast<int32_t>(upper.size()),
2913 static_cast<int32_t>(steps.size()),
2914 static_cast<int32_t>(initVals.size())}));
2923 ForOp::ensureTerminator(*body, builder, result.
location);
2928 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2929 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2930 if (!getInitVals().empty())
2931 p <<
" init (" << getInitVals() <<
")";
2936 (*this)->getAttrs(),
2937 ParallelOp::getOperandSegmentSizeAttr());
2942 std::optional<Value> ParallelOp::getSingleInductionVar() {
2943 if (getNumLoops() != 1)
2944 return std::nullopt;
2945 return getBody()->getArgument(0);
2948 std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
2949 if (getNumLoops() != 1)
2950 return std::nullopt;
2954 std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
2955 if (getNumLoops() != 1)
2956 return std::nullopt;
2960 std::optional<OpFoldResult> ParallelOp::getSingleStep() {
2961 if (getNumLoops() != 1)
2962 return std::nullopt;
2963 return getStep()[0];
2967 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2969 return ParallelOp();
2970 assert(ivArg.getOwner() &&
"unlinked block argument");
2971 auto *containingOp = ivArg.getOwner()->getParentOp();
2972 return dyn_cast<ParallelOp>(containingOp);
2977 struct ParallelOpSingleOrZeroIterationDimsFolder
2988 for (
auto [lb, ub, step, iv] :
2989 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2990 op.getInductionVars())) {
2992 if (numIterations.has_value()) {
2994 if (*numIterations == 0) {
2995 rewriter.
replaceOp(op, op.getInitVals());
3000 if (*numIterations == 1) {
3005 newLowerBounds.push_back(lb);
3006 newUpperBounds.push_back(ub);
3007 newSteps.push_back(step);
3010 if (newLowerBounds.size() == op.getLowerBound().size())
3013 if (newLowerBounds.empty()) {
3017 results.reserve(op.getInitVals().size());
3018 for (
auto &bodyOp : op.getBody()->without_terminator()) {
3019 auto reduce = dyn_cast<ReduceOp>(bodyOp);
3021 rewriter.
clone(bodyOp, mapping);
3024 Block &reduceBlock =
reduce.getReductionOperator().front();
3025 auto initValIndex = results.size();
3026 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3030 rewriter.
clone(reduceBodyOp, mapping);
3033 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3034 results.push_back(result);
3041 rewriter.
create<ParallelOp>(op.
getLoc(), newLowerBounds, newUpperBounds,
3042 newSteps, op.getInitVals(),
nullptr);
3046 newOp.getRegion().
begin(), mapping);
3047 rewriter.
replaceOp(op, newOp.getResults());
3057 Block &outerBody = *op.getBody();
3061 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3066 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3067 llvm::is_contained(innerOp.getUpperBound(), val) ||
3068 llvm::is_contained(innerOp.getStep(), val))
3072 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3077 Block &innerBody = *innerOp.getBody();
3078 assert(iterVals.size() ==
3086 builder.
clone(op, mapping);
3089 auto concatValues = [](
const auto &first,
const auto &second) {
3091 ret.reserve(first.size() + second.size());
3092 ret.assign(first.begin(), first.end());
3093 ret.append(second.begin(), second.end());
3097 auto newLowerBounds =
3098 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3099 auto newUpperBounds =
3100 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3101 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3104 newSteps, std::nullopt,
3115 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3124 void ParallelOp::getSuccessorRegions(
3137 void ReduceOp::build(
3140 auto type = operand.
getType();
3154 auto type = getOperand().getType();
3155 Block &block = getReductionOperator().
front();
3157 return emitOpError(
"the block inside reduce should not be empty");
3160 return arg.getType() != type;
3162 return emitOpError() <<
"expects two arguments to reduce block of type "
3167 return emitOpError(
"the block inside reduce should be terminated with a "
3168 "'scf.reduce.return' op");
3195 p <<
"(" << getOperand() <<
") ";
3196 p <<
" : " << getOperand().getType() <<
' ';
3207 auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
3208 Type reduceType = reduceOp.getOperand().getType();
3209 if (reduceType != getResult().getType())
3210 return emitOpError() <<
"needs to have type " << reduceType
3211 <<
" (the type of the enclosing ReduceOp)";
3221 ValueRange operands, BodyBuilderFn beforeBuilder,
3222 BodyBuilderFn afterBuilder) {
3230 beforeArgLocs.reserve(operands.size());
3231 for (
Value operand : operands) {
3232 beforeArgLocs.push_back(operand.getLoc());
3237 beforeRegion, {}, operands.getTypes(), beforeArgLocs);
3246 resultTypes, afterArgLocs);
3252 ConditionOp WhileOp::getConditionOp() {
3253 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3256 YieldOp WhileOp::getYieldOp() {
3257 return cast<YieldOp>(getAfterBody()->getTerminator());
3261 return getYieldOp().getResultsMutable();
3265 return getBeforeBody()->getArguments();
3269 return getAfterBody()->getArguments();
3273 return getBeforeArguments();
3277 assert(point == getBefore() &&
3278 "WhileOp is expected to branch only to the first region");
3286 regions.emplace_back(&getBefore(), getBefore().getArguments());
3290 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3291 "there are only two regions in a WhileOp");
3293 if (point == getAfter()) {
3294 regions.emplace_back(&getBefore(), getBefore().getArguments());
3298 regions.emplace_back(getResults());
3299 regions.emplace_back(&getAfter(), getAfter().getArguments());
3303 return {&getBefore(), &getAfter()};
3324 FunctionType functionType;
3329 result.
addTypes(functionType.getResults());
3331 if (functionType.getNumInputs() != operands.size()) {
3333 <<
"expected as many input types as operands "
3334 <<
"(expected " << operands.size() <<
" got "
3335 << functionType.getNumInputs() <<
")";
3345 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3346 regionArgs[i].type = functionType.getInput(i);
3368 template <
typename OpTy>
3371 if (left.size() != right.size())
3372 return op.
emitOpError(
"expects the same number of ") << message;
3374 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3375 if (left[i] != right[i]) {
3378 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3379 <<
" and " << right[i];
3388 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3390 "expects the 'before' region to terminate with 'scf.condition'");
3391 if (!beforeTerminator)
3394 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3396 "expects the 'after' region to terminate with 'scf.yield'");
3397 return success(afterTerminator !=
nullptr);
3425 auto term = op.getConditionOp();
3429 Value constantTrue =
nullptr;
3431 bool replaced =
false;
3432 for (
auto yieldedAndBlockArgs :
3433 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3434 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3435 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3437 constantTrue = rewriter.
create<arith::ConstantOp>(
3438 op.
getLoc(), term.getCondition().getType(),
3499 struct RemoveLoopInvariantArgsFromBeforeBlock
3505 Block &afterBlock = *op.getAfterBody();
3507 ConditionOp condOp = op.getConditionOp();
3512 bool canSimplify =
false;
3513 for (
const auto &it :
3515 auto index =
static_cast<unsigned>(it.index());
3516 auto [initVal, yieldOpArg] = it.value();
3519 if (yieldOpArg == initVal) {
3528 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3529 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3530 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3531 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3544 for (
const auto &it :
3546 auto index =
static_cast<unsigned>(it.index());
3547 auto [initVal, yieldOpArg] = it.value();
3551 if (yieldOpArg == initVal) {
3552 beforeBlockInitValMap.insert({index, initVal});
3560 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3561 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3562 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3563 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3564 beforeBlockInitValMap.insert({index, initVal});
3569 newInitArgs.emplace_back(initVal);
3570 newYieldOpArgs.emplace_back(yieldOpArg);
3571 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3584 &newWhile.getBefore(), {},
3587 Block &beforeBlock = *op.getBeforeBody();
3594 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3597 if (beforeBlockInitValMap.count(i) != 0)
3598 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3600 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3603 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3605 newWhile.getAfter().begin());
3607 rewriter.
replaceOp(op, newWhile.getResults());
3652 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3657 Block &beforeBlock = *op.getBeforeBody();
3658 ConditionOp condOp = op.getConditionOp();
3661 bool canSimplify =
false;
3662 for (
Value condOpArg : condOpArgs) {
3682 auto index =
static_cast<unsigned>(it.index());
3683 Value condOpArg = it.value();
3688 condOpInitValMap.insert({index, condOpArg});
3690 newCondOpArgs.emplace_back(condOpArg);
3691 newAfterBlockType.emplace_back(condOpArg.
getType());
3692 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3703 auto newWhile = rewriter.
create<WhileOp>(op.
getLoc(), newAfterBlockType,
3706 Block &newAfterBlock =
3708 newAfterBlockType, newAfterBlockArgLocs);
3710 Block &afterBlock = *op.getAfterBody();
3717 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3718 Value afterBlockArg, result;
3721 if (condOpInitValMap.count(i) != 0) {
3722 afterBlockArg = condOpInitValMap[i];
3723 result = afterBlockArg;
3725 afterBlockArg = newAfterBlock.getArgument(
j);
3726 result = newWhile.getResult(
j);
3729 newAfterBlockArgs[i] = afterBlockArg;
3730 newWhileResults[i] = result;
3733 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3735 newWhile.getBefore().begin());
3737 rewriter.
replaceOp(op, newWhileResults);
3773 auto term = op.getConditionOp();
3774 auto afterArgs = op.getAfterArguments();
3775 auto termArgs = term.getArgs();
3782 bool needUpdate =
false;
3783 for (
const auto &it :
3785 auto i =
static_cast<unsigned>(it.index());
3786 Value result = std::get<0>(it.value());
3787 Value afterArg = std::get<1>(it.value());
3788 Value termArg = std::get<2>(it.value());
3792 newResultsIndices.emplace_back(i);
3793 newTermArgs.emplace_back(termArg);
3794 newResultTypes.emplace_back(result.
getType());
3795 newArgLocs.emplace_back(result.
getLoc());
3810 rewriter.
create<WhileOp>(op.
getLoc(), newResultTypes, op.getInits());
3813 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3820 newResults[it.value()] = newWhile.getResult(it.index());
3821 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3825 newWhile.getBefore().begin());
3827 Block &afterBlock = *op.getAfterBody();
3828 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3862 using namespace scf;
3863 auto cond = op.getConditionOp();
3864 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3867 bool changed =
false;
3868 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3869 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3870 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3873 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3874 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3878 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3881 if (cmp2.getPredicate() == cmp.getPredicate())
3882 samePredicate =
true;
3883 else if (cmp2.getPredicate() ==
3885 samePredicate =
false;
3906 if (!llvm::any_of(op.getBeforeArguments(),
3907 [](
Value arg) { return arg.use_empty(); }))
3910 YieldOp yield = op.getYieldOp();
3915 llvm::BitVector argsToErase;
3917 size_t argsCount = op.getBeforeArguments().size();
3918 newYields.reserve(argsCount);
3919 newInits.reserve(argsCount);
3920 argsToErase.reserve(argsCount);
3921 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3922 op.getBeforeArguments(), yield.
getOperands(), op.getInits())) {
3923 if (beforeArg.use_empty()) {
3924 argsToErase.push_back(
true);
3926 argsToErase.push_back(
false);
3927 newYields.emplace_back(yieldValue);
3928 newInits.emplace_back(initValue);
3932 Block &beforeBlock = *op.getBeforeBody();
3933 Block &afterBlock = *op.getAfterBody();
3941 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3942 Block &newAfterBlock = *newWhileOp.getAfterBody();
3948 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
3949 newBeforeBlock.getArguments());
3953 rewriter.
replaceOp(op, newWhileOp.getResults());
3964 ConditionOp condOp = op.getConditionOp();
3968 for (
Value arg : condOpArgs)
3969 argsSet.insert(arg);
3971 if (argsSet.size() == condOpArgs.size())
3974 llvm::SmallDenseMap<Value, unsigned> argsMap;
3976 argsMap.reserve(condOpArgs.size());
3977 newArgs.reserve(condOpArgs.size());
3978 for (
Value arg : condOpArgs) {
3979 if (!argsMap.count(arg)) {
3980 auto pos =
static_cast<unsigned>(argsMap.size());
3981 argsMap.insert({arg, pos});
3982 newArgs.emplace_back(arg);
3989 auto newWhileOp = rewriter.
create<scf::WhileOp>(
3990 loc, argsRange.getTypes(), op.getInits(),
nullptr,
3992 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3993 Block &newAfterBlock = *newWhileOp.getAfterBody();
3998 auto it = argsMap.find(arg);
3999 assert(it != argsMap.end());
4000 auto pos = it->second;
4001 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4002 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4010 Block &beforeBlock = *op.getBeforeBody();
4011 Block &afterBlock = *op.getAfterBody();
4013 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4014 newBeforeBlock.getArguments());
4015 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4024 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4025 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4026 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4027 WhileRemoveUnusedArgs>(context);
4041 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4044 caseValues.push_back(value);
4053 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4055 p <<
"case " << value <<
' ';
4061 if (getCases().size() != getCaseRegions().size()) {
4062 return emitOpError(
"has ")
4063 << getCaseRegions().size() <<
" case regions but "
4064 << getCases().size() <<
" case values";
4068 for (int64_t value : getCases())
4069 if (!valueSet.insert(value).second)
4070 return emitOpError(
"has duplicate case value: ") << value;
4072 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4074 return emitOpError(
"expected region to end with scf.yield, but got ")
4077 if (yield.getNumOperands() != getNumResults()) {
4078 return (emitOpError(
"expected each region to return ")
4079 << getNumResults() <<
" values, but " << name <<
" returns "
4080 << yield.getNumOperands())
4081 .attachNote(yield.getLoc())
4082 <<
"see yield operation here";
4084 for (
auto [idx, result, operand] :
4085 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4086 yield.getOperandTypes())) {
4087 if (result == operand)
4089 return (emitOpError(
"expected result #")
4090 << idx <<
" of each region to be " << result)
4091 .attachNote(yield.getLoc())
4092 << name <<
" returns " << operand <<
" here";
4097 if (
failed(verifyRegion(getDefaultRegion(),
"default region")))
4100 if (
failed(verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4106 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4108 Block &scf::IndexSwitchOp::getDefaultBlock() {
4109 return getDefaultRegion().
front();
4112 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4113 assert(idx < getNumCases() &&
"case index out-of-bounds");
4114 return getCaseRegions()[idx].front();
4117 void IndexSwitchOp::getSuccessorRegions(
4121 successors.emplace_back(getResults());
4125 llvm::copy(getRegions(), std::back_inserter(successors));
4128 void IndexSwitchOp::getEntrySuccessorRegions(
4131 FoldAdaptor adaptor(operands, *
this);
4134 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4136 llvm::copy(getRegions(), std::back_inserter(successors));
4142 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4143 if (caseValue == arg.getInt()) {
4144 successors.emplace_back(&caseRegion);
4148 successors.emplace_back(&getDefaultRegion());
4151 void IndexSwitchOp::getRegionInvocationBounds(
4153 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4154 if (!operandValue) {
4160 unsigned liveIndex = getNumRegions() - 1;
4161 const auto *it = llvm::find(getCases(), operandValue.getInt());
4162 if (it != getCases().end())
4163 liveIndex = std::distance(getCases().begin(), it);
4164 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4165 bounds.emplace_back(0, i == liveIndex);
4171 if (!maybeCst.has_value())
4173 int64_t cst = *maybeCst;
4174 int64_t caseIdx, e = getNumCases();
4175 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4176 if (cst == getCases()[caseIdx])
4180 Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
4181 : getDefaultRegion();
4186 Block *pDestination = (*this)->getBlock();
4201 #define GET_OP_CLASSES
4202 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
OpListType & getOperations()
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
void push_back(Block *block)
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false)
Returns "success" when any of the elements in ofrs is a constant value.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.