26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
33 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
56 auto retValOp = dyn_cast<scf::YieldOp>(op);
60 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
61 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
71 void SCFDialect::initialize() {
74 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
76 addInterfaces<SCFInlinerInterface>();
77 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
78 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
79 InParallelOp, ReduceReturnOp>();
80 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
81 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
82 ForallOp, InParallelOp, WhileOp, YieldOp>();
83 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
88 builder.
create<scf::YieldOp>(loc);
93 template <
typename TerminatorTy>
95 StringRef errorMessage) {
98 terminatorOperation = ®ion.
front().
back();
99 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
103 if (terminatorOperation)
104 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
116 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
162 if (getRegion().empty())
163 return emitOpError(
"region needs to have at least one block");
164 if (getRegion().front().getNumArguments() > 0)
165 return emitOpError(
"region cannot have any arguments");
188 if (!llvm::hasSingleElement(op.getRegion()))
237 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
240 Block *prevBlock = op->getBlock();
244 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
246 for (
Block &blk : op.getRegion()) {
247 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
249 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
250 yieldOp.getResults());
258 for (
auto res : op.getResults())
259 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
271 void ExecuteRegionOp::getSuccessorRegions(
289 assert((point.
isParent() || point == getParentOp().getAfter()) &&
290 "condition op can only exit the loop or branch to the after"
293 return getArgsMutable();
296 void ConditionOp::getSuccessorRegions(
298 FoldAdaptor adaptor(operands, *
this);
300 WhileOp whileOp = getParentOp();
304 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
305 if (!boolAttr || boolAttr.getValue())
306 regions.emplace_back(&whileOp.getAfter(),
307 whileOp.getAfter().getArguments());
308 if (!boolAttr || !boolAttr.getValue())
309 regions.emplace_back(whileOp.getResults());
318 BodyBuilderFn bodyBuilder) {
323 for (
Value v : initArgs)
329 for (
Value v : initArgs)
335 if (initArgs.empty() && !bodyBuilder) {
336 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
337 }
else if (bodyBuilder) {
347 if (getInitArgs().size() != getNumResults())
349 "mismatch in number of loop-carried values and defined values");
354 LogicalResult ForOp::verifyRegions() {
359 "expected induction variable to be same type as bounds and step");
361 if (getNumRegionIterArgs() != getNumResults())
363 "mismatch in number of basic block args and defined values");
365 auto initArgs = getInitArgs();
366 auto iterArgs = getRegionIterArgs();
367 auto opResults = getResults();
369 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
371 return emitOpError() <<
"types mismatch between " << i
372 <<
"th iter operand and defined value";
374 return emitOpError() <<
"types mismatch between " << i
375 <<
"th iter region arg and defined value";
382 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
386 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
390 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
394 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
398 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
403 std::optional<int64_t> tripCount =
405 if (!tripCount.has_value() || tripCount != 1)
409 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
416 llvm::append_range(bbArgReplacements, getInitArgs());
420 getOperation()->getIterator(), bbArgReplacements);
436 StringRef prefix =
"") {
437 assert(blocksArgs.size() == initializers.size() &&
438 "expected same length of arguments and initializers");
439 if (initializers.empty())
443 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
444 p << std::get<0>(it) <<
" = " << std::get<1>(it);
450 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
454 if (!getInitArgs().empty())
455 p <<
" -> (" << getInitArgs().getTypes() <<
')';
458 p <<
" : " << t <<
' ';
461 !getInitArgs().empty());
483 regionArgs.push_back(inductionVariable);
493 if (regionArgs.size() != result.
types.size() + 1)
496 "mismatch in number of loop-carried values and defined values");
505 regionArgs.front().type = type;
506 for (
auto [iterArg, type] :
507 llvm::zip_equal(llvm::drop_begin(regionArgs), result.
types))
514 ForOp::ensureTerminator(*body, builder, result.
location);
523 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
524 operands, result.
types)) {
525 Type type = std::get<2>(argOperandType);
526 std::get<0>(argOperandType).type = type;
543 return getBody()->getArguments().drop_front(getNumInductionVars());
547 return getInitArgsMutable();
550 FailureOr<LoopLikeOpInterface>
551 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
553 bool replaceInitOperandUsesInLoop,
558 auto inits = llvm::to_vector(getInitArgs());
559 inits.append(newInitOperands.begin(), newInitOperands.end());
560 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
566 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
568 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
573 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
574 assert(newInitOperands.size() == newYieldedValues.size() &&
575 "expected as many new yield values as new iter operands");
577 yieldOp.getResultsMutable().append(newYieldedValues);
583 newLoop.getBody()->getArguments().take_front(
584 getBody()->getNumArguments()));
586 if (replaceInitOperandUsesInLoop) {
589 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
600 newLoop->getResults().take_front(getNumResults()));
601 return cast<LoopLikeOpInterface>(newLoop.getOperation());
605 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
608 assert(ivArg.getOwner() &&
"unlinked block argument");
609 auto *containingOp = ivArg.getOwner()->getParentOp();
610 return dyn_cast_or_null<ForOp>(containingOp);
614 return getInitArgs();
631 for (
auto [lb, ub, step] :
632 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
634 if (!tripCount.has_value() || *tripCount != 1)
643 return getBody()->getArguments().drop_front(getRank());
647 return getOutputsMutable();
653 scf::InParallelOp terminator = forallOp.getTerminator();
658 bbArgReplacements.append(forallOp.getOutputs().begin(),
659 forallOp.getOutputs().end());
663 forallOp->getIterator(), bbArgReplacements);
668 results.reserve(forallOp.getResults().size());
669 for (
auto &yieldingOp : terminator.getYieldingOps()) {
670 auto parallelInsertSliceOp =
671 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
673 Value dst = parallelInsertSliceOp.getDest();
674 Value src = parallelInsertSliceOp.getSource();
675 if (llvm::isa<TensorType>(src.
getType())) {
676 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
677 forallOp.getLoc(), dst.
getType(), src, dst,
678 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
679 parallelInsertSliceOp.getStrides(),
680 parallelInsertSliceOp.getStaticOffsets(),
681 parallelInsertSliceOp.getStaticSizes(),
682 parallelInsertSliceOp.getStaticStrides()));
684 llvm_unreachable(
"unsupported terminator");
699 assert(lbs.size() == ubs.size() &&
700 "expected the same number of lower and upper bounds");
701 assert(lbs.size() == steps.size() &&
702 "expected the same number of lower bounds and steps");
707 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
709 assert(results.size() == iterArgs.size() &&
710 "loop nest body must return as many values as loop has iteration "
712 return LoopNest{{}, std::move(results)};
720 loops.reserve(lbs.size());
721 ivs.reserve(lbs.size());
724 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
725 auto loop = builder.
create<scf::ForOp>(
726 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
732 currentIterArgs = args;
733 currentLoc = nestedLoc;
739 loops.push_back(loop);
743 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
745 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
752 ? bodyBuilder(builder, currentLoc, ivs,
753 loops.back().getRegionIterArgs())
755 assert(results.size() == iterArgs.size() &&
756 "loop nest body must return as many values as loop has iteration "
759 builder.
create<scf::YieldOp>(loc, results);
763 llvm::append_range(nestResults, loops.front().getResults());
764 return LoopNest{std::move(loops), std::move(nestResults)};
772 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
777 bodyBuilder(nestedBuilder, nestedLoc, ivs);
786 assert(operand.
getOwner() == forOp);
791 "expected an iter OpOperand");
793 "Expected a different type");
795 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
797 newIterOperands.push_back(replacement);
800 newIterOperands.push_back(opOperand.get());
804 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
805 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806 forOp.getStep(), newIterOperands);
807 newForOp->
setAttrs(forOp->getAttrs());
808 Block &newBlock = newForOp.getRegion().
front();
816 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
818 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
819 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
823 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
826 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
829 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
830 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
831 clonedYieldOp.getOperand(yieldIdx));
833 newYieldOperands[yieldIdx] = castOut;
834 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
835 rewriter.
eraseOp(clonedYieldOp);
840 newResults[yieldIdx] =
841 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
861 LogicalResult matchAndRewrite(scf::ForOp forOp,
863 bool canonicalize =
false;
870 int64_t numResults = forOp.getNumResults();
872 keepMask.reserve(numResults);
875 newBlockTransferArgs.reserve(1 + numResults);
876 newBlockTransferArgs.push_back(
Value());
877 newIterArgs.reserve(forOp.getInitArgs().size());
878 newYieldValues.reserve(numResults);
879 newResultValues.reserve(numResults);
881 for (
auto [init, arg, result, yielded] :
882 llvm::zip(forOp.getInitArgs(),
883 forOp.getRegionIterArgs(),
885 forOp.getYieldedValues()
892 bool forwarded = (arg == yielded) || (init == yielded) ||
893 (arg.use_empty() && result.use_empty());
896 keepMask.push_back(
false);
897 newBlockTransferArgs.push_back(init);
898 newResultValues.push_back(init);
904 if (
auto it = initYieldToArg.find({init, yielded});
905 it != initYieldToArg.end()) {
907 keepMask.push_back(
false);
908 auto [sameArg, sameResult] = it->second;
912 newBlockTransferArgs.push_back(init);
913 newResultValues.push_back(init);
918 initYieldToArg.insert({{init, yielded}, {arg, result}});
919 keepMask.push_back(
true);
920 newIterArgs.push_back(init);
921 newYieldValues.push_back(yielded);
922 newBlockTransferArgs.push_back(
Value());
923 newResultValues.push_back(
Value());
929 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
930 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
931 forOp.getStep(), newIterArgs);
932 newForOp->
setAttrs(forOp->getAttrs());
933 Block &newBlock = newForOp.getRegion().
front();
937 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
939 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
940 Value &newResultVal = newResultValues[idx];
941 assert((blockTransferArg && newResultVal) ||
942 (!blockTransferArg && !newResultVal));
943 if (!blockTransferArg) {
944 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
945 newResultVal = newForOp.getResult(collapsedIdx++);
951 "unexpected argument size mismatch");
956 if (newIterArgs.empty()) {
957 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
960 rewriter.
replaceOp(forOp, newResultValues);
965 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
969 filteredOperands.reserve(newResultValues.size());
970 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
972 filteredOperands.push_back(mergedTerminator.getOperand(idx));
973 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
977 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
978 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
979 cloneFilteredTerminator(mergedYieldOp);
980 rewriter.
eraseOp(mergedYieldOp);
981 rewriter.
replaceOp(forOp, newResultValues);
989 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
990 IntegerAttr clb, cub;
992 llvm::APInt lbValue = clb.getValue();
993 llvm::APInt ubValue = cub.getValue();
994 return (ubValue - lbValue).getSExtValue();
1003 return diff.getSExtValue();
1004 return std::nullopt;
1013 LogicalResult matchAndRewrite(ForOp op,
1017 if (op.getLowerBound() == op.getUpperBound()) {
1018 rewriter.
replaceOp(op, op.getInitArgs());
1022 std::optional<int64_t> diff =
1023 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1029 rewriter.
replaceOp(op, op.getInitArgs());
1033 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1034 if (!maybeStepValue)
1039 llvm::APInt stepValue = *maybeStepValue;
1040 if (stepValue.sge(*diff)) {
1042 blockArgs.reserve(op.getInitArgs().size() + 1);
1043 blockArgs.push_back(op.getLowerBound());
1044 llvm::append_range(blockArgs, op.getInitArgs());
1051 if (!llvm::hasSingleElement(block))
1055 if (llvm::any_of(op.getYieldedValues(),
1056 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1058 rewriter.
replaceOp(op, op.getYieldedValues());
1092 LogicalResult matchAndRewrite(ForOp op,
1094 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1095 OpOperand &iterOpOperand = std::get<0>(it);
1097 if (!incomingCast ||
1098 incomingCast.getSource().getType() == incomingCast.getType())
1103 incomingCast.getDest().getType(),
1104 incomingCast.getSource().getType()))
1106 if (!std::get<1>(it).hasOneUse())
1112 rewriter, op, iterOpOperand, incomingCast.getSource(),
1114 return b.create<tensor::CastOp>(loc, type, source);
1126 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1130 std::optional<APInt> ForOp::getConstantStep() {
1133 return step.getValue();
1137 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1138 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1144 if (
auto constantStep = getConstantStep())
1145 if (*constantStep == 1)
1158 unsigned numLoops = getRank();
1160 if (getNumResults() != getOutputs().size())
1161 return emitOpError(
"produces ")
1162 << getNumResults() <<
" results, but has only "
1163 << getOutputs().size() <<
" outputs";
1166 auto *body = getBody();
1168 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1169 for (int64_t i = 0; i < numLoops; ++i)
1171 return emitOpError(
"expects ")
1172 << i <<
"-th block argument to be an index";
1173 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1175 return emitOpError(
"type mismatch between ")
1176 << i <<
"-th output and corresponding block argument";
1177 if (getMapping().has_value() && !getMapping()->empty()) {
1178 if (
static_cast<int64_t
>(getMapping()->size()) != numLoops)
1179 return emitOpError() <<
"mapping attribute size must match op rank";
1180 for (
auto map : getMapping()->getValue()) {
1181 if (!isa<DeviceMappingAttrInterface>(map))
1182 return emitOpError()
1190 getStaticLowerBound(),
1191 getDynamicLowerBound())))
1194 getStaticUpperBound(),
1195 getDynamicUpperBound())))
1198 getStaticStep(), getDynamicStep())))
1206 p <<
" (" << getInductionVars();
1207 if (isNormalized()) {
1228 if (!getRegionOutArgs().empty())
1229 p <<
"-> (" << getResultTypes() <<
") ";
1230 p.printRegion(getRegion(),
1232 getNumResults() > 0);
1233 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1234 getStaticLowerBoundAttrName(),
1235 getStaticUpperBoundAttrName(),
1236 getStaticStepAttrName()});
1241 auto indexType = b.getIndexType();
1261 unsigned numLoops = ivs.size();
1296 if (outOperands.size() != result.
types.size())
1298 "mismatch between out operands and types");
1308 std::unique_ptr<Region> region = std::make_unique<Region>();
1309 for (
auto &iv : ivs) {
1310 iv.type = b.getIndexType();
1311 regionArgs.push_back(iv);
1314 auto &out = it.value();
1315 out.type = result.
types[it.index()];
1316 regionArgs.push_back(out);
1322 ForallOp::ensureTerminator(*region, b, result.
location);
1334 {static_cast<int32_t>(dynamicLbs.size()),
1335 static_cast<int32_t>(dynamicUbs.size()),
1336 static_cast<int32_t>(dynamicSteps.size()),
1337 static_cast<int32_t>(outOperands.size())}));
1342 void ForallOp::build(
1346 std::optional<ArrayAttr> mapping,
1367 "operandSegmentSizes",
1369 static_cast<int32_t>(dynamicUbs.size()),
1370 static_cast<int32_t>(dynamicSteps.size()),
1371 static_cast<int32_t>(outputs.size())}));
1372 if (mapping.has_value()) {
1391 if (!bodyBuilderFn) {
1392 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1399 void ForallOp::build(
1402 std::optional<ArrayAttr> mapping,
1404 unsigned numLoops = ubs.size();
1407 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1411 bool ForallOp::isNormalized() {
1415 return intValue.has_value() && intValue == val;
1418 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1421 InParallelOp ForallOp::getTerminator() {
1422 return cast<InParallelOp>(getBody()->getTerminator());
1427 InParallelOp inParallelOp = getTerminator();
1428 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1429 if (
auto parallelInsertSliceOp =
1430 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1431 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1432 storeOps.push_back(parallelInsertSliceOp);
1438 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1443 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1445 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1449 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1451 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1455 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1461 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1464 assert(tidxArg.getOwner() &&
"unlinked block argument");
1465 auto *containingOp = tidxArg.getOwner()->getParentOp();
1466 return dyn_cast<ForallOp>(containingOp);
1474 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1476 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1480 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1483 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1492 LogicalResult matchAndRewrite(ForallOp op,
1507 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1508 op.setStaticLowerBound(staticLowerBound);
1512 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1513 op.setStaticUpperBound(staticUpperBound);
1516 op.getDynamicStepMutable().assign(dynamicStep);
1517 op.setStaticStep(staticStep);
1519 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1521 {static_cast<int32_t>(dynamicLowerBound.size()),
1522 static_cast<int32_t>(dynamicUpperBound.size()),
1523 static_cast<int32_t>(dynamicStep.size()),
1524 static_cast<int32_t>(op.getNumResults())}));
1606 LogicalResult matchAndRewrite(ForallOp forallOp,
1625 for (
OpResult result : forallOp.getResults()) {
1626 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1627 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1628 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1629 resultToDelete.insert(result);
1631 resultToReplace.push_back(result);
1632 newOuts.push_back(opOperand->
get());
1638 if (resultToDelete.empty())
1646 for (
OpResult result : resultToDelete) {
1647 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1648 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1650 forallOp.getCombiningOps(blockArg);
1651 for (
Operation *combiningOp : combiningOps)
1652 rewriter.
eraseOp(combiningOp);
1657 auto newForallOp = rewriter.
create<scf::ForallOp>(
1658 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1659 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1660 forallOp.getMapping(),
1665 Block *loopBody = forallOp.getBody();
1666 Block *newLoopBody = newForallOp.getBody();
1671 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1678 for (
OpResult result : forallOp.getResults()) {
1679 if (resultToDelete.count(result)) {
1680 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1682 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1685 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1689 for (
auto &&[oldResult, newResult] :
1690 llvm::zip(resultToReplace, newForallOp->getResults()))
1696 for (
OpResult oldResult : resultToDelete)
1698 forallOp.getTiedOpOperand(oldResult)->get());
1703 struct ForallOpSingleOrZeroIterationDimsFolder
1707 LogicalResult matchAndRewrite(ForallOp op,
1710 if (op.getMapping().has_value() && !op.getMapping()->empty())
1718 for (
auto [lb, ub, step, iv] :
1719 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1720 op.getMixedStep(), op.getInductionVars())) {
1722 if (numIterations.has_value()) {
1724 if (*numIterations == 0) {
1725 rewriter.
replaceOp(op, op.getOutputs());
1730 if (*numIterations == 1) {
1735 newMixedLowerBounds.push_back(lb);
1736 newMixedUpperBounds.push_back(ub);
1737 newMixedSteps.push_back(step);
1741 if (newMixedLowerBounds.empty()) {
1747 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1749 op,
"no dimensions have 0 or 1 iterations");
1754 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1755 newMixedUpperBounds, newMixedSteps,
1756 op.getOutputs(), std::nullopt,
nullptr);
1757 newOp.getBodyRegion().getBlocks().clear();
1762 newOp.getStaticLowerBoundAttrName(),
1763 newOp.getStaticUpperBoundAttrName(),
1764 newOp.getStaticStepAttrName()};
1765 for (
const auto &namedAttr : op->getAttrs()) {
1766 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1769 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1773 newOp.getRegion().begin(), mapping);
1774 rewriter.
replaceOp(op, newOp.getResults());
1780 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1783 LogicalResult matchAndRewrite(ForallOp op,
1787 for (
auto [lb, ub, step, iv] :
1788 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1789 op.getMixedStep(), op.getInductionVars())) {
1790 if (iv.getUses().begin() == iv.getUses().end())
1793 if (!numIterations.has_value() || numIterations.value() != 1) {
1804 struct FoldTensorCastOfOutputIntoForallOp
1813 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1815 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1818 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1825 castOp.getSource().getType())) {
1829 tensorCastProducers[en.index()] =
1830 TypeCast{castOp.getSource().getType(), castOp.getType()};
1831 newOutputTensors[en.index()] = castOp.getSource();
1834 if (tensorCastProducers.empty())
1839 auto newForallOp = rewriter.
create<ForallOp>(
1840 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1841 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1843 auto castBlockArgs =
1844 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1845 for (
auto [index, cast] : tensorCastProducers) {
1846 Value &oldTypeBBArg = castBlockArgs[index];
1847 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1848 nestedLoc, cast.dstType, oldTypeBBArg);
1853 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1854 ivsBlockArgs.append(castBlockArgs);
1856 bbArgs.front().getParentBlock(), ivsBlockArgs);
1862 auto terminator = newForallOp.getTerminator();
1863 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1864 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1865 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1866 insertSliceOp.getDestMutable().assign(outputBlockArg);
1872 for (
auto &item : tensorCastProducers) {
1873 Value &oldTypeResult = castResults[item.first];
1874 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1877 rewriter.
replaceOp(forallOp, castResults);
1886 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1887 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1888 ForallOpSingleOrZeroIterationDimsFolder,
1889 ForallOpReplaceConstantInductionVar>(context);
1918 scf::ForallOp forallOp =
1919 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1921 return this->emitOpError(
"expected forall op parent");
1924 for (
Operation &op : getRegion().front().getOperations()) {
1925 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1926 return this->emitOpError(
"expected only ")
1927 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1931 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1933 if (!llvm::is_contained(regionOutArgs, dest))
1934 return op.emitOpError(
"may only insert into an output block argument");
1951 std::unique_ptr<Region> region = std::make_unique<Region>();
1955 if (region->empty())
1965 OpResult InParallelOp::getParentResult(int64_t idx) {
1966 return getOperation()->getParentOp()->getResult(idx);
1970 return llvm::to_vector<4>(
1971 llvm::map_range(getYieldingOps(), [](
Operation &op) {
1973 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1974 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1979 return getRegion().front().getOperations();
1987 assert(a &&
"expected non-empty operation");
1988 assert(b &&
"expected non-empty operation");
1993 if (ifOp->isProperAncestor(b))
1996 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1997 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1999 ifOp = ifOp->getParentOfType<IfOp>();
2007 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2008 IfOp::Adaptor adaptor,
2010 if (adaptor.getRegions().empty())
2012 Region *r = &adaptor.getThenRegion();
2015 Block &b = r->front();
2018 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2021 TypeRange types = yieldOp.getOperandTypes();
2022 llvm::append_range(inferredReturnTypes, types);
2028 return build(builder, result, resultTypes, cond,
false,
2034 bool addElseBlock) {
2035 assert((!addElseBlock || addThenBlock) &&
2036 "must not create else block w/o then block");
2051 bool withElseRegion) {
2052 build(builder, result,
TypeRange{}, cond, withElseRegion);
2064 if (resultTypes.empty())
2065 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2069 if (withElseRegion) {
2071 if (resultTypes.empty())
2072 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2079 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2086 thenBuilder(builder, result.
location);
2092 elseBuilder(builder, result.
location);
2099 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2101 inferredReturnTypes))) {
2102 result.
addTypes(inferredReturnTypes);
2107 if (getNumResults() != 0 && getElseRegion().empty())
2108 return emitOpError(
"must have an else block if defining values");
2146 bool printBlockTerminators =
false;
2148 p <<
" " << getCondition();
2149 if (!getResults().empty()) {
2150 p <<
" -> (" << getResultTypes() <<
")";
2152 printBlockTerminators =
true;
2157 printBlockTerminators);
2160 auto &elseRegion = getElseRegion();
2161 if (!elseRegion.
empty()) {
2165 printBlockTerminators);
2182 Region *elseRegion = &this->getElseRegion();
2183 if (elseRegion->
empty())
2191 FoldAdaptor adaptor(operands, *
this);
2192 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2193 if (!boolAttr || boolAttr.getValue())
2194 regions.emplace_back(&getThenRegion());
2197 if (!boolAttr || !boolAttr.getValue()) {
2198 if (!getElseRegion().empty())
2199 regions.emplace_back(&getElseRegion());
2201 regions.emplace_back(getResults());
2205 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2208 if (getElseRegion().empty())
2211 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2218 getConditionMutable().assign(xorStmt.getLhs());
2222 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2223 getElseRegion().getBlocks());
2224 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2225 getThenRegion().getBlocks(), thenBlock);
2229 void IfOp::getRegionInvocationBounds(
2232 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2235 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2236 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2239 invocationBounds.assign(2, {0, 1});
2255 llvm::transform(usedResults, std::back_inserter(usedOperands),
2260 [&]() { yieldOp->setOperands(usedOperands); });
2263 LogicalResult matchAndRewrite(IfOp op,
2267 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2268 [](
OpResult result) { return !result.use_empty(); });
2271 if (usedResults.size() == op.getNumResults())
2276 llvm::transform(usedResults, std::back_inserter(newTypes),
2281 rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2287 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2288 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2293 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2302 LogicalResult matchAndRewrite(IfOp op,
2310 else if (!op.getElseRegion().empty())
2324 LogicalResult matchAndRewrite(IfOp op,
2326 if (op->getNumResults() == 0)
2329 auto cond = op.getCondition();
2330 auto thenYieldArgs = op.thenYield().getOperands();
2331 auto elseYieldArgs = op.elseYield().getOperands();
2334 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2335 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2336 &op.getElseRegion() == falseVal.getParentRegion())
2337 nonHoistable.push_back(trueVal.getType());
2341 if (nonHoistable.size() == op->getNumResults())
2344 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond,
2346 if (replacement.thenBlock())
2347 rewriter.
eraseBlock(replacement.thenBlock());
2348 replacement.getThenRegion().takeBody(op.getThenRegion());
2349 replacement.getElseRegion().takeBody(op.getElseRegion());
2352 assert(thenYieldArgs.size() == results.size());
2353 assert(elseYieldArgs.size() == results.size());
2358 for (
const auto &it :
2360 Value trueVal = std::get<0>(it.value());
2361 Value falseVal = std::get<1>(it.value());
2364 results[it.index()] = replacement.getResult(trueYields.size());
2365 trueYields.push_back(trueVal);
2366 falseYields.push_back(falseVal);
2367 }
else if (trueVal == falseVal)
2368 results[it.index()] = trueVal;
2370 results[it.index()] = rewriter.
create<arith::SelectOp>(
2371 op.getLoc(), cond, trueVal, falseVal);
2401 LogicalResult matchAndRewrite(IfOp op,
2413 Value constantTrue =
nullptr;
2414 Value constantFalse =
nullptr;
2417 llvm::make_early_inc_range(op.getCondition().getUses())) {
2422 constantTrue = rewriter.
create<arith::ConstantOp>(
2426 [&]() { use.
set(constantTrue); });
2427 }
else if (op.getElseRegion().isAncestor(
2432 constantFalse = rewriter.
create<arith::ConstantOp>(
2436 [&]() { use.
set(constantFalse); });
2480 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2483 LogicalResult matchAndRewrite(IfOp op,
2486 if (op.getNumResults() == 0)
2490 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2492 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2495 op.getOperation()->getIterator());
2498 for (
auto [trueResult, falseResult, opResult] :
2499 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2501 if (trueResult == falseResult) {
2502 if (!opResult.use_empty()) {
2503 opResult.replaceAllUsesWith(trueResult);
2514 bool trueVal = trueYield.
getValue();
2515 bool falseVal = falseYield.
getValue();
2516 if (!trueVal && falseVal) {
2517 if (!opResult.use_empty()) {
2518 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2520 op.getLoc(), op.getCondition(),
2530 if (trueVal && !falseVal) {
2531 if (!opResult.use_empty()) {
2532 opResult.replaceAllUsesWith(op.getCondition());
2565 LogicalResult matchAndRewrite(IfOp nextIf,
2567 Block *parent = nextIf->getBlock();
2568 if (nextIf == &parent->
front())
2571 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2579 Block *nextThen =
nullptr;
2580 Block *nextElse =
nullptr;
2581 if (nextIf.getCondition() == prevIf.getCondition()) {
2582 nextThen = nextIf.thenBlock();
2583 if (!nextIf.getElseRegion().empty())
2584 nextElse = nextIf.elseBlock();
2586 if (arith::XOrIOp notv =
2587 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2588 if (notv.getLhs() == prevIf.getCondition() &&
2590 nextElse = nextIf.thenBlock();
2591 if (!nextIf.getElseRegion().empty())
2592 nextThen = nextIf.elseBlock();
2595 if (arith::XOrIOp notv =
2596 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2597 if (notv.getLhs() == nextIf.getCondition() &&
2599 nextElse = nextIf.thenBlock();
2600 if (!nextIf.getElseRegion().empty())
2601 nextThen = nextIf.elseBlock();
2605 if (!nextThen && !nextElse)
2609 if (!prevIf.getElseRegion().empty())
2610 prevElseYielded = prevIf.elseYield().getOperands();
2613 for (
auto it : llvm::zip(prevIf.getResults(),
2614 prevIf.thenYield().getOperands(), prevElseYielded))
2616 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2620 use.
set(std::get<1>(it));
2625 use.
set(std::get<2>(it));
2631 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2633 IfOp combinedIf = rewriter.
create<IfOp>(
2634 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2635 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2638 combinedIf.getThenRegion(),
2639 combinedIf.getThenRegion().begin());
2642 YieldOp thenYield = combinedIf.thenYield();
2643 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2644 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2648 llvm::append_range(mergedYields, thenYield2.getOperands());
2649 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2655 combinedIf.getElseRegion(),
2656 combinedIf.getElseRegion().begin());
2659 if (combinedIf.getElseRegion().empty()) {
2661 combinedIf.getElseRegion(),
2662 combinedIf.getElseRegion().
begin());
2664 YieldOp elseYield = combinedIf.elseYield();
2665 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2666 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2671 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2673 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2682 if (pair.index() < prevIf.getNumResults())
2683 prevValues.push_back(pair.value());
2685 nextValues.push_back(pair.value());
2697 LogicalResult matchAndRewrite(IfOp ifOp,
2700 if (ifOp.getNumResults())
2702 Block *elseBlock = ifOp.elseBlock();
2703 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2707 newIfOp.getThenRegion().begin());
2732 LogicalResult matchAndRewrite(IfOp op,
2734 auto nestedOps = op.thenBlock()->without_terminator();
2736 if (!llvm::hasSingleElement(nestedOps))
2740 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2743 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2747 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2753 llvm::append_range(elseYield, op.elseYield().getOperands());
2767 if (tup.value().getDefiningOp() == nestedIf) {
2768 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2769 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2770 elseYield[tup.index()]) {
2775 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2788 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2791 elseYieldsToUpgradeToSelect.push_back(tup.index());
2795 Value newCondition = rewriter.
create<arith::AndIOp>(
2796 loc, op.getCondition(), nestedIf.getCondition());
2797 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2801 llvm::append_range(results, newIf.getResults());
2804 for (
auto idx : elseYieldsToUpgradeToSelect)
2805 results[idx] = rewriter.
create<arith::SelectOp>(
2806 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2808 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2811 if (!elseYield.empty()) {
2814 rewriter.
create<YieldOp>(loc, elseYield);
2825 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2826 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2827 RemoveStaticCondition, RemoveUnusedResults,
2828 ReplaceIfYieldWithConditionOrValue>(context);
2831 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2832 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2833 Block *IfOp::elseBlock() {
2834 Region &r = getElseRegion();
2839 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2845 void ParallelOp::build(
2855 ParallelOp::getOperandSegmentSizeAttr(),
2857 static_cast<int32_t>(upperBounds.size()),
2858 static_cast<int32_t>(steps.size()),
2859 static_cast<int32_t>(initVals.size())}));
2863 unsigned numIVs = steps.size();
2869 if (bodyBuilderFn) {
2871 bodyBuilderFn(builder, result.
location,
2876 if (initVals.empty())
2877 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2880 void ParallelOp::build(
2887 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2890 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2894 wrapper = wrappedBuilderFn;
2896 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2905 if (stepValues.empty())
2907 "needs at least one tuple element for lowerBound, upperBound and step");
2910 for (
Value stepValue : stepValues)
2913 return emitOpError(
"constant step operand must be positive");
2917 Block *body = getBody();
2919 return emitOpError() <<
"expects the same number of induction variables: "
2921 <<
" as bound and step values: " << stepValues.size();
2923 if (!arg.getType().isIndex())
2925 "expects arguments for the induction variable to be of index type");
2928 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2929 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2934 auto resultsSize = getResults().size();
2935 auto reductionsSize = reduceOp.getReductions().size();
2936 auto initValsSize = getInitVals().size();
2937 if (resultsSize != reductionsSize)
2938 return emitOpError() <<
"expects number of results: " << resultsSize
2939 <<
" to be the same as number of reductions: "
2941 if (resultsSize != initValsSize)
2942 return emitOpError() <<
"expects number of results: " << resultsSize
2943 <<
" to be the same as number of initial values: "
2947 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2948 auto resultType = getOperation()->getResult(i).getType();
2949 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2950 if (resultType != reductionOperandType)
2951 return reduceOp.emitOpError()
2952 <<
"expects type of " << i
2953 <<
"-th reduction operand: " << reductionOperandType
2954 <<
" to be the same as the " << i
2955 <<
"-th result type: " << resultType;
3003 for (
auto &iv : ivs)
3010 ParallelOp::getOperandSegmentSizeAttr(),
3012 static_cast<int32_t>(upper.size()),
3013 static_cast<int32_t>(steps.size()),
3014 static_cast<int32_t>(initVals.size())}));
3023 ParallelOp::ensureTerminator(*body, builder, result.
location);
3028 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3029 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3030 if (!getInitVals().empty())
3031 p <<
" init (" << getInitVals() <<
")";
3036 (*this)->getAttrs(),
3037 ParallelOp::getOperandSegmentSizeAttr());
3042 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3046 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3050 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3054 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3059 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3061 return ParallelOp();
3062 assert(ivArg.getOwner() &&
"unlinked block argument");
3063 auto *containingOp = ivArg.getOwner()->getParentOp();
3064 return dyn_cast<ParallelOp>(containingOp);
3069 struct ParallelOpSingleOrZeroIterationDimsFolder
3073 LogicalResult matchAndRewrite(ParallelOp op,
3080 for (
auto [lb, ub, step, iv] :
3081 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3082 op.getInductionVars())) {
3084 if (numIterations.has_value()) {
3086 if (*numIterations == 0) {
3087 rewriter.
replaceOp(op, op.getInitVals());
3092 if (*numIterations == 1) {
3097 newLowerBounds.push_back(lb);
3098 newUpperBounds.push_back(ub);
3099 newSteps.push_back(step);
3102 if (newLowerBounds.size() == op.getLowerBound().size())
3105 if (newLowerBounds.empty()) {
3109 results.reserve(op.getInitVals().size());
3110 for (
auto &bodyOp : op.getBody()->without_terminator())
3111 rewriter.
clone(bodyOp, mapping);
3112 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3113 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3114 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3115 auto initValIndex = results.size();
3116 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3120 rewriter.
clone(reduceBodyOp, mapping);
3123 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3124 results.push_back(result);
3132 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3133 newSteps, op.getInitVals(),
nullptr);
3139 newOp.getRegion().begin(), mapping);
3140 rewriter.
replaceOp(op, newOp.getResults());
3148 LogicalResult matchAndRewrite(ParallelOp op,
3150 Block &outerBody = *op.getBody();
3154 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3159 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3160 llvm::is_contained(innerOp.getUpperBound(), val) ||
3161 llvm::is_contained(innerOp.getStep(), val))
3165 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3170 Block &innerBody = *innerOp.getBody();
3171 assert(iterVals.size() ==
3179 builder.
clone(op, mapping);
3182 auto concatValues = [](
const auto &first,
const auto &second) {
3184 ret.reserve(first.size() + second.size());
3185 ret.assign(first.begin(), first.end());
3186 ret.append(second.begin(), second.end());
3190 auto newLowerBounds =
3191 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3192 auto newUpperBounds =
3193 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3194 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3197 newSteps, std::nullopt,
3208 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3217 void ParallelOp::getSuccessorRegions(
3235 for (
Value v : operands) {
3244 LogicalResult ReduceOp::verifyRegions() {
3247 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3248 auto type = getOperands()[i].getType();
3251 return emitOpError() << i <<
"-th reduction has an empty body";
3254 return arg.getType() != type;
3256 return emitOpError() <<
"expected two block arguments with type " << type
3257 <<
" in the " << i <<
"-th reduction region";
3261 return emitOpError(
"reduction bodies must be terminated with an "
3262 "'scf.reduce.return' op");
3281 Block *reductionBody = getOperation()->getBlock();
3283 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3285 if (expectedResultType != getResult().
getType())
3286 return emitOpError() <<
"must have type " << expectedResultType
3287 <<
" (the type of the reduction inputs)";
3297 ValueRange inits, BodyBuilderFn beforeBuilder,
3298 BodyBuilderFn afterBuilder) {
3306 beforeArgLocs.reserve(inits.size());
3307 for (
Value operand : inits) {
3308 beforeArgLocs.push_back(operand.getLoc());
3313 inits.getTypes(), beforeArgLocs);
3322 resultTypes, afterArgLocs);
3328 ConditionOp WhileOp::getConditionOp() {
3329 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3332 YieldOp WhileOp::getYieldOp() {
3333 return cast<YieldOp>(getAfterBody()->getTerminator());
3336 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3337 return getYieldOp().getResultsMutable();
3341 return getBeforeBody()->getArguments();
3345 return getAfterBody()->getArguments();
3349 return getBeforeArguments();
3353 assert(point == getBefore() &&
3354 "WhileOp is expected to branch only to the first region");
3362 regions.emplace_back(&getBefore(), getBefore().getArguments());
3366 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3367 "there are only two regions in a WhileOp");
3369 if (point == getAfter()) {
3370 regions.emplace_back(&getBefore(), getBefore().getArguments());
3374 regions.emplace_back(getResults());
3375 regions.emplace_back(&getAfter(), getAfter().getArguments());
3379 return {&getBefore(), &getAfter()};
3400 FunctionType functionType;
3405 result.
addTypes(functionType.getResults());
3407 if (functionType.getNumInputs() != operands.size()) {
3409 <<
"expected as many input types as operands "
3410 <<
"(expected " << operands.size() <<
" got "
3411 << functionType.getNumInputs() <<
")";
3421 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3422 regionArgs[i].type = functionType.getInput(i);
3424 return failure(parser.
parseRegion(*before, regionArgs) ||
3444 template <
typename OpTy>
3447 if (left.size() != right.size())
3448 return op.emitOpError(
"expects the same number of ") << message;
3450 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3451 if (left[i] != right[i]) {
3454 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3455 <<
" and " << right[i];
3464 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3466 "expects the 'before' region to terminate with 'scf.condition'");
3467 if (!beforeTerminator)
3470 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3472 "expects the 'after' region to terminate with 'scf.yield'");
3473 return success(afterTerminator !=
nullptr);
3499 LogicalResult matchAndRewrite(WhileOp op,
3501 auto term = op.getConditionOp();
3505 Value constantTrue =
nullptr;
3507 bool replaced =
false;
3508 for (
auto yieldedAndBlockArgs :
3509 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3510 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3511 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3513 constantTrue = rewriter.
create<arith::ConstantOp>(
3514 op.getLoc(), term.getCondition().getType(),
3523 return success(replaced);
3575 struct RemoveLoopInvariantArgsFromBeforeBlock
3579 LogicalResult matchAndRewrite(WhileOp op,
3581 Block &afterBlock = *op.getAfterBody();
3583 ConditionOp condOp = op.getConditionOp();
3588 bool canSimplify =
false;
3589 for (
const auto &it :
3591 auto index =
static_cast<unsigned>(it.index());
3592 auto [initVal, yieldOpArg] = it.value();
3595 if (yieldOpArg == initVal) {
3604 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3605 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3606 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3607 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3620 for (
const auto &it :
3622 auto index =
static_cast<unsigned>(it.index());
3623 auto [initVal, yieldOpArg] = it.value();
3627 if (yieldOpArg == initVal) {
3628 beforeBlockInitValMap.insert({index, initVal});
3636 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3637 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3638 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3639 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3640 beforeBlockInitValMap.insert({index, initVal});
3645 newInitArgs.emplace_back(initVal);
3646 newYieldOpArgs.emplace_back(yieldOpArg);
3647 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3657 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3660 &newWhile.getBefore(), {},
3663 Block &beforeBlock = *op.getBeforeBody();
3670 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3673 if (beforeBlockInitValMap.count(i) != 0)
3674 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3676 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3679 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3681 newWhile.getAfter().begin());
3683 rewriter.
replaceOp(op, newWhile.getResults());
3728 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3731 LogicalResult matchAndRewrite(WhileOp op,
3733 Block &beforeBlock = *op.getBeforeBody();
3734 ConditionOp condOp = op.getConditionOp();
3737 bool canSimplify =
false;
3738 for (
Value condOpArg : condOpArgs) {
3758 auto index =
static_cast<unsigned>(it.index());
3759 Value condOpArg = it.value();
3764 condOpInitValMap.insert({index, condOpArg});
3766 newCondOpArgs.emplace_back(condOpArg);
3767 newAfterBlockType.emplace_back(condOpArg.
getType());
3768 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3779 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3782 Block &newAfterBlock =
3784 newAfterBlockType, newAfterBlockArgLocs);
3786 Block &afterBlock = *op.getAfterBody();
3793 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3794 Value afterBlockArg, result;
3797 if (condOpInitValMap.count(i) != 0) {
3798 afterBlockArg = condOpInitValMap[i];
3799 result = afterBlockArg;
3801 afterBlockArg = newAfterBlock.getArgument(
j);
3802 result = newWhile.getResult(
j);
3805 newAfterBlockArgs[i] = afterBlockArg;
3806 newWhileResults[i] = result;
3809 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3811 newWhile.getBefore().begin());
3813 rewriter.
replaceOp(op, newWhileResults);
3847 LogicalResult matchAndRewrite(WhileOp op,
3849 auto term = op.getConditionOp();
3850 auto afterArgs = op.getAfterArguments();
3851 auto termArgs = term.getArgs();
3858 bool needUpdate =
false;
3859 for (
const auto &it :
3861 auto i =
static_cast<unsigned>(it.index());
3862 Value result = std::get<0>(it.value());
3863 Value afterArg = std::get<1>(it.value());
3864 Value termArg = std::get<2>(it.value());
3868 newResultsIndices.emplace_back(i);
3869 newTermArgs.emplace_back(termArg);
3870 newResultTypes.emplace_back(result.
getType());
3871 newArgLocs.emplace_back(result.
getLoc());
3886 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3889 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3896 newResults[it.value()] = newWhile.getResult(it.index());
3897 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3901 newWhile.getBefore().begin());
3903 Block &afterBlock = *op.getAfterBody();
3904 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3936 LogicalResult matchAndRewrite(scf::WhileOp op,
3938 using namespace scf;
3939 auto cond = op.getConditionOp();
3940 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3944 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3945 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3946 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3949 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3950 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3954 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3957 if (cmp2.getPredicate() == cmp.getPredicate())
3958 samePredicate =
true;
3959 else if (cmp2.getPredicate() ==
3961 samePredicate =
false;
3979 LogicalResult matchAndRewrite(WhileOp op,
3982 if (!llvm::any_of(op.getBeforeArguments(),
3983 [](
Value arg) { return arg.use_empty(); }))
3986 YieldOp yield = op.getYieldOp();
3991 llvm::BitVector argsToErase;
3993 size_t argsCount = op.getBeforeArguments().size();
3994 newYields.reserve(argsCount);
3995 newInits.reserve(argsCount);
3996 argsToErase.reserve(argsCount);
3997 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3998 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3999 if (beforeArg.use_empty()) {
4000 argsToErase.push_back(
true);
4002 argsToErase.push_back(
false);
4003 newYields.emplace_back(yieldValue);
4004 newInits.emplace_back(initValue);
4008 Block &beforeBlock = *op.getBeforeBody();
4009 Block &afterBlock = *op.getAfterBody();
4015 rewriter.
create<WhileOp>(loc, op.getResultTypes(), newInits,
4017 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4018 Block &newAfterBlock = *newWhileOp.getAfterBody();
4024 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4025 newBeforeBlock.getArguments());
4029 rewriter.
replaceOp(op, newWhileOp.getResults());
4038 LogicalResult matchAndRewrite(WhileOp op,
4040 ConditionOp condOp = op.getConditionOp();
4045 if (argsSet.size() == condOpArgs.size())
4048 llvm::SmallDenseMap<Value, unsigned> argsMap;
4050 argsMap.reserve(condOpArgs.size());
4051 newArgs.reserve(condOpArgs.size());
4052 for (
Value arg : condOpArgs) {
4053 if (!argsMap.count(arg)) {
4054 auto pos =
static_cast<unsigned>(argsMap.size());
4055 argsMap.insert({arg, pos});
4056 newArgs.emplace_back(arg);
4063 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4064 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4066 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4067 Block &newAfterBlock = *newWhileOp.getAfterBody();
4072 auto it = argsMap.find(arg);
4073 assert(it != argsMap.end());
4074 auto pos = it->second;
4075 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4076 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4084 Block &beforeBlock = *op.getBeforeBody();
4085 Block &afterBlock = *op.getAfterBody();
4087 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4088 newBeforeBlock.getArguments());
4089 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4097 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4099 if (args1.size() != args2.size())
4100 return std::nullopt;
4104 auto it = llvm::find(args2, arg1);
4105 if (it == args2.end())
4106 return std::nullopt;
4108 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4115 llvm::SmallDenseSet<Value> set;
4116 for (
Value arg : args) {
4117 if (!set.insert(arg).second)
4130 LogicalResult matchAndRewrite(WhileOp loop,
4132 auto oldBefore = loop.getBeforeBody();
4133 ConditionOp oldTerm = loop.getConditionOp();
4134 ValueRange beforeArgs = oldBefore->getArguments();
4136 if (beforeArgs == termArgs)
4139 if (hasDuplicates(termArgs))
4142 auto mapping = getArgsMapping(beforeArgs, termArgs);
4153 auto oldAfter = loop.getAfterBody();
4157 newResultTypes[
j] = loop.getResult(i).getType();
4159 auto newLoop = rewriter.
create<WhileOp>(
4160 loop.getLoc(), newResultTypes, loop.getInits(),
4162 auto newBefore = newLoop.getBeforeBody();
4163 auto newAfter = newLoop.getAfterBody();
4168 newResults[i] = newLoop.getResult(
j);
4169 newAfterArgs[i] = newAfter->getArgument(
j);
4173 newBefore->getArguments());
4185 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4186 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4187 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4188 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4202 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4205 caseValues.push_back(value);
4214 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4216 p <<
"case " << value <<
' ';
4222 if (getCases().size() != getCaseRegions().size()) {
4223 return emitOpError(
"has ")
4224 << getCaseRegions().size() <<
" case regions but "
4225 << getCases().size() <<
" case values";
4229 for (int64_t value : getCases())
4230 if (!valueSet.insert(value).second)
4231 return emitOpError(
"has duplicate case value: ") << value;
4233 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4235 return emitOpError(
"expected region to end with scf.yield, but got ")
4238 if (yield.getNumOperands() != getNumResults()) {
4239 return (emitOpError(
"expected each region to return ")
4240 << getNumResults() <<
" values, but " << name <<
" returns "
4241 << yield.getNumOperands())
4242 .attachNote(yield.getLoc())
4243 <<
"see yield operation here";
4245 for (
auto [idx, result, operand] :
4246 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4247 yield.getOperandTypes())) {
4248 if (result == operand)
4250 return (emitOpError(
"expected result #")
4251 << idx <<
" of each region to be " << result)
4252 .attachNote(yield.getLoc())
4253 << name <<
" returns " << operand <<
" here";
4258 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4261 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4267 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4269 Block &scf::IndexSwitchOp::getDefaultBlock() {
4270 return getDefaultRegion().
front();
4273 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4274 assert(idx < getNumCases() &&
"case index out-of-bounds");
4275 return getCaseRegions()[idx].front();
4278 void IndexSwitchOp::getSuccessorRegions(
4282 successors.emplace_back(getResults());
4286 llvm::append_range(successors, getRegions());
4289 void IndexSwitchOp::getEntrySuccessorRegions(
4292 FoldAdaptor adaptor(operands, *
this);
4295 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4297 llvm::append_range(successors, getRegions());
4303 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4304 if (caseValue == arg.getInt()) {
4305 successors.emplace_back(&caseRegion);
4309 successors.emplace_back(&getDefaultRegion());
4312 void IndexSwitchOp::getRegionInvocationBounds(
4314 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4315 if (!operandValue) {
4321 unsigned liveIndex = getNumRegions() - 1;
4322 const auto *it = llvm::find(getCases(), operandValue.getInt());
4323 if (it != getCases().end())
4324 liveIndex = std::distance(getCases().begin(), it);
4325 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4326 bounds.emplace_back(0, i == liveIndex);
4337 if (!maybeCst.has_value())
4339 int64_t cst = *maybeCst;
4340 int64_t caseIdx, e = op.getNumCases();
4341 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4342 if (cst == op.getCases()[caseIdx])
4346 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4347 : op.getDefaultRegion();
4348 Block &source = r.front();
4371 #define GET_OP_CLASSES
4372 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.