29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/SmallPtrSet.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/DebugLog.h"
38 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
61 auto retValOp = dyn_cast<scf::YieldOp>(op);
65 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
66 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
76 void SCFDialect::initialize() {
79 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
81 addInterfaces<SCFInlinerInterface>();
82 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
83 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
84 InParallelOp, ReduceReturnOp>();
85 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
86 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
87 ForallOp, InParallelOp, WhileOp, YieldOp>();
88 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
93 scf::YieldOp::create(builder, loc);
98 template <
typename TerminatorTy>
100 StringRef errorMessage) {
101 Operation *terminatorOperation =
nullptr;
103 terminatorOperation = ®ion.
front().
back();
104 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
108 if (terminatorOperation)
109 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
125 if (addOp.getLhs() != lb ||
139 assert(region.
hasOneBlock() &&
"expected single-block region");
188 if (getRegion().empty())
189 return emitOpError(
"region needs to have at least one block");
190 if (getRegion().front().getNumArguments() > 0)
191 return emitOpError(
"region cannot have any arguments");
214 if (!op.getRegion().hasOneBlock() || op.getNoInline())
263 if (op.getNoInline())
265 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
268 Block *prevBlock = op->getBlock();
272 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
274 for (
Block &blk : op.getRegion()) {
275 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
277 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278 yieldOp.getResults());
286 for (
auto res : op.getResults())
287 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
299 void ExecuteRegionOp::getSuccessorRegions(
317 assert((point.
isParent() || point == getParentOp().getAfter()) &&
318 "condition op can only exit the loop or branch to the after"
321 return getArgsMutable();
324 void ConditionOp::getSuccessorRegions(
326 FoldAdaptor adaptor(operands, *
this);
328 WhileOp whileOp = getParentOp();
332 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
333 if (!boolAttr || boolAttr.getValue())
334 regions.emplace_back(&whileOp.getAfter(),
335 whileOp.getAfter().getArguments());
336 if (!boolAttr || !boolAttr.getValue())
337 regions.emplace_back(whileOp.getResults());
346 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
354 for (
Value v : initArgs)
360 for (
Value v : initArgs)
366 if (initArgs.empty() && !bodyBuilder) {
367 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
368 }
else if (bodyBuilder) {
378 if (getInitArgs().size() != getNumResults())
380 "mismatch in number of loop-carried values and defined values");
385 LogicalResult ForOp::verifyRegions() {
390 "expected induction variable to be same type as bounds and step");
392 if (getNumRegionIterArgs() != getNumResults())
394 "mismatch in number of basic block args and defined values");
396 auto initArgs = getInitArgs();
397 auto iterArgs = getRegionIterArgs();
398 auto opResults = getResults();
400 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
402 return emitOpError() <<
"types mismatch between " << i
403 <<
"th iter operand and defined value";
405 return emitOpError() <<
"types mismatch between " << i
406 <<
"th iter region arg and defined value";
413 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
417 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
421 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
425 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
429 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
434 std::optional<APInt> tripCount = getStaticTripCount();
435 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
438 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
441 if (*tripCount == 0) {
448 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
455 llvm::append_range(bbArgReplacements, getInitArgs());
459 getOperation()->getIterator(), bbArgReplacements);
475 StringRef prefix =
"") {
476 assert(blocksArgs.size() == initializers.size() &&
477 "expected same length of arguments and initializers");
478 if (initializers.empty())
482 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
483 p << std::get<0>(it) <<
" = " << std::get<1>(it);
489 if (getUnsignedCmp())
492 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
496 if (!getInitArgs().empty())
497 p <<
" -> (" << getInitArgs().getTypes() <<
')';
500 p <<
" : " << t <<
' ';
503 !getInitArgs().empty());
505 getUnsignedCmpAttrName().strref());
530 regionArgs.push_back(inductionVariable);
540 if (regionArgs.size() != result.
types.size() + 1)
543 "mismatch in number of loop-carried values and defined values");
552 regionArgs.front().type = type;
553 for (
auto [iterArg, type] :
554 llvm::zip_equal(llvm::drop_begin(regionArgs), result.
types))
561 ForOp::ensureTerminator(*body, builder, result.
location);
570 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
571 operands, result.
types)) {
572 Type type = std::get<2>(argOperandType);
573 std::get<0>(argOperandType).type = type;
590 return getBody()->getArguments().drop_front(getNumInductionVars());
594 return getInitArgsMutable();
597 FailureOr<LoopLikeOpInterface>
598 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
600 bool replaceInitOperandUsesInLoop,
605 auto inits = llvm::to_vector(getInitArgs());
606 inits.append(newInitOperands.begin(), newInitOperands.end());
607 scf::ForOp newLoop = scf::ForOp::create(
613 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
615 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
620 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
621 assert(newInitOperands.size() == newYieldedValues.size() &&
622 "expected as many new yield values as new iter operands");
624 yieldOp.getResultsMutable().append(newYieldedValues);
630 newLoop.getBody()->getArguments().take_front(
631 getBody()->getNumArguments()));
633 if (replaceInitOperandUsesInLoop) {
636 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
647 newLoop->getResults().take_front(getNumResults()));
648 return cast<LoopLikeOpInterface>(newLoop.getOperation());
652 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
655 assert(ivArg.getOwner() &&
"unlinked block argument");
656 auto *containingOp = ivArg.getOwner()->getParentOp();
657 return dyn_cast_or_null<ForOp>(containingOp);
661 return getInitArgs();
678 for (
auto [lb, ub, step] :
679 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
682 if (!tripCount.has_value() || *tripCount != 1)
691 return getBody()->getArguments().drop_front(getRank());
695 return getOutputsMutable();
701 scf::InParallelOp terminator = forallOp.getTerminator();
706 bbArgReplacements.append(forallOp.getOutputs().begin(),
707 forallOp.getOutputs().end());
711 forallOp->getIterator(), bbArgReplacements);
716 results.reserve(forallOp.getResults().size());
717 for (
auto &yieldingOp : terminator.getYieldingOps()) {
718 auto parallelInsertSliceOp =
719 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
720 if (!parallelInsertSliceOp)
723 Value dst = parallelInsertSliceOp.getDest();
724 Value src = parallelInsertSliceOp.getSource();
725 if (llvm::isa<TensorType>(src.
getType())) {
726 results.push_back(tensor::InsertSliceOp::create(
727 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
728 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
729 parallelInsertSliceOp.getStrides(),
730 parallelInsertSliceOp.getStaticOffsets(),
731 parallelInsertSliceOp.getStaticSizes(),
732 parallelInsertSliceOp.getStaticStrides()));
734 llvm_unreachable(
"unsupported terminator");
749 assert(lbs.size() == ubs.size() &&
750 "expected the same number of lower and upper bounds");
751 assert(lbs.size() == steps.size() &&
752 "expected the same number of lower bounds and steps");
757 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
759 assert(results.size() == iterArgs.size() &&
760 "loop nest body must return as many values as loop has iteration "
762 return LoopNest{{}, std::move(results)};
770 loops.reserve(lbs.size());
771 ivs.reserve(lbs.size());
774 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
775 auto loop = scf::ForOp::create(
776 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
782 currentIterArgs = args;
783 currentLoc = nestedLoc;
789 loops.push_back(loop);
793 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
795 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
802 ? bodyBuilder(builder, currentLoc, ivs,
803 loops.back().getRegionIterArgs())
805 assert(results.size() == iterArgs.size() &&
806 "loop nest body must return as many values as loop has iteration "
809 scf::YieldOp::create(builder, loc, results);
813 llvm::append_range(nestResults, loops.front().getResults());
814 return LoopNest{std::move(loops), std::move(nestResults)};
827 bodyBuilder(nestedBuilder, nestedLoc, ivs);
836 assert(operand.
getOwner() == forOp);
841 "expected an iter OpOperand");
843 "Expected a different type");
845 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
847 newIterOperands.push_back(replacement);
850 newIterOperands.push_back(opOperand.get());
854 scf::ForOp newForOp = scf::ForOp::create(
855 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
856 forOp.getStep(), newIterOperands,
nullptr,
857 forOp.getUnsignedCmp());
858 newForOp->setAttrs(forOp->getAttrs());
859 Block &newBlock = newForOp.getRegion().
front();
867 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
869 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
870 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
874 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
877 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
880 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
881 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
882 clonedYieldOp.getOperand(yieldIdx));
884 newYieldOperands[yieldIdx] = castOut;
885 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
886 rewriter.
eraseOp(clonedYieldOp);
891 newResults[yieldIdx] =
892 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
912 LogicalResult matchAndRewrite(scf::ForOp forOp,
914 bool canonicalize =
false;
921 int64_t numResults = forOp.getNumResults();
923 keepMask.reserve(numResults);
926 newBlockTransferArgs.reserve(1 + numResults);
927 newBlockTransferArgs.push_back(
Value());
928 newIterArgs.reserve(forOp.getInitArgs().size());
929 newYieldValues.reserve(numResults);
930 newResultValues.reserve(numResults);
932 for (
auto [init, arg, result, yielded] :
933 llvm::zip(forOp.getInitArgs(),
934 forOp.getRegionIterArgs(),
936 forOp.getYieldedValues()
943 bool forwarded = (arg == yielded) || (init == yielded) ||
944 (arg.use_empty() && result.use_empty());
947 keepMask.push_back(
false);
948 newBlockTransferArgs.push_back(init);
949 newResultValues.push_back(init);
955 if (
auto it = initYieldToArg.find({init, yielded});
956 it != initYieldToArg.end()) {
958 keepMask.push_back(
false);
959 auto [sameArg, sameResult] = it->second;
963 newBlockTransferArgs.push_back(init);
964 newResultValues.push_back(init);
969 initYieldToArg.insert({{init, yielded}, {arg, result}});
970 keepMask.push_back(
true);
971 newIterArgs.push_back(init);
972 newYieldValues.push_back(yielded);
973 newBlockTransferArgs.push_back(
Value());
974 newResultValues.push_back(
Value());
980 scf::ForOp newForOp =
981 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
982 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
983 nullptr, forOp.getUnsignedCmp());
984 newForOp->setAttrs(forOp->getAttrs());
985 Block &newBlock = newForOp.getRegion().
front();
989 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
991 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
992 Value &newResultVal = newResultValues[idx];
993 assert((blockTransferArg && newResultVal) ||
994 (!blockTransferArg && !newResultVal));
995 if (!blockTransferArg) {
996 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
997 newResultVal = newForOp.getResult(collapsedIdx++);
1003 "unexpected argument size mismatch");
1008 if (newIterArgs.empty()) {
1009 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1012 rewriter.
replaceOp(forOp, newResultValues);
1017 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1021 filteredOperands.reserve(newResultValues.size());
1022 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1024 filteredOperands.push_back(mergedTerminator.getOperand(idx));
1025 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1029 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1030 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1031 cloneFilteredTerminator(mergedYieldOp);
1032 rewriter.
eraseOp(mergedYieldOp);
1033 rewriter.
replaceOp(forOp, newResultValues);
1044 LogicalResult matchAndRewrite(ForOp op,
1046 std::optional<APInt> tripCount = op.getStaticTripCount();
1047 if (!tripCount.has_value())
1049 "can't compute constant trip count");
1051 if (tripCount->isZero()) {
1052 LDBG() <<
"SimplifyTrivialLoops tripCount is 0 for loop "
1054 rewriter.
replaceOp(op, op.getInitArgs());
1058 if (tripCount->getSExtValue() == 1) {
1059 LDBG() <<
"SimplifyTrivialLoops tripCount is 1 for loop "
1062 blockArgs.reserve(op.getInitArgs().size() + 1);
1063 blockArgs.push_back(op.getLowerBound());
1064 llvm::append_range(blockArgs, op.getInitArgs());
1071 if (!llvm::hasSingleElement(block))
1075 if (llvm::any_of(op.getYieldedValues(),
1076 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1078 LDBG() <<
"SimplifyTrivialLoops empty body loop allows replacement with "
1079 "yield operands for loop "
1081 rewriter.
replaceOp(op, op.getYieldedValues());
1115 LogicalResult matchAndRewrite(ForOp op,
1117 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1118 OpOperand &iterOpOperand = std::get<0>(it);
1120 if (!incomingCast ||
1121 incomingCast.getSource().getType() == incomingCast.getType())
1126 incomingCast.getDest().getType(),
1127 incomingCast.getSource().getType()))
1129 if (!std::get<1>(it).hasOneUse())
1135 rewriter, op, iterOpOperand, incomingCast.getSource(),
1137 return tensor::CastOp::create(b, loc, type, source);
1149 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1153 std::optional<APInt> ForOp::getConstantStep() {
1156 return step.getValue();
1160 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1161 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1167 if (
auto constantStep = getConstantStep())
1168 if (*constantStep == 1)
1176 std::optional<APInt> ForOp::getStaticTripCount() {
1186 unsigned numLoops = getRank();
1188 if (getNumResults() != getOutputs().size())
1189 return emitOpError(
"produces ")
1190 << getNumResults() <<
" results, but has only "
1191 << getOutputs().size() <<
" outputs";
1194 auto *body = getBody();
1196 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1197 for (int64_t i = 0; i < numLoops; ++i)
1199 return emitOpError(
"expects ")
1200 << i <<
"-th block argument to be an index";
1201 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1203 return emitOpError(
"type mismatch between ")
1204 << i <<
"-th output and corresponding block argument";
1205 if (getMapping().has_value() && !getMapping()->empty()) {
1206 if (getDeviceMappingAttrs().size() != numLoops)
1207 return emitOpError() <<
"mapping attribute size must match op rank";
1208 if (
failed(getDeviceMaskingAttr()))
1210 <<
" supports at most one device masking attribute";
1216 getStaticLowerBound(),
1217 getDynamicLowerBound())))
1220 getStaticUpperBound(),
1221 getDynamicUpperBound())))
1224 getStaticStep(), getDynamicStep())))
1232 p <<
" (" << getInductionVars();
1233 if (isNormalized()) {
1254 if (!getRegionOutArgs().empty())
1255 p <<
"-> (" << getResultTypes() <<
") ";
1256 p.printRegion(getRegion(),
1258 getNumResults() > 0);
1259 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1260 getStaticLowerBoundAttrName(),
1261 getStaticUpperBoundAttrName(),
1262 getStaticStepAttrName()});
1267 auto indexType = b.getIndexType();
1287 unsigned numLoops = ivs.size();
1322 if (outOperands.size() != result.
types.size())
1324 "mismatch between out operands and types");
1334 std::unique_ptr<Region> region = std::make_unique<Region>();
1335 for (
auto &iv : ivs) {
1336 iv.type = b.getIndexType();
1337 regionArgs.push_back(iv);
1340 auto &out = it.value();
1341 out.type = result.
types[it.index()];
1342 regionArgs.push_back(out);
1348 ForallOp::ensureTerminator(*region, b, result.
location);
1360 {static_cast<int32_t>(dynamicLbs.size()),
1361 static_cast<int32_t>(dynamicUbs.size()),
1362 static_cast<int32_t>(dynamicSteps.size()),
1363 static_cast<int32_t>(outOperands.size())}));
1368 void ForallOp::build(
1372 std::optional<ArrayAttr> mapping,
1393 "operandSegmentSizes",
1395 static_cast<int32_t>(dynamicUbs.size()),
1396 static_cast<int32_t>(dynamicSteps.size()),
1397 static_cast<int32_t>(outputs.size())}));
1398 if (mapping.has_value()) {
1417 if (!bodyBuilderFn) {
1418 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1425 void ForallOp::build(
1428 std::optional<ArrayAttr> mapping,
1430 unsigned numLoops = ubs.size();
1433 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1437 bool ForallOp::isNormalized() {
1441 return intValue.has_value() && intValue == val;
1444 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1447 InParallelOp ForallOp::getTerminator() {
1448 return cast<InParallelOp>(getBody()->getTerminator());
1454 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1455 storeOps.push_back(parallelOp);
1465 for (
auto attr : getMapping()->getValue()) {
1466 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1473 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1474 DeviceMaskingAttrInterface res;
1477 for (
auto attr : getMapping()->getValue()) {
1478 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1487 bool ForallOp::usesLinearMapping() {
1491 return ifaces.front().isLinearMapping();
1494 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1499 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1501 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1505 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1507 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1511 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1517 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1520 assert(tidxArg.getOwner() &&
"unlinked block argument");
1521 auto *containingOp = tidxArg.getOwner()->getParentOp();
1522 return dyn_cast<ForallOp>(containingOp);
1530 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1532 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1536 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1539 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1548 LogicalResult matchAndRewrite(ForallOp op,
1563 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1564 op.setStaticLowerBound(staticLowerBound);
1568 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1569 op.setStaticUpperBound(staticUpperBound);
1572 op.getDynamicStepMutable().assign(dynamicStep);
1573 op.setStaticStep(staticStep);
1575 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1577 {static_cast<int32_t>(dynamicLowerBound.size()),
1578 static_cast<int32_t>(dynamicUpperBound.size()),
1579 static_cast<int32_t>(dynamicStep.size()),
1580 static_cast<int32_t>(op.getNumResults())}));
1662 LogicalResult matchAndRewrite(ForallOp forallOp,
1681 for (
OpResult result : forallOp.getResults()) {
1682 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1683 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1684 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1685 resultToDelete.insert(result);
1687 resultToReplace.push_back(result);
1688 newOuts.push_back(opOperand->
get());
1694 if (resultToDelete.empty())
1702 for (
OpResult result : resultToDelete) {
1703 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1704 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1706 forallOp.getCombiningOps(blockArg);
1707 for (
Operation *combiningOp : combiningOps)
1708 rewriter.
eraseOp(combiningOp);
1713 auto newForallOp = scf::ForallOp::create(
1714 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1715 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1716 forallOp.getMapping(),
1721 Block *loopBody = forallOp.getBody();
1722 Block *newLoopBody = newForallOp.getBody();
1727 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1734 for (
OpResult result : forallOp.getResults()) {
1735 if (resultToDelete.count(result)) {
1736 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1738 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1741 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1745 for (
auto &&[oldResult, newResult] :
1746 llvm::zip(resultToReplace, newForallOp->getResults()))
1752 for (
OpResult oldResult : resultToDelete)
1754 forallOp.getTiedOpOperand(oldResult)->get());
1759 struct ForallOpSingleOrZeroIterationDimsFolder
1763 LogicalResult matchAndRewrite(ForallOp op,
1766 if (op.getMapping().has_value() && !op.getMapping()->empty())
1774 for (
auto [lb, ub, step, iv] :
1775 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1776 op.getMixedStep(), op.getInductionVars())) {
1777 auto numIterations =
1779 if (numIterations.has_value()) {
1781 if (*numIterations == 0) {
1782 rewriter.
replaceOp(op, op.getOutputs());
1787 if (*numIterations == 1) {
1792 newMixedLowerBounds.push_back(lb);
1793 newMixedUpperBounds.push_back(ub);
1794 newMixedSteps.push_back(step);
1798 if (newMixedLowerBounds.empty()) {
1804 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1806 op,
"no dimensions have 0 or 1 iterations");
1811 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1812 newMixedUpperBounds, newMixedSteps,
1813 op.getOutputs(), std::nullopt,
nullptr);
1814 newOp.getBodyRegion().getBlocks().clear();
1819 newOp.getStaticLowerBoundAttrName(),
1820 newOp.getStaticUpperBoundAttrName(),
1821 newOp.getStaticStepAttrName()};
1822 for (
const auto &namedAttr : op->getAttrs()) {
1823 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1826 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1830 newOp.getRegion().begin(), mapping);
1831 rewriter.
replaceOp(op, newOp.getResults());
1837 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1840 LogicalResult matchAndRewrite(ForallOp op,
1844 for (
auto [lb, ub, step, iv] :
1845 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1846 op.getMixedStep(), op.getInductionVars())) {
1849 auto numIterations =
1851 if (!numIterations.has_value() || numIterations.value() != 1) {
1862 struct FoldTensorCastOfOutputIntoForallOp
1871 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1873 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1876 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1883 castOp.getSource().getType())) {
1887 tensorCastProducers[en.index()] =
1888 TypeCast{castOp.getSource().getType(), castOp.getType()};
1889 newOutputTensors[en.index()] = castOp.getSource();
1892 if (tensorCastProducers.empty())
1897 auto newForallOp = ForallOp::create(
1898 rewriter, loc, forallOp.getMixedLowerBound(),
1899 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1900 newOutputTensors, forallOp.getMapping(),
1902 auto castBlockArgs =
1903 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1904 for (auto [index, cast] : tensorCastProducers) {
1905 Value &oldTypeBBArg = castBlockArgs[index];
1906 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1907 cast.dstType, oldTypeBBArg);
1912 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1913 ivsBlockArgs.append(castBlockArgs);
1915 bbArgs.front().getParentBlock(), ivsBlockArgs);
1921 auto terminator = newForallOp.getTerminator();
1922 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1923 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1924 if (
auto parallelCombingingOp =
1925 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1926 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1933 for (
auto &item : tensorCastProducers) {
1934 Value &oldTypeResult = castResults[item.first];
1935 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1938 rewriter.
replaceOp(forallOp, castResults);
1947 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1948 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1949 ForallOpSingleOrZeroIterationDimsFolder,
1950 ForallOpReplaceConstantInductionVar>(context);
1981 scf::ForallOp forallOp =
1982 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1984 return this->emitOpError(
"expected forall op parent");
1986 for (
Operation &op : getRegion().front().getOperations()) {
1987 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1988 if (!parallelCombiningOp) {
1989 return this->emitOpError(
"expected only ParallelCombiningOpInterface")
1997 if (!llvm::is_contained(regionOutArgs, dest.get()))
1998 return op.emitOpError(
"may only insert into an output block argument");
2017 std::unique_ptr<Region> region = std::make_unique<Region>();
2021 if (region->empty())
2031 OpResult InParallelOp::getParentResult(int64_t idx) {
2032 return getOperation()->getParentOp()->getResult(idx);
2037 for (
Operation &yieldingOp : getYieldingOps()) {
2038 auto parallelCombiningOp =
2039 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2040 if (!parallelCombiningOp)
2043 parallelCombiningOp.getUpdatedDestinations())
2044 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2046 return updatedDests;
2050 return getRegion().front().getOperations();
2058 assert(a &&
"expected non-empty operation");
2059 assert(b &&
"expected non-empty operation");
2064 if (ifOp->isProperAncestor(b))
2067 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2068 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2070 ifOp = ifOp->getParentOfType<IfOp>();
2078 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2079 IfOp::Adaptor adaptor,
2081 if (adaptor.getRegions().empty())
2083 Region *r = &adaptor.getThenRegion();
2089 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2092 TypeRange types = yieldOp.getOperandTypes();
2093 llvm::append_range(inferredReturnTypes, types);
2099 return build(builder, result, resultTypes, cond,
false,
2105 bool addElseBlock) {
2106 assert((!addElseBlock || addThenBlock) &&
2107 "must not create else block w/o then block");
2122 bool withElseRegion) {
2123 build(builder, result,
TypeRange{}, cond, withElseRegion);
2135 if (resultTypes.empty())
2136 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2140 if (withElseRegion) {
2142 if (resultTypes.empty())
2143 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2150 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2157 thenBuilder(builder, result.
location);
2163 elseBuilder(builder, result.
location);
2170 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2172 inferredReturnTypes))) {
2173 result.
addTypes(inferredReturnTypes);
2178 if (getNumResults() != 0 && getElseRegion().empty())
2179 return emitOpError(
"must have an else block if defining values");
2217 bool printBlockTerminators =
false;
2219 p <<
" " << getCondition();
2220 if (!getResults().empty()) {
2221 p <<
" -> (" << getResultTypes() <<
")";
2223 printBlockTerminators =
true;
2228 printBlockTerminators);
2231 auto &elseRegion = getElseRegion();
2232 if (!elseRegion.
empty()) {
2236 printBlockTerminators);
2253 Region *elseRegion = &this->getElseRegion();
2254 if (elseRegion->
empty())
2262 FoldAdaptor adaptor(operands, *
this);
2263 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2264 if (!boolAttr || boolAttr.getValue())
2265 regions.emplace_back(&getThenRegion());
2268 if (!boolAttr || !boolAttr.getValue()) {
2269 if (!getElseRegion().empty())
2270 regions.emplace_back(&getElseRegion());
2272 regions.emplace_back(getResults());
2276 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2279 if (getElseRegion().empty())
2282 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2289 getConditionMutable().assign(xorStmt.getLhs());
2293 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2294 getElseRegion().getBlocks());
2295 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2296 getThenRegion().getBlocks(), thenBlock);
2300 void IfOp::getRegionInvocationBounds(
2303 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2306 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2307 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2310 invocationBounds.assign(2, {0, 1});
2326 llvm::transform(usedResults, std::back_inserter(usedOperands),
2331 [&]() { yieldOp->setOperands(usedOperands); });
2338 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2339 [](
OpResult result) { return !result.use_empty(); });
2342 if (usedResults.size() == op.getNumResults())
2347 llvm::transform(usedResults, std::back_inserter(newTypes),
2352 IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2358 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2359 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2364 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2381 else if (!op.getElseRegion().empty())
2397 if (op->getNumResults() == 0)
2400 auto cond = op.getCondition();
2401 auto thenYieldArgs = op.thenYield().getOperands();
2402 auto elseYieldArgs = op.elseYield().getOperands();
2405 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2406 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2407 &op.getElseRegion() == falseVal.getParentRegion())
2408 nonHoistable.push_back(trueVal.getType());
2412 if (nonHoistable.size() == op->getNumResults())
2415 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2417 if (replacement.thenBlock())
2418 rewriter.
eraseBlock(replacement.thenBlock());
2419 replacement.getThenRegion().takeBody(op.getThenRegion());
2420 replacement.getElseRegion().takeBody(op.getElseRegion());
2423 assert(thenYieldArgs.size() == results.size());
2424 assert(elseYieldArgs.size() == results.size());
2429 for (
const auto &it :
2431 Value trueVal = std::get<0>(it.value());
2432 Value falseVal = std::get<1>(it.value());
2435 results[it.index()] = replacement.getResult(trueYields.size());
2436 trueYields.push_back(trueVal);
2437 falseYields.push_back(falseVal);
2438 }
else if (trueVal == falseVal)
2439 results[it.index()] = trueVal;
2441 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2442 cond, trueVal, falseVal);
2484 Value constantTrue =
nullptr;
2485 Value constantFalse =
nullptr;
2488 llvm::make_early_inc_range(op.getCondition().getUses())) {
2493 constantTrue = rewriter.
create<arith::ConstantOp>(
2497 [&]() { use.
set(constantTrue); });
2498 }
else if (op.getElseRegion().isAncestor(
2503 constantFalse = rewriter.
create<arith::ConstantOp>(
2507 [&]() { use.
set(constantFalse); });
2551 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2557 if (op.getNumResults() == 0)
2561 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2563 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2566 op.getOperation()->getIterator());
2569 for (
auto [trueResult, falseResult, opResult] :
2570 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2572 if (trueResult == falseResult) {
2573 if (!opResult.use_empty()) {
2574 opResult.replaceAllUsesWith(trueResult);
2585 bool trueVal = trueYield.
getValue();
2586 bool falseVal = falseYield.
getValue();
2587 if (!trueVal && falseVal) {
2588 if (!opResult.use_empty()) {
2589 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2590 Value notCond = arith::XOrIOp::create(
2591 rewriter, op.getLoc(), op.getCondition(),
2597 opResult.replaceAllUsesWith(notCond);
2601 if (trueVal && !falseVal) {
2602 if (!opResult.use_empty()) {
2603 opResult.replaceAllUsesWith(op.getCondition());
2638 Block *parent = nextIf->getBlock();
2639 if (nextIf == &parent->
front())
2642 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2650 Block *nextThen =
nullptr;
2651 Block *nextElse =
nullptr;
2652 if (nextIf.getCondition() == prevIf.getCondition()) {
2653 nextThen = nextIf.thenBlock();
2654 if (!nextIf.getElseRegion().empty())
2655 nextElse = nextIf.elseBlock();
2657 if (arith::XOrIOp notv =
2658 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2659 if (notv.getLhs() == prevIf.getCondition() &&
2661 nextElse = nextIf.thenBlock();
2662 if (!nextIf.getElseRegion().empty())
2663 nextThen = nextIf.elseBlock();
2666 if (arith::XOrIOp notv =
2667 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2668 if (notv.getLhs() == nextIf.getCondition() &&
2670 nextElse = nextIf.thenBlock();
2671 if (!nextIf.getElseRegion().empty())
2672 nextThen = nextIf.elseBlock();
2676 if (!nextThen && !nextElse)
2680 if (!prevIf.getElseRegion().empty())
2681 prevElseYielded = prevIf.elseYield().getOperands();
2684 for (
auto it : llvm::zip(prevIf.getResults(),
2685 prevIf.thenYield().getOperands(), prevElseYielded))
2687 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2691 use.
set(std::get<1>(it));
2696 use.
set(std::get<2>(it));
2702 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2704 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2705 prevIf.getCondition(),
false);
2706 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2709 combinedIf.getThenRegion(),
2710 combinedIf.getThenRegion().begin());
2713 YieldOp thenYield = combinedIf.thenYield();
2714 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2715 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2719 llvm::append_range(mergedYields, thenYield2.getOperands());
2720 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2726 combinedIf.getElseRegion(),
2727 combinedIf.getElseRegion().begin());
2730 if (combinedIf.getElseRegion().empty()) {
2732 combinedIf.getElseRegion(),
2733 combinedIf.getElseRegion().
begin());
2735 YieldOp elseYield = combinedIf.elseYield();
2736 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2737 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2742 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2744 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2753 if (pair.index() < prevIf.getNumResults())
2754 prevValues.push_back(pair.value());
2756 nextValues.push_back(pair.value());
2771 if (ifOp.getNumResults())
2773 Block *elseBlock = ifOp.elseBlock();
2774 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2778 newIfOp.getThenRegion().begin());
2805 auto nestedOps = op.thenBlock()->without_terminator();
2807 if (!llvm::hasSingleElement(nestedOps))
2811 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2814 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2818 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2824 llvm::append_range(elseYield, op.elseYield().getOperands());
2838 if (tup.value().getDefiningOp() == nestedIf) {
2839 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2840 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2841 elseYield[tup.index()]) {
2846 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2859 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2862 elseYieldsToUpgradeToSelect.push_back(tup.index());
2866 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2867 nestedIf.getCondition());
2868 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2872 llvm::append_range(results, newIf.getResults());
2875 for (
auto idx : elseYieldsToUpgradeToSelect)
2877 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2878 thenYield[idx], elseYield[idx]);
2880 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2883 if (!elseYield.empty()) {
2886 YieldOp::create(rewriter, loc, elseYield);
2897 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2898 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2899 RemoveStaticCondition, RemoveUnusedResults,
2900 ReplaceIfYieldWithConditionOrValue>(context);
2903 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2904 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2905 Block *IfOp::elseBlock() {
2906 Region &r = getElseRegion();
2911 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2917 void ParallelOp::build(
2927 ParallelOp::getOperandSegmentSizeAttr(),
2929 static_cast<int32_t>(upperBounds.size()),
2930 static_cast<int32_t>(steps.size()),
2931 static_cast<int32_t>(initVals.size())}));
2935 unsigned numIVs = steps.size();
2941 if (bodyBuilderFn) {
2943 bodyBuilderFn(builder, result.
location,
2948 if (initVals.empty())
2949 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2952 void ParallelOp::build(
2959 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2962 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2966 wrapper = wrappedBuilderFn;
2968 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2977 if (stepValues.empty())
2979 "needs at least one tuple element for lowerBound, upperBound and step");
2982 for (
Value stepValue : stepValues)
2985 return emitOpError(
"constant step operand must be positive");
2989 Block *body = getBody();
2991 return emitOpError() <<
"expects the same number of induction variables: "
2993 <<
" as bound and step values: " << stepValues.size();
2995 if (!arg.getType().isIndex())
2997 "expects arguments for the induction variable to be of index type");
3000 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
3001 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
3006 auto resultsSize = getResults().size();
3007 auto reductionsSize = reduceOp.getReductions().size();
3008 auto initValsSize = getInitVals().size();
3009 if (resultsSize != reductionsSize)
3010 return emitOpError() <<
"expects number of results: " << resultsSize
3011 <<
" to be the same as number of reductions: "
3013 if (resultsSize != initValsSize)
3014 return emitOpError() <<
"expects number of results: " << resultsSize
3015 <<
" to be the same as number of initial values: "
3019 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3020 auto resultType = getOperation()->getResult(i).getType();
3021 auto reductionOperandType = reduceOp.getOperands()[i].getType();
3022 if (resultType != reductionOperandType)
3023 return reduceOp.emitOpError()
3024 <<
"expects type of " << i
3025 <<
"-th reduction operand: " << reductionOperandType
3026 <<
" to be the same as the " << i
3027 <<
"-th result type: " << resultType;
3043 OpAsmParser::Delimiter::Paren) ||
3050 OpAsmParser::Delimiter::Paren) ||
3058 OpAsmParser::Delimiter::Paren) ||
3075 for (
auto &iv : ivs)
3082 ParallelOp::getOperandSegmentSizeAttr(),
3084 static_cast<int32_t>(upper.size()),
3085 static_cast<int32_t>(steps.size()),
3086 static_cast<int32_t>(initVals.size())}));
3095 ParallelOp::ensureTerminator(*body, builder, result.
location);
3100 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3101 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3102 if (!getInitVals().empty())
3103 p <<
" init (" << getInitVals() <<
")";
3108 (*this)->getAttrs(),
3109 ParallelOp::getOperandSegmentSizeAttr());
3114 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3118 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3122 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3126 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3131 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3133 return ParallelOp();
3134 assert(ivArg.getOwner() &&
"unlinked block argument");
3135 auto *containingOp = ivArg.getOwner()->getParentOp();
3136 return dyn_cast<ParallelOp>(containingOp);
3141 struct ParallelOpSingleOrZeroIterationDimsFolder
3152 for (
auto [lb, ub, step, iv] :
3153 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3154 op.getInductionVars())) {
3155 auto numIterations =
3157 if (numIterations.has_value()) {
3159 if (*numIterations == 0) {
3160 rewriter.
replaceOp(op, op.getInitVals());
3165 if (*numIterations == 1) {
3170 newLowerBounds.push_back(lb);
3171 newUpperBounds.push_back(ub);
3172 newSteps.push_back(step);
3175 if (newLowerBounds.size() == op.getLowerBound().size())
3178 if (newLowerBounds.empty()) {
3182 results.reserve(op.getInitVals().size());
3183 for (
auto &bodyOp : op.getBody()->without_terminator())
3184 rewriter.
clone(bodyOp, mapping);
3185 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3186 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3187 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3188 auto initValIndex = results.size();
3189 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3193 rewriter.
clone(reduceBodyOp, mapping);
3196 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3197 results.push_back(result);
3205 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3206 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3212 newOp.getRegion().begin(), mapping);
3213 rewriter.
replaceOp(op, newOp.getResults());
3223 Block &outerBody = *op.getBody();
3227 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3232 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3233 llvm::is_contained(innerOp.getUpperBound(), val) ||
3234 llvm::is_contained(innerOp.getStep(), val))
3238 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3243 Block &innerBody = *innerOp.getBody();
3244 assert(iterVals.size() ==
3252 builder.
clone(op, mapping);
3255 auto concatValues = [](
const auto &first,
const auto &second) {
3257 ret.reserve(first.size() + second.size());
3258 ret.assign(first.begin(), first.end());
3259 ret.append(second.begin(), second.end());
3263 auto newLowerBounds =
3264 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3265 auto newUpperBounds =
3266 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3267 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3281 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3290 void ParallelOp::getSuccessorRegions(
3308 for (
Value v : operands) {
3317 LogicalResult ReduceOp::verifyRegions() {
3320 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3321 auto type = getOperands()[i].getType();
3324 return emitOpError() << i <<
"-th reduction has an empty body";
3327 return arg.getType() != type;
3329 return emitOpError() <<
"expected two block arguments with type " << type
3330 <<
" in the " << i <<
"-th reduction region";
3334 return emitOpError(
"reduction bodies must be terminated with an "
3335 "'scf.reduce.return' op");
3354 Block *reductionBody = getOperation()->getBlock();
3356 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3358 if (expectedResultType != getResult().
getType())
3359 return emitOpError() <<
"must have type " << expectedResultType
3360 <<
" (the type of the reduction inputs)";
3370 ValueRange inits, BodyBuilderFn beforeBuilder,
3371 BodyBuilderFn afterBuilder) {
3379 beforeArgLocs.reserve(inits.size());
3380 for (
Value operand : inits) {
3381 beforeArgLocs.push_back(operand.getLoc());
3386 inits.getTypes(), beforeArgLocs);
3395 resultTypes, afterArgLocs);
3401 ConditionOp WhileOp::getConditionOp() {
3402 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3405 YieldOp WhileOp::getYieldOp() {
3406 return cast<YieldOp>(getAfterBody()->getTerminator());
3409 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3410 return getYieldOp().getResultsMutable();
3414 return getBeforeBody()->getArguments();
3418 return getAfterBody()->getArguments();
3422 return getBeforeArguments();
3426 assert(point == getBefore() &&
3427 "WhileOp is expected to branch only to the first region");
3435 regions.emplace_back(&getBefore(), getBefore().getArguments());
3439 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3440 "there are only two regions in a WhileOp");
3442 if (point == getAfter()) {
3443 regions.emplace_back(&getBefore(), getBefore().getArguments());
3447 regions.emplace_back(getResults());
3448 regions.emplace_back(&getAfter(), getAfter().getArguments());
3452 return {&getBefore(), &getAfter()};
3473 FunctionType functionType;
3478 result.
addTypes(functionType.getResults());
3480 if (functionType.getNumInputs() != operands.size()) {
3482 <<
"expected as many input types as operands " <<
"(expected "
3483 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3493 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3494 regionArgs[i].type = functionType.getInput(i);
3496 return failure(parser.
parseRegion(*before, regionArgs) ||
3516 template <
typename OpTy>
3519 if (left.size() != right.size())
3520 return op.emitOpError(
"expects the same number of ") << message;
3522 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3523 if (left[i] != right[i]) {
3526 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3527 <<
" and " << right[i];
3536 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3538 "expects the 'before' region to terminate with 'scf.condition'");
3539 if (!beforeTerminator)
3542 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3544 "expects the 'after' region to terminate with 'scf.yield'");
3545 return success(afterTerminator !=
nullptr);
3573 auto term = op.getConditionOp();
3577 Value constantTrue =
nullptr;
3579 bool replaced =
false;
3580 for (
auto yieldedAndBlockArgs :
3581 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3582 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3583 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3585 constantTrue = arith::ConstantOp::create(
3586 rewriter, op.getLoc(), term.getCondition().getType(),
3595 return success(replaced);
3647 struct RemoveLoopInvariantArgsFromBeforeBlock
3653 Block &afterBlock = *op.getAfterBody();
3655 ConditionOp condOp = op.getConditionOp();
3660 bool canSimplify =
false;
3661 for (
const auto &it :
3663 auto index =
static_cast<unsigned>(it.index());
3664 auto [initVal, yieldOpArg] = it.value();
3667 if (yieldOpArg == initVal) {
3676 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3677 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3678 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3679 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3692 for (
const auto &it :
3694 auto index =
static_cast<unsigned>(it.index());
3695 auto [initVal, yieldOpArg] = it.value();
3699 if (yieldOpArg == initVal) {
3700 beforeBlockInitValMap.insert({index, initVal});
3708 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3709 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3710 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3711 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3712 beforeBlockInitValMap.insert({index, initVal});
3717 newInitArgs.emplace_back(initVal);
3718 newYieldOpArgs.emplace_back(yieldOpArg);
3719 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3728 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3732 &newWhile.getBefore(), {},
3733 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3735 Block &beforeBlock = *op.getBeforeBody();
3742 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3745 if (beforeBlockInitValMap.count(i) != 0)
3746 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3748 newBeforeBlockArgs[i] = newBeforeBlock.
getArgument(
j++);
3751 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3753 newWhile.getAfter().begin());
3755 rewriter.
replaceOp(op, newWhile.getResults());
3800 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3805 Block &beforeBlock = *op.getBeforeBody();
3806 ConditionOp condOp = op.getConditionOp();
3809 bool canSimplify =
false;
3810 for (
Value condOpArg : condOpArgs) {
3830 auto index =
static_cast<unsigned>(it.index());
3831 Value condOpArg = it.value();
3836 condOpInitValMap.insert({index, condOpArg});
3838 newCondOpArgs.emplace_back(condOpArg);
3839 newAfterBlockType.emplace_back(condOpArg.
getType());
3840 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3851 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3854 Block &newAfterBlock =
3856 newAfterBlockType, newAfterBlockArgLocs);
3858 Block &afterBlock = *op.getAfterBody();
3865 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3866 Value afterBlockArg, result;
3869 if (condOpInitValMap.count(i) != 0) {
3870 afterBlockArg = condOpInitValMap[i];
3871 result = afterBlockArg;
3874 result = newWhile.getResult(
j);
3877 newAfterBlockArgs[i] = afterBlockArg;
3878 newWhileResults[i] = result;
3881 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3883 newWhile.getBefore().begin());
3885 rewriter.
replaceOp(op, newWhileResults);
3921 auto term = op.getConditionOp();
3922 auto afterArgs = op.getAfterArguments();
3923 auto termArgs = term.getArgs();
3930 bool needUpdate =
false;
3931 for (
const auto &it :
3933 auto i =
static_cast<unsigned>(it.index());
3934 Value result = std::get<0>(it.value());
3935 Value afterArg = std::get<1>(it.value());
3936 Value termArg = std::get<2>(it.value());
3940 newResultsIndices.emplace_back(i);
3941 newTermArgs.emplace_back(termArg);
3942 newResultTypes.emplace_back(result.
getType());
3943 newArgLocs.emplace_back(result.
getLoc());
3958 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
3961 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3968 newResults[it.value()] = newWhile.getResult(it.index());
3969 newAfterBlockArgs[it.value()] = newAfterBlock.
getArgument(it.index());
3973 newWhile.getBefore().begin());
3975 Block &afterBlock = *op.getAfterBody();
3976 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4010 using namespace scf;
4011 auto cond = op.getConditionOp();
4012 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4016 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4017 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
4018 if (std::get<0>(tup) != cmp.getOperand(opIdx))
4021 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4022 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4026 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4029 if (cmp2.getPredicate() == cmp.getPredicate())
4030 samePredicate =
true;
4031 else if (cmp2.getPredicate() ==
4033 samePredicate =
false;
4051 LogicalResult matchAndRewrite(WhileOp op,
4054 if (!llvm::any_of(op.getBeforeArguments(),
4055 [](
Value arg) { return arg.use_empty(); }))
4058 YieldOp yield = op.getYieldOp();
4063 llvm::BitVector argsToErase;
4065 size_t argsCount = op.getBeforeArguments().size();
4066 newYields.reserve(argsCount);
4067 newInits.reserve(argsCount);
4068 argsToErase.reserve(argsCount);
4069 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4070 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4071 if (beforeArg.use_empty()) {
4072 argsToErase.push_back(
true);
4074 argsToErase.push_back(
false);
4075 newYields.emplace_back(yieldValue);
4076 newInits.emplace_back(initValue);
4080 Block &beforeBlock = *op.getBeforeBody();
4081 Block &afterBlock = *op.getAfterBody();
4087 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4089 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4090 Block &newAfterBlock = *newWhileOp.getAfterBody();
4096 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4101 rewriter.
replaceOp(op, newWhileOp.getResults());
4108 using OpRewritePattern::OpRewritePattern;
4110 LogicalResult matchAndRewrite(WhileOp op,
4112 ConditionOp condOp = op.getConditionOp();
4117 if (argsSet.size() == condOpArgs.size())
4120 llvm::SmallDenseMap<Value, unsigned> argsMap;
4122 argsMap.reserve(condOpArgs.size());
4123 newArgs.reserve(condOpArgs.size());
4124 for (
Value arg : condOpArgs) {
4125 if (!argsMap.count(arg)) {
4126 auto pos =
static_cast<unsigned>(argsMap.size());
4127 argsMap.insert({arg, pos});
4128 newArgs.emplace_back(arg);
4136 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4139 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4140 Block &newAfterBlock = *newWhileOp.getAfterBody();
4145 auto it = argsMap.find(arg);
4146 assert(it != argsMap.end());
4147 auto pos = it->second;
4148 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4149 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4157 Block &beforeBlock = *op.getBeforeBody();
4158 Block &afterBlock = *op.getAfterBody();
4160 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4162 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4170 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4172 if (args1.size() != args2.size())
4173 return std::nullopt;
4177 auto it = llvm::find(args2, arg1);
4178 if (it == args2.end())
4179 return std::nullopt;
4181 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4188 llvm::SmallDenseSet<Value> set;
4189 for (
Value arg : args) {
4190 if (!set.insert(arg).second)
4201 using OpRewritePattern::OpRewritePattern;
4203 LogicalResult matchAndRewrite(WhileOp loop,
4205 auto oldBefore = loop.getBeforeBody();
4206 ConditionOp oldTerm = loop.getConditionOp();
4207 ValueRange beforeArgs = oldBefore->getArguments();
4209 if (beforeArgs == termArgs)
4212 if (hasDuplicates(termArgs))
4215 auto mapping = getArgsMapping(beforeArgs, termArgs);
4226 auto oldAfter = loop.getAfterBody();
4230 newResultTypes[
j] = loop.getResult(i).getType();
4232 auto newLoop = WhileOp::create(
4233 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4235 auto newBefore = newLoop.getBeforeBody();
4236 auto newAfter = newLoop.getAfterBody();
4241 newResults[i] = newLoop.getResult(
j);
4242 newAfterArgs[i] = newAfter->getArgument(
j);
4246 newBefore->getArguments());
4258 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4259 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4260 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4261 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4275 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4278 caseValues.push_back(value);
4287 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4289 p <<
"case " << value <<
' ';
4295 if (getCases().size() != getCaseRegions().size()) {
4296 return emitOpError(
"has ")
4297 << getCaseRegions().size() <<
" case regions but "
4298 << getCases().size() <<
" case values";
4302 for (int64_t value : getCases())
4303 if (!valueSet.insert(value).second)
4304 return emitOpError(
"has duplicate case value: ") << value;
4306 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4308 return emitOpError(
"expected region to end with scf.yield, but got ")
4311 if (yield.getNumOperands() != getNumResults()) {
4312 return (emitOpError(
"expected each region to return ")
4313 << getNumResults() <<
" values, but " << name <<
" returns "
4314 << yield.getNumOperands())
4315 .attachNote(yield.getLoc())
4316 <<
"see yield operation here";
4318 for (
auto [idx, result, operand] :
4321 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
4322 if (result == operand.getType())
4324 return (emitOpError(
"expected result #")
4325 << idx <<
" of each region to be " << result)
4326 .attachNote(yield.getLoc())
4327 << name <<
" returns " << operand.getType() <<
" here";
4341 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4343 Block &scf::IndexSwitchOp::getDefaultBlock() {
4344 return getDefaultRegion().
front();
4347 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4348 assert(idx < getNumCases() &&
"case index out-of-bounds");
4349 return getCaseRegions()[idx].front();
4352 void IndexSwitchOp::getSuccessorRegions(
4356 successors.emplace_back(getResults());
4360 llvm::append_range(successors, getRegions());
4363 void IndexSwitchOp::getEntrySuccessorRegions(
4366 FoldAdaptor adaptor(operands, *
this);
4369 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4371 llvm::append_range(successors, getRegions());
4377 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4378 if (caseValue == arg.getInt()) {
4379 successors.emplace_back(&caseRegion);
4383 successors.emplace_back(&getDefaultRegion());
4386 void IndexSwitchOp::getRegionInvocationBounds(
4388 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4389 if (!operandValue) {
4395 unsigned liveIndex = getNumRegions() - 1;
4396 const auto *it = llvm::find(getCases(), operandValue.getInt());
4397 if (it != getCases().end())
4398 liveIndex = std::distance(getCases().begin(), it);
4399 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4400 bounds.emplace_back(0, i == liveIndex);
4411 if (!maybeCst.has_value())
4413 int64_t cst = *maybeCst;
4414 int64_t caseIdx, e = op.getNumCases();
4415 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4416 if (cst == op.getCases()[caseIdx])
4420 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4421 : op.getDefaultRegion();
4445 #define GET_OP_CLASSES
4446 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.