26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
77 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78 InParallelOp, ReduceReturnOp>();
79 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81 ForallOp, InParallelOp, WhileOp, YieldOp>();
82 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
87 scf::YieldOp::create(builder, loc);
92 template <
typename TerminatorTy>
94 StringRef errorMessage) {
97 terminatorOperation = ®ion.
front().
back();
98 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
102 if (terminatorOperation)
103 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
115 assert(region.
hasOneBlock() &&
"expected single-block region");
164 if (getRegion().empty())
165 return emitOpError(
"region needs to have at least one block");
166 if (getRegion().front().getNumArguments() > 0)
167 return emitOpError(
"region cannot have any arguments");
190 if (!op.getRegion().hasOneBlock() || op.getNoInline())
239 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
242 Block *prevBlock = op->getBlock();
246 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
248 for (
Block &blk : op.getRegion()) {
249 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
251 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
252 yieldOp.getResults());
260 for (
auto res : op.getResults())
261 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
273 void ExecuteRegionOp::getSuccessorRegions(
291 assert((point.
isParent() || point == getParentOp().getAfter()) &&
292 "condition op can only exit the loop or branch to the after"
295 return getArgsMutable();
298 void ConditionOp::getSuccessorRegions(
300 FoldAdaptor adaptor(operands, *
this);
302 WhileOp whileOp = getParentOp();
306 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
307 if (!boolAttr || boolAttr.getValue())
308 regions.emplace_back(&whileOp.getAfter(),
309 whileOp.getAfter().getArguments());
310 if (!boolAttr || !boolAttr.getValue())
311 regions.emplace_back(whileOp.getResults());
320 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
328 for (
Value v : initArgs)
334 for (
Value v : initArgs)
340 if (initArgs.empty() && !bodyBuilder) {
341 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
342 }
else if (bodyBuilder) {
352 if (getInitArgs().size() != getNumResults())
354 "mismatch in number of loop-carried values and defined values");
359 LogicalResult ForOp::verifyRegions() {
364 "expected induction variable to be same type as bounds and step");
366 if (getNumRegionIterArgs() != getNumResults())
368 "mismatch in number of basic block args and defined values");
370 auto initArgs = getInitArgs();
371 auto iterArgs = getRegionIterArgs();
372 auto opResults = getResults();
374 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
376 return emitOpError() <<
"types mismatch between " << i
377 <<
"th iter operand and defined value";
379 return emitOpError() <<
"types mismatch between " << i
380 <<
"th iter region arg and defined value";
387 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
391 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
395 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
399 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
403 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
408 std::optional<int64_t> tripCount =
410 if (!tripCount.has_value() || tripCount != 1)
414 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
421 llvm::append_range(bbArgReplacements, getInitArgs());
425 getOperation()->getIterator(), bbArgReplacements);
441 StringRef prefix =
"") {
442 assert(blocksArgs.size() == initializers.size() &&
443 "expected same length of arguments and initializers");
444 if (initializers.empty())
448 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
449 p << std::get<0>(it) <<
" = " << std::get<1>(it);
455 if (getUnsignedCmp())
458 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
462 if (!getInitArgs().empty())
463 p <<
" -> (" << getInitArgs().getTypes() <<
')';
466 p <<
" : " << t <<
' ';
469 !getInitArgs().empty());
471 getUnsignedCmpAttrName().strref());
496 regionArgs.push_back(inductionVariable);
506 if (regionArgs.size() != result.
types.size() + 1)
509 "mismatch in number of loop-carried values and defined values");
518 regionArgs.front().type = type;
519 for (
auto [iterArg, type] :
520 llvm::zip_equal(llvm::drop_begin(regionArgs), result.
types))
527 ForOp::ensureTerminator(*body, builder, result.
location);
536 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
537 operands, result.
types)) {
538 Type type = std::get<2>(argOperandType);
539 std::get<0>(argOperandType).type = type;
556 return getBody()->getArguments().drop_front(getNumInductionVars());
560 return getInitArgsMutable();
563 FailureOr<LoopLikeOpInterface>
564 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
566 bool replaceInitOperandUsesInLoop,
571 auto inits = llvm::to_vector(getInitArgs());
572 inits.append(newInitOperands.begin(), newInitOperands.end());
573 scf::ForOp newLoop = scf::ForOp::create(
579 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
581 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
586 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
587 assert(newInitOperands.size() == newYieldedValues.size() &&
588 "expected as many new yield values as new iter operands");
590 yieldOp.getResultsMutable().append(newYieldedValues);
596 newLoop.getBody()->getArguments().take_front(
597 getBody()->getNumArguments()));
599 if (replaceInitOperandUsesInLoop) {
602 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
613 newLoop->getResults().take_front(getNumResults()));
614 return cast<LoopLikeOpInterface>(newLoop.getOperation());
618 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
621 assert(ivArg.getOwner() &&
"unlinked block argument");
622 auto *containingOp = ivArg.getOwner()->getParentOp();
623 return dyn_cast_or_null<ForOp>(containingOp);
627 return getInitArgs();
644 for (
auto [lb, ub, step] :
645 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
647 if (!tripCount.has_value() || *tripCount != 1)
656 return getBody()->getArguments().drop_front(getRank());
660 return getOutputsMutable();
666 scf::InParallelOp terminator = forallOp.getTerminator();
671 bbArgReplacements.append(forallOp.getOutputs().begin(),
672 forallOp.getOutputs().end());
676 forallOp->getIterator(), bbArgReplacements);
681 results.reserve(forallOp.getResults().size());
682 for (
auto &yieldingOp : terminator.getYieldingOps()) {
683 auto parallelInsertSliceOp =
684 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
686 Value dst = parallelInsertSliceOp.getDest();
687 Value src = parallelInsertSliceOp.getSource();
688 if (llvm::isa<TensorType>(src.
getType())) {
689 results.push_back(tensor::InsertSliceOp::create(
690 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
691 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
692 parallelInsertSliceOp.getStrides(),
693 parallelInsertSliceOp.getStaticOffsets(),
694 parallelInsertSliceOp.getStaticSizes(),
695 parallelInsertSliceOp.getStaticStrides()));
697 llvm_unreachable(
"unsupported terminator");
712 assert(lbs.size() == ubs.size() &&
713 "expected the same number of lower and upper bounds");
714 assert(lbs.size() == steps.size() &&
715 "expected the same number of lower bounds and steps");
720 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
722 assert(results.size() == iterArgs.size() &&
723 "loop nest body must return as many values as loop has iteration "
725 return LoopNest{{}, std::move(results)};
733 loops.reserve(lbs.size());
734 ivs.reserve(lbs.size());
737 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
738 auto loop = scf::ForOp::create(
739 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
745 currentIterArgs = args;
746 currentLoc = nestedLoc;
752 loops.push_back(loop);
756 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
758 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
765 ? bodyBuilder(builder, currentLoc, ivs,
766 loops.back().getRegionIterArgs())
768 assert(results.size() == iterArgs.size() &&
769 "loop nest body must return as many values as loop has iteration "
772 scf::YieldOp::create(builder, loc, results);
776 llvm::append_range(nestResults, loops.front().getResults());
777 return LoopNest{std::move(loops), std::move(nestResults)};
790 bodyBuilder(nestedBuilder, nestedLoc, ivs);
799 assert(operand.
getOwner() == forOp);
804 "expected an iter OpOperand");
806 "Expected a different type");
808 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
810 newIterOperands.push_back(replacement);
813 newIterOperands.push_back(opOperand.get());
817 scf::ForOp newForOp = scf::ForOp::create(
818 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
819 forOp.getStep(), newIterOperands,
nullptr,
820 forOp.getUnsignedCmp());
821 newForOp->setAttrs(forOp->getAttrs());
822 Block &newBlock = newForOp.getRegion().
front();
830 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
832 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
833 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
837 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
840 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
843 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
844 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
845 clonedYieldOp.getOperand(yieldIdx));
847 newYieldOperands[yieldIdx] = castOut;
848 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
849 rewriter.
eraseOp(clonedYieldOp);
854 newResults[yieldIdx] =
855 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
875 LogicalResult matchAndRewrite(scf::ForOp forOp,
877 bool canonicalize =
false;
884 int64_t numResults = forOp.getNumResults();
886 keepMask.reserve(numResults);
889 newBlockTransferArgs.reserve(1 + numResults);
890 newBlockTransferArgs.push_back(
Value());
891 newIterArgs.reserve(forOp.getInitArgs().size());
892 newYieldValues.reserve(numResults);
893 newResultValues.reserve(numResults);
895 for (
auto [init, arg, result, yielded] :
896 llvm::zip(forOp.getInitArgs(),
897 forOp.getRegionIterArgs(),
899 forOp.getYieldedValues()
906 bool forwarded = (arg == yielded) || (init == yielded) ||
907 (arg.use_empty() && result.use_empty());
910 keepMask.push_back(
false);
911 newBlockTransferArgs.push_back(init);
912 newResultValues.push_back(init);
918 if (
auto it = initYieldToArg.find({init, yielded});
919 it != initYieldToArg.end()) {
921 keepMask.push_back(
false);
922 auto [sameArg, sameResult] = it->second;
926 newBlockTransferArgs.push_back(init);
927 newResultValues.push_back(init);
932 initYieldToArg.insert({{init, yielded}, {arg, result}});
933 keepMask.push_back(
true);
934 newIterArgs.push_back(init);
935 newYieldValues.push_back(yielded);
936 newBlockTransferArgs.push_back(
Value());
937 newResultValues.push_back(
Value());
943 scf::ForOp newForOp =
944 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
945 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
946 nullptr, forOp.getUnsignedCmp());
947 newForOp->setAttrs(forOp->getAttrs());
948 Block &newBlock = newForOp.getRegion().
front();
952 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
954 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
955 Value &newResultVal = newResultValues[idx];
956 assert((blockTransferArg && newResultVal) ||
957 (!blockTransferArg && !newResultVal));
958 if (!blockTransferArg) {
959 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
960 newResultVal = newForOp.getResult(collapsedIdx++);
966 "unexpected argument size mismatch");
971 if (newIterArgs.empty()) {
972 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
975 rewriter.
replaceOp(forOp, newResultValues);
980 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
984 filteredOperands.reserve(newResultValues.size());
985 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
987 filteredOperands.push_back(mergedTerminator.getOperand(idx));
988 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
992 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
993 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
994 cloneFilteredTerminator(mergedYieldOp);
995 rewriter.
eraseOp(mergedYieldOp);
996 rewriter.
replaceOp(forOp, newResultValues);
1004 static std::optional<APInt> computeConstDiff(
Value l,
Value u) {
1005 IntegerAttr clb, cub;
1007 llvm::APInt lbValue = clb.getValue();
1008 llvm::APInt ubValue = cub.getValue();
1009 return ubValue - lbValue;
1019 return std::nullopt;
1028 LogicalResult matchAndRewrite(ForOp op,
1032 if (op.getLowerBound() == op.getUpperBound()) {
1033 rewriter.
replaceOp(op, op.getInitArgs());
1037 std::optional<APInt> diff =
1038 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1043 bool zeroOrLessIterations =
1044 diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
1045 if (zeroOrLessIterations) {
1046 rewriter.
replaceOp(op, op.getInitArgs());
1050 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1051 if (!maybeStepValue)
1056 llvm::APInt stepValue = *maybeStepValue;
1057 if (stepValue.sge(*diff)) {
1059 blockArgs.reserve(op.getInitArgs().size() + 1);
1060 blockArgs.push_back(op.getLowerBound());
1061 llvm::append_range(blockArgs, op.getInitArgs());
1068 if (!llvm::hasSingleElement(block))
1072 if (llvm::any_of(op.getYieldedValues(),
1073 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1075 rewriter.
replaceOp(op, op.getYieldedValues());
1109 LogicalResult matchAndRewrite(ForOp op,
1111 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1112 OpOperand &iterOpOperand = std::get<0>(it);
1114 if (!incomingCast ||
1115 incomingCast.getSource().getType() == incomingCast.getType())
1120 incomingCast.getDest().getType(),
1121 incomingCast.getSource().getType()))
1123 if (!std::get<1>(it).hasOneUse())
1129 rewriter, op, iterOpOperand, incomingCast.getSource(),
1131 return tensor::CastOp::create(b, loc, type, source);
1143 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1147 std::optional<APInt> ForOp::getConstantStep() {
1150 return step.getValue();
1154 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1155 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1161 if (
auto constantStep = getConstantStep())
1162 if (*constantStep == 1)
1175 unsigned numLoops = getRank();
1177 if (getNumResults() != getOutputs().size())
1178 return emitOpError(
"produces ")
1179 << getNumResults() <<
" results, but has only "
1180 << getOutputs().size() <<
" outputs";
1183 auto *body = getBody();
1185 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1186 for (int64_t i = 0; i < numLoops; ++i)
1188 return emitOpError(
"expects ")
1189 << i <<
"-th block argument to be an index";
1190 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1192 return emitOpError(
"type mismatch between ")
1193 << i <<
"-th output and corresponding block argument";
1194 if (getMapping().has_value() && !getMapping()->empty()) {
1195 if (getDeviceMappingAttrs().size() != numLoops)
1196 return emitOpError() <<
"mapping attribute size must match op rank";
1197 if (
failed(getDeviceMaskingAttr()))
1199 <<
" supports at most one device masking attribute";
1205 getStaticLowerBound(),
1206 getDynamicLowerBound())))
1209 getStaticUpperBound(),
1210 getDynamicUpperBound())))
1213 getStaticStep(), getDynamicStep())))
1221 p <<
" (" << getInductionVars();
1222 if (isNormalized()) {
1243 if (!getRegionOutArgs().empty())
1244 p <<
"-> (" << getResultTypes() <<
") ";
1245 p.printRegion(getRegion(),
1247 getNumResults() > 0);
1248 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1249 getStaticLowerBoundAttrName(),
1250 getStaticUpperBoundAttrName(),
1251 getStaticStepAttrName()});
1256 auto indexType = b.getIndexType();
1276 unsigned numLoops = ivs.size();
1311 if (outOperands.size() != result.
types.size())
1313 "mismatch between out operands and types");
1323 std::unique_ptr<Region> region = std::make_unique<Region>();
1324 for (
auto &iv : ivs) {
1325 iv.type = b.getIndexType();
1326 regionArgs.push_back(iv);
1329 auto &out = it.value();
1330 out.type = result.
types[it.index()];
1331 regionArgs.push_back(out);
1337 ForallOp::ensureTerminator(*region, b, result.
location);
1349 {static_cast<int32_t>(dynamicLbs.size()),
1350 static_cast<int32_t>(dynamicUbs.size()),
1351 static_cast<int32_t>(dynamicSteps.size()),
1352 static_cast<int32_t>(outOperands.size())}));
1357 void ForallOp::build(
1361 std::optional<ArrayAttr> mapping,
1382 "operandSegmentSizes",
1384 static_cast<int32_t>(dynamicUbs.size()),
1385 static_cast<int32_t>(dynamicSteps.size()),
1386 static_cast<int32_t>(outputs.size())}));
1387 if (mapping.has_value()) {
1406 if (!bodyBuilderFn) {
1407 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1414 void ForallOp::build(
1417 std::optional<ArrayAttr> mapping,
1419 unsigned numLoops = ubs.size();
1422 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1426 bool ForallOp::isNormalized() {
1430 return intValue.has_value() && intValue == val;
1433 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1436 InParallelOp ForallOp::getTerminator() {
1437 return cast<InParallelOp>(getBody()->getTerminator());
1442 InParallelOp inParallelOp = getTerminator();
1443 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1444 if (
auto parallelInsertSliceOp =
1445 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1446 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1447 storeOps.push_back(parallelInsertSliceOp);
1457 for (
auto attr : getMapping()->getValue()) {
1458 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1465 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1466 DeviceMaskingAttrInterface res;
1469 for (
auto attr : getMapping()->getValue()) {
1470 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1479 bool ForallOp::usesLinearMapping() {
1483 return ifaces.front().isLinearMapping();
1486 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1491 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1493 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1497 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1499 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1503 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1509 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1512 assert(tidxArg.getOwner() &&
"unlinked block argument");
1513 auto *containingOp = tidxArg.getOwner()->getParentOp();
1514 return dyn_cast<ForallOp>(containingOp);
1522 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1524 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1528 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1531 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1540 LogicalResult matchAndRewrite(ForallOp op,
1555 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1556 op.setStaticLowerBound(staticLowerBound);
1560 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1561 op.setStaticUpperBound(staticUpperBound);
1564 op.getDynamicStepMutable().assign(dynamicStep);
1565 op.setStaticStep(staticStep);
1567 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1569 {static_cast<int32_t>(dynamicLowerBound.size()),
1570 static_cast<int32_t>(dynamicUpperBound.size()),
1571 static_cast<int32_t>(dynamicStep.size()),
1572 static_cast<int32_t>(op.getNumResults())}));
1654 LogicalResult matchAndRewrite(ForallOp forallOp,
1673 for (
OpResult result : forallOp.getResults()) {
1674 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1675 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1676 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1677 resultToDelete.insert(result);
1679 resultToReplace.push_back(result);
1680 newOuts.push_back(opOperand->
get());
1686 if (resultToDelete.empty())
1694 for (
OpResult result : resultToDelete) {
1695 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1696 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1698 forallOp.getCombiningOps(blockArg);
1699 for (
Operation *combiningOp : combiningOps)
1700 rewriter.
eraseOp(combiningOp);
1705 auto newForallOp = scf::ForallOp::create(
1706 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1707 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1708 forallOp.getMapping(),
1713 Block *loopBody = forallOp.getBody();
1714 Block *newLoopBody = newForallOp.getBody();
1719 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1726 for (
OpResult result : forallOp.getResults()) {
1727 if (resultToDelete.count(result)) {
1728 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1730 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1733 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1737 for (
auto &&[oldResult, newResult] :
1738 llvm::zip(resultToReplace, newForallOp->getResults()))
1744 for (
OpResult oldResult : resultToDelete)
1746 forallOp.getTiedOpOperand(oldResult)->get());
1751 struct ForallOpSingleOrZeroIterationDimsFolder
1755 LogicalResult matchAndRewrite(ForallOp op,
1758 if (op.getMapping().has_value() && !op.getMapping()->empty())
1766 for (
auto [lb, ub, step, iv] :
1767 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1768 op.getMixedStep(), op.getInductionVars())) {
1770 if (numIterations.has_value()) {
1772 if (*numIterations == 0) {
1773 rewriter.
replaceOp(op, op.getOutputs());
1778 if (*numIterations == 1) {
1783 newMixedLowerBounds.push_back(lb);
1784 newMixedUpperBounds.push_back(ub);
1785 newMixedSteps.push_back(step);
1789 if (newMixedLowerBounds.empty()) {
1795 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1797 op,
"no dimensions have 0 or 1 iterations");
1802 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1803 newMixedUpperBounds, newMixedSteps,
1804 op.getOutputs(), std::nullopt,
nullptr);
1805 newOp.getBodyRegion().getBlocks().clear();
1810 newOp.getStaticLowerBoundAttrName(),
1811 newOp.getStaticUpperBoundAttrName(),
1812 newOp.getStaticStepAttrName()};
1813 for (
const auto &namedAttr : op->getAttrs()) {
1814 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1817 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1821 newOp.getRegion().begin(), mapping);
1822 rewriter.
replaceOp(op, newOp.getResults());
1828 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1831 LogicalResult matchAndRewrite(ForallOp op,
1835 for (
auto [lb, ub, step, iv] :
1836 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1837 op.getMixedStep(), op.getInductionVars())) {
1841 if (!numIterations.has_value() || numIterations.value() != 1) {
1852 struct FoldTensorCastOfOutputIntoForallOp
1861 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1863 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1866 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1873 castOp.getSource().getType())) {
1877 tensorCastProducers[en.index()] =
1878 TypeCast{castOp.getSource().getType(), castOp.getType()};
1879 newOutputTensors[en.index()] = castOp.getSource();
1882 if (tensorCastProducers.empty())
1887 auto newForallOp = ForallOp::create(
1888 rewriter, loc, forallOp.getMixedLowerBound(),
1889 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1890 newOutputTensors, forallOp.getMapping(),
1892 auto castBlockArgs =
1893 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1894 for (auto [index, cast] : tensorCastProducers) {
1895 Value &oldTypeBBArg = castBlockArgs[index];
1896 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1897 cast.dstType, oldTypeBBArg);
1902 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1903 ivsBlockArgs.append(castBlockArgs);
1905 bbArgs.front().getParentBlock(), ivsBlockArgs);
1911 auto terminator = newForallOp.getTerminator();
1912 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1913 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1914 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1915 insertSliceOp.getDestMutable().assign(outputBlockArg);
1921 for (
auto &item : tensorCastProducers) {
1922 Value &oldTypeResult = castResults[item.first];
1923 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1926 rewriter.
replaceOp(forallOp, castResults);
1935 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1936 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1937 ForallOpSingleOrZeroIterationDimsFolder,
1938 ForallOpReplaceConstantInductionVar>(context);
1969 scf::ForallOp forallOp =
1970 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1972 return this->emitOpError(
"expected forall op parent");
1975 for (
Operation &op : getRegion().front().getOperations()) {
1976 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1977 return this->emitOpError(
"expected only ")
1978 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1982 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1984 if (!llvm::is_contained(regionOutArgs, dest))
1985 return op.emitOpError(
"may only insert into an output block argument");
2002 std::unique_ptr<Region> region = std::make_unique<Region>();
2006 if (region->empty())
2016 OpResult InParallelOp::getParentResult(int64_t idx) {
2017 return getOperation()->getParentOp()->getResult(idx);
2021 return llvm::to_vector<4>(
2022 llvm::map_range(getYieldingOps(), [](
Operation &op) {
2024 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
2025 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
2030 return getRegion().front().getOperations();
2038 assert(a &&
"expected non-empty operation");
2039 assert(b &&
"expected non-empty operation");
2044 if (ifOp->isProperAncestor(b))
2047 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2048 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2050 ifOp = ifOp->getParentOfType<IfOp>();
2058 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2059 IfOp::Adaptor adaptor,
2061 if (adaptor.getRegions().empty())
2063 Region *r = &adaptor.getThenRegion();
2069 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2072 TypeRange types = yieldOp.getOperandTypes();
2073 llvm::append_range(inferredReturnTypes, types);
2079 return build(builder, result, resultTypes, cond,
false,
2085 bool addElseBlock) {
2086 assert((!addElseBlock || addThenBlock) &&
2087 "must not create else block w/o then block");
2102 bool withElseRegion) {
2103 build(builder, result,
TypeRange{}, cond, withElseRegion);
2115 if (resultTypes.empty())
2116 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2120 if (withElseRegion) {
2122 if (resultTypes.empty())
2123 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2130 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2137 thenBuilder(builder, result.
location);
2143 elseBuilder(builder, result.
location);
2150 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2152 inferredReturnTypes))) {
2153 result.
addTypes(inferredReturnTypes);
2158 if (getNumResults() != 0 && getElseRegion().empty())
2159 return emitOpError(
"must have an else block if defining values");
2197 bool printBlockTerminators =
false;
2199 p <<
" " << getCondition();
2200 if (!getResults().empty()) {
2201 p <<
" -> (" << getResultTypes() <<
")";
2203 printBlockTerminators =
true;
2208 printBlockTerminators);
2211 auto &elseRegion = getElseRegion();
2212 if (!elseRegion.
empty()) {
2216 printBlockTerminators);
2233 Region *elseRegion = &this->getElseRegion();
2234 if (elseRegion->
empty())
2242 FoldAdaptor adaptor(operands, *
this);
2243 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2244 if (!boolAttr || boolAttr.getValue())
2245 regions.emplace_back(&getThenRegion());
2248 if (!boolAttr || !boolAttr.getValue()) {
2249 if (!getElseRegion().empty())
2250 regions.emplace_back(&getElseRegion());
2252 regions.emplace_back(getResults());
2256 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2259 if (getElseRegion().empty())
2262 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2269 getConditionMutable().assign(xorStmt.getLhs());
2273 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2274 getElseRegion().getBlocks());
2275 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2276 getThenRegion().getBlocks(), thenBlock);
2280 void IfOp::getRegionInvocationBounds(
2283 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2286 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2287 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2290 invocationBounds.assign(2, {0, 1});
2306 llvm::transform(usedResults, std::back_inserter(usedOperands),
2311 [&]() { yieldOp->setOperands(usedOperands); });
2318 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2319 [](
OpResult result) { return !result.use_empty(); });
2322 if (usedResults.size() == op.getNumResults())
2327 llvm::transform(usedResults, std::back_inserter(newTypes),
2332 IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2338 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2339 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2344 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2361 else if (!op.getElseRegion().empty())
2377 if (op->getNumResults() == 0)
2380 auto cond = op.getCondition();
2381 auto thenYieldArgs = op.thenYield().getOperands();
2382 auto elseYieldArgs = op.elseYield().getOperands();
2385 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2386 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2387 &op.getElseRegion() == falseVal.getParentRegion())
2388 nonHoistable.push_back(trueVal.getType());
2392 if (nonHoistable.size() == op->getNumResults())
2395 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2397 if (replacement.thenBlock())
2398 rewriter.
eraseBlock(replacement.thenBlock());
2399 replacement.getThenRegion().takeBody(op.getThenRegion());
2400 replacement.getElseRegion().takeBody(op.getElseRegion());
2403 assert(thenYieldArgs.size() == results.size());
2404 assert(elseYieldArgs.size() == results.size());
2409 for (
const auto &it :
2411 Value trueVal = std::get<0>(it.value());
2412 Value falseVal = std::get<1>(it.value());
2415 results[it.index()] = replacement.getResult(trueYields.size());
2416 trueYields.push_back(trueVal);
2417 falseYields.push_back(falseVal);
2418 }
else if (trueVal == falseVal)
2419 results[it.index()] = trueVal;
2421 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2422 cond, trueVal, falseVal);
2472 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2478 if (op.getNumResults() == 0)
2482 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2484 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2487 op.getOperation()->getIterator());
2490 for (
auto [trueResult, falseResult, opResult] :
2491 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2493 if (trueResult == falseResult) {
2494 if (!opResult.use_empty()) {
2495 opResult.replaceAllUsesWith(trueResult);
2506 bool trueVal = trueYield.
getValue();
2507 bool falseVal = falseYield.
getValue();
2508 if (!trueVal && falseVal) {
2509 if (!opResult.use_empty()) {
2510 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2511 Value notCond = arith::XOrIOp::create(
2512 rewriter, op.getLoc(), op.getCondition(),
2518 opResult.replaceAllUsesWith(notCond);
2522 if (trueVal && !falseVal) {
2523 if (!opResult.use_empty()) {
2524 opResult.replaceAllUsesWith(op.getCondition());
2559 Block *parent = nextIf->getBlock();
2560 if (nextIf == &parent->
front())
2563 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2571 Block *nextThen =
nullptr;
2572 Block *nextElse =
nullptr;
2573 if (nextIf.getCondition() == prevIf.getCondition()) {
2574 nextThen = nextIf.thenBlock();
2575 if (!nextIf.getElseRegion().empty())
2576 nextElse = nextIf.elseBlock();
2578 if (arith::XOrIOp notv =
2579 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2580 if (notv.getLhs() == prevIf.getCondition() &&
2582 nextElse = nextIf.thenBlock();
2583 if (!nextIf.getElseRegion().empty())
2584 nextThen = nextIf.elseBlock();
2587 if (arith::XOrIOp notv =
2588 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2589 if (notv.getLhs() == nextIf.getCondition() &&
2591 nextElse = nextIf.thenBlock();
2592 if (!nextIf.getElseRegion().empty())
2593 nextThen = nextIf.elseBlock();
2597 if (!nextThen && !nextElse)
2601 if (!prevIf.getElseRegion().empty())
2602 prevElseYielded = prevIf.elseYield().getOperands();
2605 for (
auto it : llvm::zip(prevIf.getResults(),
2606 prevIf.thenYield().getOperands(), prevElseYielded))
2608 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2612 use.
set(std::get<1>(it));
2617 use.
set(std::get<2>(it));
2623 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2625 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2626 prevIf.getCondition(),
false);
2627 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2630 combinedIf.getThenRegion(),
2631 combinedIf.getThenRegion().begin());
2634 YieldOp thenYield = combinedIf.thenYield();
2635 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2636 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2640 llvm::append_range(mergedYields, thenYield2.getOperands());
2641 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2647 combinedIf.getElseRegion(),
2648 combinedIf.getElseRegion().begin());
2651 if (combinedIf.getElseRegion().empty()) {
2653 combinedIf.getElseRegion(),
2654 combinedIf.getElseRegion().
begin());
2656 YieldOp elseYield = combinedIf.elseYield();
2657 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2658 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2663 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2665 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2674 if (pair.index() < prevIf.getNumResults())
2675 prevValues.push_back(pair.value());
2677 nextValues.push_back(pair.value());
2692 if (ifOp.getNumResults())
2694 Block *elseBlock = ifOp.elseBlock();
2695 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2699 newIfOp.getThenRegion().begin());
2726 auto nestedOps = op.thenBlock()->without_terminator();
2728 if (!llvm::hasSingleElement(nestedOps))
2732 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2735 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2739 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2745 llvm::append_range(elseYield, op.elseYield().getOperands());
2759 if (tup.value().getDefiningOp() == nestedIf) {
2760 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2761 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2762 elseYield[tup.index()]) {
2767 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2780 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2783 elseYieldsToUpgradeToSelect.push_back(tup.index());
2787 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2788 nestedIf.getCondition());
2789 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2793 llvm::append_range(results, newIf.getResults());
2796 for (
auto idx : elseYieldsToUpgradeToSelect)
2798 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2799 thenYield[idx], elseYield[idx]);
2801 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2804 if (!elseYield.empty()) {
2807 YieldOp::create(rewriter, loc, elseYield);
2818 results.
add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
2819 RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
2820 ReplaceIfYieldWithConditionOrValue>(context);
2823 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2824 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2825 Block *IfOp::elseBlock() {
2826 Region &r = getElseRegion();
2831 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2837 void ParallelOp::build(
2847 ParallelOp::getOperandSegmentSizeAttr(),
2849 static_cast<int32_t>(upperBounds.size()),
2850 static_cast<int32_t>(steps.size()),
2851 static_cast<int32_t>(initVals.size())}));
2855 unsigned numIVs = steps.size();
2861 if (bodyBuilderFn) {
2863 bodyBuilderFn(builder, result.
location,
2868 if (initVals.empty())
2869 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2872 void ParallelOp::build(
2879 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2882 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2886 wrapper = wrappedBuilderFn;
2888 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2897 if (stepValues.empty())
2899 "needs at least one tuple element for lowerBound, upperBound and step");
2902 for (
Value stepValue : stepValues)
2905 return emitOpError(
"constant step operand must be positive");
2909 Block *body = getBody();
2911 return emitOpError() <<
"expects the same number of induction variables: "
2913 <<
" as bound and step values: " << stepValues.size();
2915 if (!arg.getType().isIndex())
2917 "expects arguments for the induction variable to be of index type");
2920 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2921 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2926 auto resultsSize = getResults().size();
2927 auto reductionsSize = reduceOp.getReductions().size();
2928 auto initValsSize = getInitVals().size();
2929 if (resultsSize != reductionsSize)
2930 return emitOpError() <<
"expects number of results: " << resultsSize
2931 <<
" to be the same as number of reductions: "
2933 if (resultsSize != initValsSize)
2934 return emitOpError() <<
"expects number of results: " << resultsSize
2935 <<
" to be the same as number of initial values: "
2939 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2940 auto resultType = getOperation()->getResult(i).getType();
2941 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2942 if (resultType != reductionOperandType)
2943 return reduceOp.emitOpError()
2944 <<
"expects type of " << i
2945 <<
"-th reduction operand: " << reductionOperandType
2946 <<
" to be the same as the " << i
2947 <<
"-th result type: " << resultType;
2963 OpAsmParser::Delimiter::Paren) ||
2970 OpAsmParser::Delimiter::Paren) ||
2978 OpAsmParser::Delimiter::Paren) ||
2995 for (
auto &iv : ivs)
3002 ParallelOp::getOperandSegmentSizeAttr(),
3004 static_cast<int32_t>(upper.size()),
3005 static_cast<int32_t>(steps.size()),
3006 static_cast<int32_t>(initVals.size())}));
3015 ParallelOp::ensureTerminator(*body, builder, result.
location);
3020 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3021 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3022 if (!getInitVals().empty())
3023 p <<
" init (" << getInitVals() <<
")";
3028 (*this)->getAttrs(),
3029 ParallelOp::getOperandSegmentSizeAttr());
3034 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3038 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3042 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3046 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3051 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3053 return ParallelOp();
3054 assert(ivArg.getOwner() &&
"unlinked block argument");
3055 auto *containingOp = ivArg.getOwner()->getParentOp();
3056 return dyn_cast<ParallelOp>(containingOp);
3061 struct ParallelOpSingleOrZeroIterationDimsFolder
3072 for (
auto [lb, ub, step, iv] :
3073 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3074 op.getInductionVars())) {
3076 if (numIterations.has_value()) {
3078 if (*numIterations == 0) {
3079 rewriter.
replaceOp(op, op.getInitVals());
3084 if (*numIterations == 1) {
3089 newLowerBounds.push_back(lb);
3090 newUpperBounds.push_back(ub);
3091 newSteps.push_back(step);
3094 if (newLowerBounds.size() == op.getLowerBound().size())
3097 if (newLowerBounds.empty()) {
3101 results.reserve(op.getInitVals().size());
3102 for (
auto &bodyOp : op.getBody()->without_terminator())
3103 rewriter.
clone(bodyOp, mapping);
3104 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3105 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3106 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3107 auto initValIndex = results.size();
3108 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3112 rewriter.
clone(reduceBodyOp, mapping);
3115 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3116 results.push_back(result);
3124 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3125 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3131 newOp.getRegion().begin(), mapping);
3132 rewriter.
replaceOp(op, newOp.getResults());
3142 Block &outerBody = *op.getBody();
3146 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3151 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3152 llvm::is_contained(innerOp.getUpperBound(), val) ||
3153 llvm::is_contained(innerOp.getStep(), val))
3157 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3162 Block &innerBody = *innerOp.getBody();
3163 assert(iterVals.size() ==
3171 builder.
clone(op, mapping);
3174 auto concatValues = [](
const auto &first,
const auto &second) {
3176 ret.reserve(first.size() + second.size());
3177 ret.assign(first.begin(), first.end());
3178 ret.append(second.begin(), second.end());
3182 auto newLowerBounds =
3183 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3184 auto newUpperBounds =
3185 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3186 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3200 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3209 void ParallelOp::getSuccessorRegions(
3227 for (
Value v : operands) {
3236 LogicalResult ReduceOp::verifyRegions() {
3239 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3240 auto type = getOperands()[i].getType();
3243 return emitOpError() << i <<
"-th reduction has an empty body";
3246 return arg.getType() != type;
3248 return emitOpError() <<
"expected two block arguments with type " << type
3249 <<
" in the " << i <<
"-th reduction region";
3253 return emitOpError(
"reduction bodies must be terminated with an "
3254 "'scf.reduce.return' op");
3273 Block *reductionBody = getOperation()->getBlock();
3275 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3277 if (expectedResultType != getResult().
getType())
3278 return emitOpError() <<
"must have type " << expectedResultType
3279 <<
" (the type of the reduction inputs)";
3289 ValueRange inits, BodyBuilderFn beforeBuilder,
3290 BodyBuilderFn afterBuilder) {
3298 beforeArgLocs.reserve(inits.size());
3299 for (
Value operand : inits) {
3300 beforeArgLocs.push_back(operand.getLoc());
3305 inits.getTypes(), beforeArgLocs);
3314 resultTypes, afterArgLocs);
3320 ConditionOp WhileOp::getConditionOp() {
3321 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3324 YieldOp WhileOp::getYieldOp() {
3325 return cast<YieldOp>(getAfterBody()->getTerminator());
3328 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3329 return getYieldOp().getResultsMutable();
3333 return getBeforeBody()->getArguments();
3337 return getAfterBody()->getArguments();
3341 return getBeforeArguments();
3345 assert(point == getBefore() &&
3346 "WhileOp is expected to branch only to the first region");
3354 regions.emplace_back(&getBefore(), getBefore().getArguments());
3358 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3359 "there are only two regions in a WhileOp");
3361 if (point == getAfter()) {
3362 regions.emplace_back(&getBefore(), getBefore().getArguments());
3366 regions.emplace_back(getResults());
3367 regions.emplace_back(&getAfter(), getAfter().getArguments());
3371 return {&getBefore(), &getAfter()};
3392 FunctionType functionType;
3397 result.
addTypes(functionType.getResults());
3399 if (functionType.getNumInputs() != operands.size()) {
3401 <<
"expected as many input types as operands " <<
"(expected "
3402 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3412 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3413 regionArgs[i].type = functionType.getInput(i);
3415 return failure(parser.
parseRegion(*before, regionArgs) ||
3435 template <
typename OpTy>
3438 if (left.size() != right.size())
3439 return op.emitOpError(
"expects the same number of ") << message;
3441 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3442 if (left[i] != right[i]) {
3445 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3446 <<
" and " << right[i];
3455 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3457 "expects the 'before' region to terminate with 'scf.condition'");
3458 if (!beforeTerminator)
3461 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3463 "expects the 'after' region to terminate with 'scf.yield'");
3464 return success(afterTerminator !=
nullptr);
3492 auto term = op.getConditionOp();
3496 Value constantTrue =
nullptr;
3498 bool replaced =
false;
3499 for (
auto yieldedAndBlockArgs :
3500 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3501 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3502 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3504 constantTrue = arith::ConstantOp::create(
3505 rewriter, op.getLoc(), term.getCondition().getType(),
3514 return success(replaced);
3566 struct RemoveLoopInvariantArgsFromBeforeBlock
3572 Block &afterBlock = *op.getAfterBody();
3574 ConditionOp condOp = op.getConditionOp();
3579 bool canSimplify =
false;
3580 for (
const auto &it :
3582 auto index =
static_cast<unsigned>(it.index());
3583 auto [initVal, yieldOpArg] = it.value();
3586 if (yieldOpArg == initVal) {
3595 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3596 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3597 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3598 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3611 for (
const auto &it :
3613 auto index =
static_cast<unsigned>(it.index());
3614 auto [initVal, yieldOpArg] = it.value();
3618 if (yieldOpArg == initVal) {
3619 beforeBlockInitValMap.insert({index, initVal});
3627 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3628 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3629 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3630 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3631 beforeBlockInitValMap.insert({index, initVal});
3636 newInitArgs.emplace_back(initVal);
3637 newYieldOpArgs.emplace_back(yieldOpArg);
3638 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3647 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3651 &newWhile.getBefore(), {},
3652 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3654 Block &beforeBlock = *op.getBeforeBody();
3661 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3664 if (beforeBlockInitValMap.count(i) != 0)
3665 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3667 newBeforeBlockArgs[i] = newBeforeBlock.
getArgument(
j++);
3670 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3672 newWhile.getAfter().begin());
3674 rewriter.
replaceOp(op, newWhile.getResults());
3719 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3724 Block &beforeBlock = *op.getBeforeBody();
3725 ConditionOp condOp = op.getConditionOp();
3728 bool canSimplify =
false;
3729 for (
Value condOpArg : condOpArgs) {
3749 auto index =
static_cast<unsigned>(it.index());
3750 Value condOpArg = it.value();
3755 condOpInitValMap.insert({index, condOpArg});
3757 newCondOpArgs.emplace_back(condOpArg);
3758 newAfterBlockType.emplace_back(condOpArg.
getType());
3759 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3770 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3773 Block &newAfterBlock =
3775 newAfterBlockType, newAfterBlockArgLocs);
3777 Block &afterBlock = *op.getAfterBody();
3784 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3785 Value afterBlockArg, result;
3788 if (condOpInitValMap.count(i) != 0) {
3789 afterBlockArg = condOpInitValMap[i];
3790 result = afterBlockArg;
3793 result = newWhile.getResult(
j);
3796 newAfterBlockArgs[i] = afterBlockArg;
3797 newWhileResults[i] = result;
3800 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3802 newWhile.getBefore().begin());
3804 rewriter.
replaceOp(op, newWhileResults);
3840 auto term = op.getConditionOp();
3841 auto afterArgs = op.getAfterArguments();
3842 auto termArgs = term.getArgs();
3849 bool needUpdate =
false;
3850 for (
const auto &it :
3852 auto i =
static_cast<unsigned>(it.index());
3853 Value result = std::get<0>(it.value());
3854 Value afterArg = std::get<1>(it.value());
3855 Value termArg = std::get<2>(it.value());
3859 newResultsIndices.emplace_back(i);
3860 newTermArgs.emplace_back(termArg);
3861 newResultTypes.emplace_back(result.
getType());
3862 newArgLocs.emplace_back(result.
getLoc());
3877 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
3880 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3887 newResults[it.value()] = newWhile.getResult(it.index());
3888 newAfterBlockArgs[it.value()] = newAfterBlock.
getArgument(it.index());
3892 newWhile.getBefore().begin());
3894 Block &afterBlock = *op.getAfterBody();
3895 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3929 using namespace scf;
3930 auto cond = op.getConditionOp();
3931 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3935 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3936 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3937 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3940 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3941 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3945 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3948 if (cmp2.getPredicate() == cmp.getPredicate())
3949 samePredicate =
true;
3950 else if (cmp2.getPredicate() ==
3952 samePredicate =
false;
3970 LogicalResult matchAndRewrite(WhileOp op,
3973 if (!llvm::any_of(op.getBeforeArguments(),
3974 [](
Value arg) { return arg.use_empty(); }))
3977 YieldOp yield = op.getYieldOp();
3982 llvm::BitVector argsToErase;
3984 size_t argsCount = op.getBeforeArguments().size();
3985 newYields.reserve(argsCount);
3986 newInits.reserve(argsCount);
3987 argsToErase.reserve(argsCount);
3988 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3989 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3990 if (beforeArg.use_empty()) {
3991 argsToErase.push_back(
true);
3993 argsToErase.push_back(
false);
3994 newYields.emplace_back(yieldValue);
3995 newInits.emplace_back(initValue);
3999 Block &beforeBlock = *op.getBeforeBody();
4000 Block &afterBlock = *op.getAfterBody();
4006 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4008 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4009 Block &newAfterBlock = *newWhileOp.getAfterBody();
4015 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4020 rewriter.
replaceOp(op, newWhileOp.getResults());
4027 using OpRewritePattern::OpRewritePattern;
4029 LogicalResult matchAndRewrite(WhileOp op,
4031 ConditionOp condOp = op.getConditionOp();
4036 if (argsSet.size() == condOpArgs.size())
4039 llvm::SmallDenseMap<Value, unsigned> argsMap;
4041 argsMap.reserve(condOpArgs.size());
4042 newArgs.reserve(condOpArgs.size());
4043 for (
Value arg : condOpArgs) {
4044 if (!argsMap.count(arg)) {
4045 auto pos =
static_cast<unsigned>(argsMap.size());
4046 argsMap.insert({arg, pos});
4047 newArgs.emplace_back(arg);
4055 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4058 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4059 Block &newAfterBlock = *newWhileOp.getAfterBody();
4064 auto it = argsMap.find(arg);
4065 assert(it != argsMap.end());
4066 auto pos = it->second;
4067 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4068 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4076 Block &beforeBlock = *op.getBeforeBody();
4077 Block &afterBlock = *op.getAfterBody();
4079 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4081 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4089 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4091 if (args1.size() != args2.size())
4092 return std::nullopt;
4096 auto it = llvm::find(args2, arg1);
4097 if (it == args2.end())
4098 return std::nullopt;
4100 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4107 llvm::SmallDenseSet<Value> set;
4108 for (
Value arg : args) {
4109 if (!set.insert(arg).second)
4120 using OpRewritePattern::OpRewritePattern;
4122 LogicalResult matchAndRewrite(WhileOp loop,
4124 auto oldBefore = loop.getBeforeBody();
4125 ConditionOp oldTerm = loop.getConditionOp();
4126 ValueRange beforeArgs = oldBefore->getArguments();
4128 if (beforeArgs == termArgs)
4131 if (hasDuplicates(termArgs))
4134 auto mapping = getArgsMapping(beforeArgs, termArgs);
4145 auto oldAfter = loop.getAfterBody();
4149 newResultTypes[
j] = loop.getResult(i).getType();
4151 auto newLoop = WhileOp::create(
4152 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4154 auto newBefore = newLoop.getBeforeBody();
4155 auto newAfter = newLoop.getAfterBody();
4160 newResults[i] = newLoop.getResult(
j);
4161 newAfterArgs[i] = newAfter->getArgument(
j);
4165 newBefore->getArguments());
4177 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4178 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4179 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4180 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4194 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4197 caseValues.push_back(value);
4206 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4208 p <<
"case " << value <<
' ';
4214 if (getCases().size() != getCaseRegions().size()) {
4215 return emitOpError(
"has ")
4216 << getCaseRegions().size() <<
" case regions but "
4217 << getCases().size() <<
" case values";
4221 for (int64_t value : getCases())
4222 if (!valueSet.insert(value).second)
4223 return emitOpError(
"has duplicate case value: ") << value;
4225 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4227 return emitOpError(
"expected region to end with scf.yield, but got ")
4230 if (yield.getNumOperands() != getNumResults()) {
4231 return (emitOpError(
"expected each region to return ")
4232 << getNumResults() <<
" values, but " << name <<
" returns "
4233 << yield.getNumOperands())
4234 .attachNote(yield.getLoc())
4235 <<
"see yield operation here";
4237 for (
auto [idx, result, operand] :
4240 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
4241 if (result == operand.getType())
4243 return (emitOpError(
"expected result #")
4244 << idx <<
" of each region to be " << result)
4245 .attachNote(yield.getLoc())
4246 << name <<
" returns " << operand.getType() <<
" here";
4260 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4262 Block &scf::IndexSwitchOp::getDefaultBlock() {
4263 return getDefaultRegion().
front();
4266 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4267 assert(idx < getNumCases() &&
"case index out-of-bounds");
4268 return getCaseRegions()[idx].front();
4271 void IndexSwitchOp::getSuccessorRegions(
4275 successors.emplace_back(getResults());
4279 llvm::append_range(successors, getRegions());
4282 void IndexSwitchOp::getEntrySuccessorRegions(
4285 FoldAdaptor adaptor(operands, *
this);
4288 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4290 llvm::append_range(successors, getRegions());
4296 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4297 if (caseValue == arg.getInt()) {
4298 successors.emplace_back(&caseRegion);
4302 successors.emplace_back(&getDefaultRegion());
4305 void IndexSwitchOp::getRegionInvocationBounds(
4307 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4308 if (!operandValue) {
4314 unsigned liveIndex = getNumRegions() - 1;
4315 const auto *it = llvm::find(getCases(), operandValue.getInt());
4316 if (it != getCases().end())
4317 liveIndex = std::distance(getCases().begin(), it);
4318 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4319 bounds.emplace_back(0, i == liveIndex);
4330 if (!maybeCst.has_value())
4332 int64_t cst = *maybeCst;
4333 int64_t caseIdx, e = op.getNumCases();
4334 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4335 if (cst == op.getCases()[caseIdx])
4339 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4340 : op.getDefaultRegion();
4364 #define GET_OP_CLASSES
4365 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.