26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
77 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78 InParallelOp, ReduceReturnOp>();
79 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81 ForallOp, InParallelOp, WhileOp, YieldOp>();
82 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
87 builder.
create<scf::YieldOp>(loc);
92 template <
typename TerminatorTy>
94 StringRef errorMessage) {
97 terminatorOperation = ®ion.
front().
back();
98 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
102 if (terminatorOperation)
103 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
115 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
161 if (getRegion().empty())
162 return emitOpError(
"region needs to have at least one block");
163 if (getRegion().front().getNumArguments() > 0)
164 return emitOpError(
"region cannot have any arguments");
187 if (!llvm::hasSingleElement(op.getRegion()))
236 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
239 Block *prevBlock = op->getBlock();
243 rewriter.
create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
245 for (
Block &blk : op.getRegion()) {
246 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
248 rewriter.
create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
249 yieldOp.getResults());
257 for (
auto res : op.getResults())
258 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
270 void ExecuteRegionOp::getSuccessorRegions(
288 assert((point.
isParent() || point == getParentOp().getAfter()) &&
289 "condition op can only exit the loop or branch to the after"
292 return getArgsMutable();
295 void ConditionOp::getSuccessorRegions(
297 FoldAdaptor adaptor(operands, *
this);
299 WhileOp whileOp = getParentOp();
303 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
304 if (!boolAttr || boolAttr.getValue())
305 regions.emplace_back(&whileOp.getAfter(),
306 whileOp.getAfter().getArguments());
307 if (!boolAttr || !boolAttr.getValue())
308 regions.emplace_back(whileOp.getResults());
317 BodyBuilderFn bodyBuilder) {
322 for (
Value v : initArgs)
328 for (
Value v : initArgs)
334 if (initArgs.empty() && !bodyBuilder) {
335 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
336 }
else if (bodyBuilder) {
346 if (getInitArgs().size() != getNumResults())
348 "mismatch in number of loop-carried values and defined values");
353 LogicalResult ForOp::verifyRegions() {
358 "expected induction variable to be same type as bounds and step");
360 if (getNumRegionIterArgs() != getNumResults())
362 "mismatch in number of basic block args and defined values");
364 auto initArgs = getInitArgs();
365 auto iterArgs = getRegionIterArgs();
366 auto opResults = getResults();
368 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
370 return emitOpError() <<
"types mismatch between " << i
371 <<
"th iter operand and defined value";
373 return emitOpError() <<
"types mismatch between " << i
374 <<
"th iter region arg and defined value";
381 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
385 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
389 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
393 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
397 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
402 std::optional<int64_t> tripCount =
404 if (!tripCount.has_value() || tripCount != 1)
408 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
415 llvm::append_range(bbArgReplacements, getInitArgs());
419 getOperation()->getIterator(), bbArgReplacements);
435 StringRef prefix =
"") {
436 assert(blocksArgs.size() == initializers.size() &&
437 "expected same length of arguments and initializers");
438 if (initializers.empty())
442 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
443 p << std::get<0>(it) <<
" = " << std::get<1>(it);
449 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
453 if (!getInitArgs().empty())
454 p <<
" -> (" << getInitArgs().getTypes() <<
')';
457 p <<
" : " << t <<
' ';
460 !getInitArgs().empty());
482 regionArgs.push_back(inductionVariable);
492 if (regionArgs.size() != result.
types.size() + 1)
495 "mismatch in number of loop-carried values and defined values");
504 regionArgs.front().type = type;
505 for (
auto [iterArg, type] :
506 llvm::zip_equal(llvm::drop_begin(regionArgs), result.
types))
513 ForOp::ensureTerminator(*body, builder, result.
location);
522 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
523 operands, result.
types)) {
524 Type type = std::get<2>(argOperandType);
525 std::get<0>(argOperandType).type = type;
542 return getBody()->getArguments().drop_front(getNumInductionVars());
546 return getInitArgsMutable();
549 FailureOr<LoopLikeOpInterface>
550 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
552 bool replaceInitOperandUsesInLoop,
557 auto inits = llvm::to_vector(getInitArgs());
558 inits.append(newInitOperands.begin(), newInitOperands.end());
559 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
565 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
567 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
572 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
573 assert(newInitOperands.size() == newYieldedValues.size() &&
574 "expected as many new yield values as new iter operands");
576 yieldOp.getResultsMutable().append(newYieldedValues);
582 newLoop.getBody()->getArguments().take_front(
583 getBody()->getNumArguments()));
585 if (replaceInitOperandUsesInLoop) {
588 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
599 newLoop->getResults().take_front(getNumResults()));
600 return cast<LoopLikeOpInterface>(newLoop.getOperation());
604 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
607 assert(ivArg.getOwner() &&
"unlinked block argument");
608 auto *containingOp = ivArg.getOwner()->getParentOp();
609 return dyn_cast_or_null<ForOp>(containingOp);
613 return getInitArgs();
630 for (
auto [lb, ub, step] :
631 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
633 if (!tripCount.has_value() || *tripCount != 1)
642 return getBody()->getArguments().drop_front(getRank());
646 return getOutputsMutable();
652 scf::InParallelOp terminator = forallOp.getTerminator();
657 bbArgReplacements.append(forallOp.getOutputs().begin(),
658 forallOp.getOutputs().end());
662 forallOp->getIterator(), bbArgReplacements);
667 results.reserve(forallOp.getResults().size());
668 for (
auto &yieldingOp : terminator.getYieldingOps()) {
669 auto parallelInsertSliceOp =
670 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
672 Value dst = parallelInsertSliceOp.getDest();
673 Value src = parallelInsertSliceOp.getSource();
674 if (llvm::isa<TensorType>(src.
getType())) {
675 results.push_back(rewriter.
create<tensor::InsertSliceOp>(
676 forallOp.getLoc(), dst.
getType(), src, dst,
677 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
678 parallelInsertSliceOp.getStrides(),
679 parallelInsertSliceOp.getStaticOffsets(),
680 parallelInsertSliceOp.getStaticSizes(),
681 parallelInsertSliceOp.getStaticStrides()));
683 llvm_unreachable(
"unsupported terminator");
698 assert(lbs.size() == ubs.size() &&
699 "expected the same number of lower and upper bounds");
700 assert(lbs.size() == steps.size() &&
701 "expected the same number of lower bounds and steps");
706 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
708 assert(results.size() == iterArgs.size() &&
709 "loop nest body must return as many values as loop has iteration "
711 return LoopNest{{}, std::move(results)};
719 loops.reserve(lbs.size());
720 ivs.reserve(lbs.size());
723 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
724 auto loop = builder.
create<scf::ForOp>(
725 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
731 currentIterArgs = args;
732 currentLoc = nestedLoc;
738 loops.push_back(loop);
742 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
744 builder.
create<scf::YieldOp>(loc, loops[i + 1].getResults());
751 ? bodyBuilder(builder, currentLoc, ivs,
752 loops.back().getRegionIterArgs())
754 assert(results.size() == iterArgs.size() &&
755 "loop nest body must return as many values as loop has iteration "
758 builder.
create<scf::YieldOp>(loc, results);
762 llvm::append_range(nestResults, loops.front().getResults());
763 return LoopNest{std::move(loops), std::move(nestResults)};
776 bodyBuilder(nestedBuilder, nestedLoc, ivs);
785 assert(operand.
getOwner() == forOp);
790 "expected an iter OpOperand");
792 "Expected a different type");
794 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
796 newIterOperands.push_back(replacement);
799 newIterOperands.push_back(opOperand.get());
803 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
804 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
805 forOp.getStep(), newIterOperands);
806 newForOp->
setAttrs(forOp->getAttrs());
807 Block &newBlock = newForOp.getRegion().
front();
815 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
817 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
818 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
822 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
825 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
828 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
829 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
830 clonedYieldOp.getOperand(yieldIdx));
832 newYieldOperands[yieldIdx] = castOut;
833 rewriter.
create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
834 rewriter.
eraseOp(clonedYieldOp);
839 newResults[yieldIdx] =
840 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
860 LogicalResult matchAndRewrite(scf::ForOp forOp,
862 bool canonicalize =
false;
869 int64_t numResults = forOp.getNumResults();
871 keepMask.reserve(numResults);
874 newBlockTransferArgs.reserve(1 + numResults);
875 newBlockTransferArgs.push_back(
Value());
876 newIterArgs.reserve(forOp.getInitArgs().size());
877 newYieldValues.reserve(numResults);
878 newResultValues.reserve(numResults);
880 for (
auto [init, arg, result, yielded] :
881 llvm::zip(forOp.getInitArgs(),
882 forOp.getRegionIterArgs(),
884 forOp.getYieldedValues()
891 bool forwarded = (arg == yielded) || (init == yielded) ||
892 (arg.use_empty() && result.use_empty());
895 keepMask.push_back(
false);
896 newBlockTransferArgs.push_back(init);
897 newResultValues.push_back(init);
903 if (
auto it = initYieldToArg.find({init, yielded});
904 it != initYieldToArg.end()) {
906 keepMask.push_back(
false);
907 auto [sameArg, sameResult] = it->second;
911 newBlockTransferArgs.push_back(init);
912 newResultValues.push_back(init);
917 initYieldToArg.insert({{init, yielded}, {arg, result}});
918 keepMask.push_back(
true);
919 newIterArgs.push_back(init);
920 newYieldValues.push_back(yielded);
921 newBlockTransferArgs.push_back(
Value());
922 newResultValues.push_back(
Value());
928 scf::ForOp newForOp = rewriter.
create<scf::ForOp>(
929 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
930 forOp.getStep(), newIterArgs);
931 newForOp->
setAttrs(forOp->getAttrs());
932 Block &newBlock = newForOp.getRegion().
front();
936 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
938 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
939 Value &newResultVal = newResultValues[idx];
940 assert((blockTransferArg && newResultVal) ||
941 (!blockTransferArg && !newResultVal));
942 if (!blockTransferArg) {
943 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
944 newResultVal = newForOp.getResult(collapsedIdx++);
950 "unexpected argument size mismatch");
955 if (newIterArgs.empty()) {
956 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
959 rewriter.
replaceOp(forOp, newResultValues);
964 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
968 filteredOperands.reserve(newResultValues.size());
969 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
971 filteredOperands.push_back(mergedTerminator.getOperand(idx));
972 rewriter.
create<scf::YieldOp>(mergedTerminator.getLoc(),
976 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
977 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
978 cloneFilteredTerminator(mergedYieldOp);
979 rewriter.
eraseOp(mergedYieldOp);
980 rewriter.
replaceOp(forOp, newResultValues);
988 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
989 IntegerAttr clb, cub;
991 llvm::APInt lbValue = clb.getValue();
992 llvm::APInt ubValue = cub.getValue();
993 return (ubValue - lbValue).getSExtValue();
1002 return diff.getSExtValue();
1003 return std::nullopt;
1012 LogicalResult matchAndRewrite(ForOp op,
1016 if (op.getLowerBound() == op.getUpperBound()) {
1017 rewriter.
replaceOp(op, op.getInitArgs());
1021 std::optional<int64_t> diff =
1022 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1028 rewriter.
replaceOp(op, op.getInitArgs());
1032 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1033 if (!maybeStepValue)
1038 llvm::APInt stepValue = *maybeStepValue;
1039 if (stepValue.sge(*diff)) {
1041 blockArgs.reserve(op.getInitArgs().size() + 1);
1042 blockArgs.push_back(op.getLowerBound());
1043 llvm::append_range(blockArgs, op.getInitArgs());
1050 if (!llvm::hasSingleElement(block))
1054 if (llvm::any_of(op.getYieldedValues(),
1055 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1057 rewriter.
replaceOp(op, op.getYieldedValues());
1091 LogicalResult matchAndRewrite(ForOp op,
1093 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1094 OpOperand &iterOpOperand = std::get<0>(it);
1096 if (!incomingCast ||
1097 incomingCast.getSource().getType() == incomingCast.getType())
1102 incomingCast.getDest().getType(),
1103 incomingCast.getSource().getType()))
1105 if (!std::get<1>(it).hasOneUse())
1111 rewriter, op, iterOpOperand, incomingCast.getSource(),
1113 return b.create<tensor::CastOp>(loc, type, source);
1125 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1129 std::optional<APInt> ForOp::getConstantStep() {
1132 return step.getValue();
1136 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1137 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1143 if (
auto constantStep = getConstantStep())
1144 if (*constantStep == 1)
1157 unsigned numLoops = getRank();
1159 if (getNumResults() != getOutputs().size())
1160 return emitOpError(
"produces ")
1161 << getNumResults() <<
" results, but has only "
1162 << getOutputs().size() <<
" outputs";
1165 auto *body = getBody();
1167 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1168 for (int64_t i = 0; i < numLoops; ++i)
1170 return emitOpError(
"expects ")
1171 << i <<
"-th block argument to be an index";
1172 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1174 return emitOpError(
"type mismatch between ")
1175 << i <<
"-th output and corresponding block argument";
1176 if (getMapping().has_value() && !getMapping()->empty()) {
1177 if (getDeviceMappingAttrs().size() != numLoops)
1178 return emitOpError() <<
"mapping attribute size must match op rank";
1179 if (failed(getDeviceMaskingAttr()))
1181 <<
" supports at most one device masking attribute";
1187 getStaticLowerBound(),
1188 getDynamicLowerBound())))
1191 getStaticUpperBound(),
1192 getDynamicUpperBound())))
1195 getStaticStep(), getDynamicStep())))
1203 p <<
" (" << getInductionVars();
1204 if (isNormalized()) {
1225 if (!getRegionOutArgs().empty())
1226 p <<
"-> (" << getResultTypes() <<
") ";
1227 p.printRegion(getRegion(),
1229 getNumResults() > 0);
1230 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1231 getStaticLowerBoundAttrName(),
1232 getStaticUpperBoundAttrName(),
1233 getStaticStepAttrName()});
1238 auto indexType = b.getIndexType();
1258 unsigned numLoops = ivs.size();
1293 if (outOperands.size() != result.
types.size())
1295 "mismatch between out operands and types");
1305 std::unique_ptr<Region> region = std::make_unique<Region>();
1306 for (
auto &iv : ivs) {
1307 iv.type = b.getIndexType();
1308 regionArgs.push_back(iv);
1311 auto &out = it.value();
1312 out.type = result.
types[it.index()];
1313 regionArgs.push_back(out);
1319 ForallOp::ensureTerminator(*region, b, result.
location);
1331 {static_cast<int32_t>(dynamicLbs.size()),
1332 static_cast<int32_t>(dynamicUbs.size()),
1333 static_cast<int32_t>(dynamicSteps.size()),
1334 static_cast<int32_t>(outOperands.size())}));
1339 void ForallOp::build(
1343 std::optional<ArrayAttr> mapping,
1364 "operandSegmentSizes",
1366 static_cast<int32_t>(dynamicUbs.size()),
1367 static_cast<int32_t>(dynamicSteps.size()),
1368 static_cast<int32_t>(outputs.size())}));
1369 if (mapping.has_value()) {
1388 if (!bodyBuilderFn) {
1389 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1396 void ForallOp::build(
1399 std::optional<ArrayAttr> mapping,
1401 unsigned numLoops = ubs.size();
1404 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1408 bool ForallOp::isNormalized() {
1412 return intValue.has_value() && intValue == val;
1415 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1418 InParallelOp ForallOp::getTerminator() {
1419 return cast<InParallelOp>(getBody()->getTerminator());
1424 InParallelOp inParallelOp = getTerminator();
1425 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1426 if (
auto parallelInsertSliceOp =
1427 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1428 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1429 storeOps.push_back(parallelInsertSliceOp);
1439 for (
auto attr : getMapping()->getValue()) {
1440 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1447 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1448 DeviceMaskingAttrInterface res;
1451 for (
auto attr : getMapping()->getValue()) {
1452 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1461 bool ForallOp::usesLinearMapping() {
1465 return ifaces.front().isLinearMapping();
1468 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1473 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1475 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1479 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1481 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1485 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1491 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1494 assert(tidxArg.getOwner() &&
"unlinked block argument");
1495 auto *containingOp = tidxArg.getOwner()->getParentOp();
1496 return dyn_cast<ForallOp>(containingOp);
1504 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1506 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1510 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1513 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1522 LogicalResult matchAndRewrite(ForallOp op,
1537 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1538 op.setStaticLowerBound(staticLowerBound);
1542 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1543 op.setStaticUpperBound(staticUpperBound);
1546 op.getDynamicStepMutable().assign(dynamicStep);
1547 op.setStaticStep(staticStep);
1549 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1551 {static_cast<int32_t>(dynamicLowerBound.size()),
1552 static_cast<int32_t>(dynamicUpperBound.size()),
1553 static_cast<int32_t>(dynamicStep.size()),
1554 static_cast<int32_t>(op.getNumResults())}));
1636 LogicalResult matchAndRewrite(ForallOp forallOp,
1655 for (
OpResult result : forallOp.getResults()) {
1656 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1657 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1658 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1659 resultToDelete.insert(result);
1661 resultToReplace.push_back(result);
1662 newOuts.push_back(opOperand->
get());
1668 if (resultToDelete.empty())
1676 for (
OpResult result : resultToDelete) {
1677 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1678 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1680 forallOp.getCombiningOps(blockArg);
1681 for (
Operation *combiningOp : combiningOps)
1682 rewriter.
eraseOp(combiningOp);
1687 auto newForallOp = rewriter.
create<scf::ForallOp>(
1688 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1689 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1690 forallOp.getMapping(),
1695 Block *loopBody = forallOp.getBody();
1696 Block *newLoopBody = newForallOp.getBody();
1701 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1708 for (
OpResult result : forallOp.getResults()) {
1709 if (resultToDelete.count(result)) {
1710 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1712 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1715 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1719 for (
auto &&[oldResult, newResult] :
1720 llvm::zip(resultToReplace, newForallOp->getResults()))
1726 for (
OpResult oldResult : resultToDelete)
1728 forallOp.getTiedOpOperand(oldResult)->get());
1733 struct ForallOpSingleOrZeroIterationDimsFolder
1737 LogicalResult matchAndRewrite(ForallOp op,
1740 if (op.getMapping().has_value() && !op.getMapping()->empty())
1748 for (
auto [lb, ub, step, iv] :
1749 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1750 op.getMixedStep(), op.getInductionVars())) {
1752 if (numIterations.has_value()) {
1754 if (*numIterations == 0) {
1755 rewriter.
replaceOp(op, op.getOutputs());
1760 if (*numIterations == 1) {
1765 newMixedLowerBounds.push_back(lb);
1766 newMixedUpperBounds.push_back(ub);
1767 newMixedSteps.push_back(step);
1771 if (newMixedLowerBounds.empty()) {
1777 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1779 op,
"no dimensions have 0 or 1 iterations");
1784 newOp = rewriter.
create<ForallOp>(loc, newMixedLowerBounds,
1785 newMixedUpperBounds, newMixedSteps,
1786 op.getOutputs(), std::nullopt,
nullptr);
1787 newOp.getBodyRegion().getBlocks().clear();
1792 newOp.getStaticLowerBoundAttrName(),
1793 newOp.getStaticUpperBoundAttrName(),
1794 newOp.getStaticStepAttrName()};
1795 for (
const auto &namedAttr : op->getAttrs()) {
1796 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1799 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1803 newOp.getRegion().begin(), mapping);
1804 rewriter.
replaceOp(op, newOp.getResults());
1810 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1813 LogicalResult matchAndRewrite(ForallOp op,
1817 for (
auto [lb, ub, step, iv] :
1818 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1819 op.getMixedStep(), op.getInductionVars())) {
1823 if (!numIterations.has_value() || numIterations.value() != 1) {
1834 struct FoldTensorCastOfOutputIntoForallOp
1843 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1845 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1848 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1855 castOp.getSource().getType())) {
1859 tensorCastProducers[en.index()] =
1860 TypeCast{castOp.getSource().getType(), castOp.getType()};
1861 newOutputTensors[en.index()] = castOp.getSource();
1864 if (tensorCastProducers.empty())
1869 auto newForallOp = rewriter.
create<ForallOp>(
1870 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1871 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1873 auto castBlockArgs =
1874 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1875 for (
auto [index, cast] : tensorCastProducers) {
1876 Value &oldTypeBBArg = castBlockArgs[index];
1877 oldTypeBBArg = nestedBuilder.
create<tensor::CastOp>(
1878 nestedLoc, cast.dstType, oldTypeBBArg);
1883 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1884 ivsBlockArgs.append(castBlockArgs);
1886 bbArgs.front().getParentBlock(), ivsBlockArgs);
1892 auto terminator = newForallOp.getTerminator();
1893 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1894 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1895 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1896 insertSliceOp.getDestMutable().assign(outputBlockArg);
1902 for (
auto &item : tensorCastProducers) {
1903 Value &oldTypeResult = castResults[item.first];
1904 oldTypeResult = rewriter.
create<tensor::CastOp>(loc, item.second.dstType,
1907 rewriter.
replaceOp(forallOp, castResults);
1916 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1917 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1918 ForallOpSingleOrZeroIterationDimsFolder,
1919 ForallOpReplaceConstantInductionVar>(context);
1948 scf::ForallOp forallOp =
1949 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1951 return this->emitOpError(
"expected forall op parent");
1954 for (
Operation &op : getRegion().front().getOperations()) {
1955 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1956 return this->emitOpError(
"expected only ")
1957 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1961 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1963 if (!llvm::is_contained(regionOutArgs, dest))
1964 return op.emitOpError(
"may only insert into an output block argument");
1981 std::unique_ptr<Region> region = std::make_unique<Region>();
1985 if (region->empty())
1995 OpResult InParallelOp::getParentResult(int64_t idx) {
1996 return getOperation()->getParentOp()->getResult(idx);
2000 return llvm::to_vector<4>(
2001 llvm::map_range(getYieldingOps(), [](
Operation &op) {
2003 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
2004 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
2009 return getRegion().front().getOperations();
2017 assert(a &&
"expected non-empty operation");
2018 assert(b &&
"expected non-empty operation");
2023 if (ifOp->isProperAncestor(b))
2026 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2027 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2029 ifOp = ifOp->getParentOfType<IfOp>();
2037 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2038 IfOp::Adaptor adaptor,
2040 if (adaptor.getRegions().empty())
2042 Region *r = &adaptor.getThenRegion();
2048 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2051 TypeRange types = yieldOp.getOperandTypes();
2052 llvm::append_range(inferredReturnTypes, types);
2058 return build(builder, result, resultTypes, cond,
false,
2064 bool addElseBlock) {
2065 assert((!addElseBlock || addThenBlock) &&
2066 "must not create else block w/o then block");
2081 bool withElseRegion) {
2082 build(builder, result,
TypeRange{}, cond, withElseRegion);
2094 if (resultTypes.empty())
2095 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2099 if (withElseRegion) {
2101 if (resultTypes.empty())
2102 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2109 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2116 thenBuilder(builder, result.
location);
2122 elseBuilder(builder, result.
location);
2129 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2131 inferredReturnTypes))) {
2132 result.
addTypes(inferredReturnTypes);
2137 if (getNumResults() != 0 && getElseRegion().empty())
2138 return emitOpError(
"must have an else block if defining values");
2176 bool printBlockTerminators =
false;
2178 p <<
" " << getCondition();
2179 if (!getResults().empty()) {
2180 p <<
" -> (" << getResultTypes() <<
")";
2182 printBlockTerminators =
true;
2187 printBlockTerminators);
2190 auto &elseRegion = getElseRegion();
2191 if (!elseRegion.
empty()) {
2195 printBlockTerminators);
2212 Region *elseRegion = &this->getElseRegion();
2213 if (elseRegion->
empty())
2221 FoldAdaptor adaptor(operands, *
this);
2222 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2223 if (!boolAttr || boolAttr.getValue())
2224 regions.emplace_back(&getThenRegion());
2227 if (!boolAttr || !boolAttr.getValue()) {
2228 if (!getElseRegion().empty())
2229 regions.emplace_back(&getElseRegion());
2231 regions.emplace_back(getResults());
2235 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2238 if (getElseRegion().empty())
2241 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2248 getConditionMutable().assign(xorStmt.getLhs());
2252 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2253 getElseRegion().getBlocks());
2254 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2255 getThenRegion().getBlocks(), thenBlock);
2259 void IfOp::getRegionInvocationBounds(
2262 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2265 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2266 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2269 invocationBounds.assign(2, {0, 1});
2285 llvm::transform(usedResults, std::back_inserter(usedOperands),
2290 [&]() { yieldOp->setOperands(usedOperands); });
2293 LogicalResult matchAndRewrite(IfOp op,
2297 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2298 [](
OpResult result) { return !result.use_empty(); });
2301 if (usedResults.size() == op.getNumResults())
2306 llvm::transform(usedResults, std::back_inserter(newTypes),
2311 rewriter.
create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2317 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2318 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2323 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2332 LogicalResult matchAndRewrite(IfOp op,
2340 else if (!op.getElseRegion().empty())
2354 LogicalResult matchAndRewrite(IfOp op,
2356 if (op->getNumResults() == 0)
2359 auto cond = op.getCondition();
2360 auto thenYieldArgs = op.thenYield().getOperands();
2361 auto elseYieldArgs = op.elseYield().getOperands();
2364 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2365 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2366 &op.getElseRegion() == falseVal.getParentRegion())
2367 nonHoistable.push_back(trueVal.getType());
2371 if (nonHoistable.size() == op->getNumResults())
2374 IfOp replacement = rewriter.
create<IfOp>(op.getLoc(), nonHoistable, cond,
2376 if (replacement.thenBlock())
2377 rewriter.
eraseBlock(replacement.thenBlock());
2378 replacement.getThenRegion().takeBody(op.getThenRegion());
2379 replacement.getElseRegion().takeBody(op.getElseRegion());
2382 assert(thenYieldArgs.size() == results.size());
2383 assert(elseYieldArgs.size() == results.size());
2388 for (
const auto &it :
2390 Value trueVal = std::get<0>(it.value());
2391 Value falseVal = std::get<1>(it.value());
2394 results[it.index()] = replacement.getResult(trueYields.size());
2395 trueYields.push_back(trueVal);
2396 falseYields.push_back(falseVal);
2397 }
else if (trueVal == falseVal)
2398 results[it.index()] = trueVal;
2400 results[it.index()] = rewriter.
create<arith::SelectOp>(
2401 op.getLoc(), cond, trueVal, falseVal);
2431 LogicalResult matchAndRewrite(IfOp op,
2443 Value constantTrue =
nullptr;
2444 Value constantFalse =
nullptr;
2447 llvm::make_early_inc_range(op.getCondition().getUses())) {
2452 constantTrue = rewriter.
create<arith::ConstantOp>(
2456 [&]() { use.
set(constantTrue); });
2457 }
else if (op.getElseRegion().isAncestor(
2462 constantFalse = rewriter.
create<arith::ConstantOp>(
2466 [&]() { use.
set(constantFalse); });
2510 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2513 LogicalResult matchAndRewrite(IfOp op,
2516 if (op.getNumResults() == 0)
2520 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2522 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2525 op.getOperation()->getIterator());
2528 for (
auto [trueResult, falseResult, opResult] :
2529 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2531 if (trueResult == falseResult) {
2532 if (!opResult.use_empty()) {
2533 opResult.replaceAllUsesWith(trueResult);
2544 bool trueVal = trueYield.
getValue();
2545 bool falseVal = falseYield.
getValue();
2546 if (!trueVal && falseVal) {
2547 if (!opResult.use_empty()) {
2548 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2550 op.getLoc(), op.getCondition(),
2560 if (trueVal && !falseVal) {
2561 if (!opResult.use_empty()) {
2562 opResult.replaceAllUsesWith(op.getCondition());
2595 LogicalResult matchAndRewrite(IfOp nextIf,
2597 Block *parent = nextIf->getBlock();
2598 if (nextIf == &parent->
front())
2601 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2609 Block *nextThen =
nullptr;
2610 Block *nextElse =
nullptr;
2611 if (nextIf.getCondition() == prevIf.getCondition()) {
2612 nextThen = nextIf.thenBlock();
2613 if (!nextIf.getElseRegion().empty())
2614 nextElse = nextIf.elseBlock();
2616 if (arith::XOrIOp notv =
2617 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2618 if (notv.getLhs() == prevIf.getCondition() &&
2620 nextElse = nextIf.thenBlock();
2621 if (!nextIf.getElseRegion().empty())
2622 nextThen = nextIf.elseBlock();
2625 if (arith::XOrIOp notv =
2626 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2627 if (notv.getLhs() == nextIf.getCondition() &&
2629 nextElse = nextIf.thenBlock();
2630 if (!nextIf.getElseRegion().empty())
2631 nextThen = nextIf.elseBlock();
2635 if (!nextThen && !nextElse)
2639 if (!prevIf.getElseRegion().empty())
2640 prevElseYielded = prevIf.elseYield().getOperands();
2643 for (
auto it : llvm::zip(prevIf.getResults(),
2644 prevIf.thenYield().getOperands(), prevElseYielded))
2646 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2650 use.
set(std::get<1>(it));
2655 use.
set(std::get<2>(it));
2661 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2663 IfOp combinedIf = rewriter.
create<IfOp>(
2664 nextIf.getLoc(), mergedTypes, prevIf.getCondition(),
false);
2665 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2668 combinedIf.getThenRegion(),
2669 combinedIf.getThenRegion().begin());
2672 YieldOp thenYield = combinedIf.thenYield();
2673 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2674 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2678 llvm::append_range(mergedYields, thenYield2.getOperands());
2679 rewriter.
create<YieldOp>(thenYield2.getLoc(), mergedYields);
2685 combinedIf.getElseRegion(),
2686 combinedIf.getElseRegion().begin());
2689 if (combinedIf.getElseRegion().empty()) {
2691 combinedIf.getElseRegion(),
2692 combinedIf.getElseRegion().
begin());
2694 YieldOp elseYield = combinedIf.elseYield();
2695 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2696 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2701 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2703 rewriter.
create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2712 if (pair.index() < prevIf.getNumResults())
2713 prevValues.push_back(pair.value());
2715 nextValues.push_back(pair.value());
2727 LogicalResult matchAndRewrite(IfOp ifOp,
2730 if (ifOp.getNumResults())
2732 Block *elseBlock = ifOp.elseBlock();
2733 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2737 newIfOp.getThenRegion().begin());
2762 LogicalResult matchAndRewrite(IfOp op,
2764 auto nestedOps = op.thenBlock()->without_terminator();
2766 if (!llvm::hasSingleElement(nestedOps))
2770 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2773 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2777 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2783 llvm::append_range(elseYield, op.elseYield().getOperands());
2797 if (tup.value().getDefiningOp() == nestedIf) {
2798 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2799 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2800 elseYield[tup.index()]) {
2805 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2818 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2821 elseYieldsToUpgradeToSelect.push_back(tup.index());
2825 Value newCondition = rewriter.
create<arith::AndIOp>(
2826 loc, op.getCondition(), nestedIf.getCondition());
2827 auto newIf = rewriter.
create<IfOp>(loc, op.getResultTypes(), newCondition);
2831 llvm::append_range(results, newIf.getResults());
2834 for (
auto idx : elseYieldsToUpgradeToSelect)
2835 results[idx] = rewriter.
create<arith::SelectOp>(
2836 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2838 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2841 if (!elseYield.empty()) {
2844 rewriter.
create<YieldOp>(loc, elseYield);
2855 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2856 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2857 RemoveStaticCondition, RemoveUnusedResults,
2858 ReplaceIfYieldWithConditionOrValue>(context);
2861 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2862 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2863 Block *IfOp::elseBlock() {
2864 Region &r = getElseRegion();
2869 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2875 void ParallelOp::build(
2885 ParallelOp::getOperandSegmentSizeAttr(),
2887 static_cast<int32_t>(upperBounds.size()),
2888 static_cast<int32_t>(steps.size()),
2889 static_cast<int32_t>(initVals.size())}));
2893 unsigned numIVs = steps.size();
2899 if (bodyBuilderFn) {
2901 bodyBuilderFn(builder, result.
location,
2906 if (initVals.empty())
2907 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2910 void ParallelOp::build(
2917 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2920 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2924 wrapper = wrappedBuilderFn;
2926 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2935 if (stepValues.empty())
2937 "needs at least one tuple element for lowerBound, upperBound and step");
2940 for (
Value stepValue : stepValues)
2943 return emitOpError(
"constant step operand must be positive");
2947 Block *body = getBody();
2949 return emitOpError() <<
"expects the same number of induction variables: "
2951 <<
" as bound and step values: " << stepValues.size();
2953 if (!arg.getType().isIndex())
2955 "expects arguments for the induction variable to be of index type");
2958 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2959 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2964 auto resultsSize = getResults().size();
2965 auto reductionsSize = reduceOp.getReductions().size();
2966 auto initValsSize = getInitVals().size();
2967 if (resultsSize != reductionsSize)
2968 return emitOpError() <<
"expects number of results: " << resultsSize
2969 <<
" to be the same as number of reductions: "
2971 if (resultsSize != initValsSize)
2972 return emitOpError() <<
"expects number of results: " << resultsSize
2973 <<
" to be the same as number of initial values: "
2977 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2978 auto resultType = getOperation()->getResult(i).getType();
2979 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2980 if (resultType != reductionOperandType)
2981 return reduceOp.emitOpError()
2982 <<
"expects type of " << i
2983 <<
"-th reduction operand: " << reductionOperandType
2984 <<
" to be the same as the " << i
2985 <<
"-th result type: " << resultType;
3033 for (
auto &iv : ivs)
3040 ParallelOp::getOperandSegmentSizeAttr(),
3042 static_cast<int32_t>(upper.size()),
3043 static_cast<int32_t>(steps.size()),
3044 static_cast<int32_t>(initVals.size())}));
3053 ParallelOp::ensureTerminator(*body, builder, result.
location);
3058 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3059 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3060 if (!getInitVals().empty())
3061 p <<
" init (" << getInitVals() <<
")";
3066 (*this)->getAttrs(),
3067 ParallelOp::getOperandSegmentSizeAttr());
3072 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3076 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3080 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3084 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3089 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3091 return ParallelOp();
3092 assert(ivArg.getOwner() &&
"unlinked block argument");
3093 auto *containingOp = ivArg.getOwner()->getParentOp();
3094 return dyn_cast<ParallelOp>(containingOp);
3099 struct ParallelOpSingleOrZeroIterationDimsFolder
3103 LogicalResult matchAndRewrite(ParallelOp op,
3110 for (
auto [lb, ub, step, iv] :
3111 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3112 op.getInductionVars())) {
3114 if (numIterations.has_value()) {
3116 if (*numIterations == 0) {
3117 rewriter.
replaceOp(op, op.getInitVals());
3122 if (*numIterations == 1) {
3127 newLowerBounds.push_back(lb);
3128 newUpperBounds.push_back(ub);
3129 newSteps.push_back(step);
3132 if (newLowerBounds.size() == op.getLowerBound().size())
3135 if (newLowerBounds.empty()) {
3139 results.reserve(op.getInitVals().size());
3140 for (
auto &bodyOp : op.getBody()->without_terminator())
3141 rewriter.
clone(bodyOp, mapping);
3142 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3143 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3144 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3145 auto initValIndex = results.size();
3146 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3150 rewriter.
clone(reduceBodyOp, mapping);
3153 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3154 results.push_back(result);
3162 rewriter.
create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3163 newSteps, op.getInitVals(),
nullptr);
3169 newOp.getRegion().begin(), mapping);
3170 rewriter.
replaceOp(op, newOp.getResults());
3178 LogicalResult matchAndRewrite(ParallelOp op,
3180 Block &outerBody = *op.getBody();
3184 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3189 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3190 llvm::is_contained(innerOp.getUpperBound(), val) ||
3191 llvm::is_contained(innerOp.getStep(), val))
3195 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3200 Block &innerBody = *innerOp.getBody();
3201 assert(iterVals.size() ==
3209 builder.
clone(op, mapping);
3212 auto concatValues = [](
const auto &first,
const auto &second) {
3214 ret.reserve(first.size() + second.size());
3215 ret.assign(first.begin(), first.end());
3216 ret.append(second.begin(), second.end());
3220 auto newLowerBounds =
3221 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3222 auto newUpperBounds =
3223 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3224 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3238 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3247 void ParallelOp::getSuccessorRegions(
3265 for (
Value v : operands) {
3274 LogicalResult ReduceOp::verifyRegions() {
3277 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3278 auto type = getOperands()[i].getType();
3281 return emitOpError() << i <<
"-th reduction has an empty body";
3284 return arg.getType() != type;
3286 return emitOpError() <<
"expected two block arguments with type " << type
3287 <<
" in the " << i <<
"-th reduction region";
3291 return emitOpError(
"reduction bodies must be terminated with an "
3292 "'scf.reduce.return' op");
3311 Block *reductionBody = getOperation()->getBlock();
3313 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3315 if (expectedResultType != getResult().
getType())
3316 return emitOpError() <<
"must have type " << expectedResultType
3317 <<
" (the type of the reduction inputs)";
3327 ValueRange inits, BodyBuilderFn beforeBuilder,
3328 BodyBuilderFn afterBuilder) {
3336 beforeArgLocs.reserve(inits.size());
3337 for (
Value operand : inits) {
3338 beforeArgLocs.push_back(operand.getLoc());
3343 inits.getTypes(), beforeArgLocs);
3352 resultTypes, afterArgLocs);
3358 ConditionOp WhileOp::getConditionOp() {
3359 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3362 YieldOp WhileOp::getYieldOp() {
3363 return cast<YieldOp>(getAfterBody()->getTerminator());
3366 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3367 return getYieldOp().getResultsMutable();
3371 return getBeforeBody()->getArguments();
3375 return getAfterBody()->getArguments();
3379 return getBeforeArguments();
3383 assert(point == getBefore() &&
3384 "WhileOp is expected to branch only to the first region");
3392 regions.emplace_back(&getBefore(), getBefore().getArguments());
3396 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3397 "there are only two regions in a WhileOp");
3399 if (point == getAfter()) {
3400 regions.emplace_back(&getBefore(), getBefore().getArguments());
3404 regions.emplace_back(getResults());
3405 regions.emplace_back(&getAfter(), getAfter().getArguments());
3409 return {&getBefore(), &getAfter()};
3430 FunctionType functionType;
3435 result.
addTypes(functionType.getResults());
3437 if (functionType.getNumInputs() != operands.size()) {
3439 <<
"expected as many input types as operands "
3440 <<
"(expected " << operands.size() <<
" got "
3441 << functionType.getNumInputs() <<
")";
3451 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3452 regionArgs[i].type = functionType.getInput(i);
3454 return failure(parser.
parseRegion(*before, regionArgs) ||
3474 template <
typename OpTy>
3477 if (left.size() != right.size())
3478 return op.emitOpError(
"expects the same number of ") << message;
3480 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3481 if (left[i] != right[i]) {
3484 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3485 <<
" and " << right[i];
3494 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3496 "expects the 'before' region to terminate with 'scf.condition'");
3497 if (!beforeTerminator)
3500 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3502 "expects the 'after' region to terminate with 'scf.yield'");
3503 return success(afterTerminator !=
nullptr);
3529 LogicalResult matchAndRewrite(WhileOp op,
3531 auto term = op.getConditionOp();
3535 Value constantTrue =
nullptr;
3537 bool replaced =
false;
3538 for (
auto yieldedAndBlockArgs :
3539 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3540 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3541 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3543 constantTrue = rewriter.
create<arith::ConstantOp>(
3544 op.getLoc(), term.getCondition().getType(),
3553 return success(replaced);
3605 struct RemoveLoopInvariantArgsFromBeforeBlock
3609 LogicalResult matchAndRewrite(WhileOp op,
3611 Block &afterBlock = *op.getAfterBody();
3613 ConditionOp condOp = op.getConditionOp();
3618 bool canSimplify =
false;
3619 for (
const auto &it :
3621 auto index =
static_cast<unsigned>(it.index());
3622 auto [initVal, yieldOpArg] = it.value();
3625 if (yieldOpArg == initVal) {
3634 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3635 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3636 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3637 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3650 for (
const auto &it :
3652 auto index =
static_cast<unsigned>(it.index());
3653 auto [initVal, yieldOpArg] = it.value();
3657 if (yieldOpArg == initVal) {
3658 beforeBlockInitValMap.insert({index, initVal});
3666 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3667 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3668 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3669 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3670 beforeBlockInitValMap.insert({index, initVal});
3675 newInitArgs.emplace_back(initVal);
3676 newYieldOpArgs.emplace_back(yieldOpArg);
3677 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3687 rewriter.
create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3690 &newWhile.getBefore(), {},
3693 Block &beforeBlock = *op.getBeforeBody();
3700 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3703 if (beforeBlockInitValMap.count(i) != 0)
3704 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3706 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(
j++);
3709 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3711 newWhile.getAfter().begin());
3713 rewriter.
replaceOp(op, newWhile.getResults());
3758 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3761 LogicalResult matchAndRewrite(WhileOp op,
3763 Block &beforeBlock = *op.getBeforeBody();
3764 ConditionOp condOp = op.getConditionOp();
3767 bool canSimplify =
false;
3768 for (
Value condOpArg : condOpArgs) {
3788 auto index =
static_cast<unsigned>(it.index());
3789 Value condOpArg = it.value();
3794 condOpInitValMap.insert({index, condOpArg});
3796 newCondOpArgs.emplace_back(condOpArg);
3797 newAfterBlockType.emplace_back(condOpArg.
getType());
3798 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3809 auto newWhile = rewriter.
create<WhileOp>(op.getLoc(), newAfterBlockType,
3812 Block &newAfterBlock =
3814 newAfterBlockType, newAfterBlockArgLocs);
3816 Block &afterBlock = *op.getAfterBody();
3823 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3824 Value afterBlockArg, result;
3827 if (condOpInitValMap.count(i) != 0) {
3828 afterBlockArg = condOpInitValMap[i];
3829 result = afterBlockArg;
3831 afterBlockArg = newAfterBlock.getArgument(
j);
3832 result = newWhile.getResult(
j);
3835 newAfterBlockArgs[i] = afterBlockArg;
3836 newWhileResults[i] = result;
3839 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3841 newWhile.getBefore().begin());
3843 rewriter.
replaceOp(op, newWhileResults);
3877 LogicalResult matchAndRewrite(WhileOp op,
3879 auto term = op.getConditionOp();
3880 auto afterArgs = op.getAfterArguments();
3881 auto termArgs = term.getArgs();
3888 bool needUpdate =
false;
3889 for (
const auto &it :
3891 auto i =
static_cast<unsigned>(it.index());
3892 Value result = std::get<0>(it.value());
3893 Value afterArg = std::get<1>(it.value());
3894 Value termArg = std::get<2>(it.value());
3898 newResultsIndices.emplace_back(i);
3899 newTermArgs.emplace_back(termArg);
3900 newResultTypes.emplace_back(result.
getType());
3901 newArgLocs.emplace_back(result.
getLoc());
3916 rewriter.
create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3919 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3926 newResults[it.value()] = newWhile.getResult(it.index());
3927 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3931 newWhile.getBefore().begin());
3933 Block &afterBlock = *op.getAfterBody();
3934 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3966 LogicalResult matchAndRewrite(scf::WhileOp op,
3968 using namespace scf;
3969 auto cond = op.getConditionOp();
3970 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3974 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3975 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3976 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3979 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3980 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3984 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3987 if (cmp2.getPredicate() == cmp.getPredicate())
3988 samePredicate =
true;
3989 else if (cmp2.getPredicate() ==
3991 samePredicate =
false;
4009 LogicalResult matchAndRewrite(WhileOp op,
4012 if (!llvm::any_of(op.getBeforeArguments(),
4013 [](
Value arg) { return arg.use_empty(); }))
4016 YieldOp yield = op.getYieldOp();
4021 llvm::BitVector argsToErase;
4023 size_t argsCount = op.getBeforeArguments().size();
4024 newYields.reserve(argsCount);
4025 newInits.reserve(argsCount);
4026 argsToErase.reserve(argsCount);
4027 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4028 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4029 if (beforeArg.use_empty()) {
4030 argsToErase.push_back(
true);
4032 argsToErase.push_back(
false);
4033 newYields.emplace_back(yieldValue);
4034 newInits.emplace_back(initValue);
4038 Block &beforeBlock = *op.getBeforeBody();
4039 Block &afterBlock = *op.getAfterBody();
4045 rewriter.
create<WhileOp>(loc, op.getResultTypes(), newInits,
4047 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4048 Block &newAfterBlock = *newWhileOp.getAfterBody();
4054 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4055 newBeforeBlock.getArguments());
4059 rewriter.
replaceOp(op, newWhileOp.getResults());
4068 LogicalResult matchAndRewrite(WhileOp op,
4070 ConditionOp condOp = op.getConditionOp();
4075 if (argsSet.size() == condOpArgs.size())
4078 llvm::SmallDenseMap<Value, unsigned> argsMap;
4080 argsMap.reserve(condOpArgs.size());
4081 newArgs.reserve(condOpArgs.size());
4082 for (
Value arg : condOpArgs) {
4083 if (!argsMap.count(arg)) {
4084 auto pos =
static_cast<unsigned>(argsMap.size());
4085 argsMap.insert({arg, pos});
4086 newArgs.emplace_back(arg);
4093 auto newWhileOp = rewriter.
create<scf::WhileOp>(
4094 loc, argsRange.getTypes(), op.getInits(),
nullptr,
4096 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4097 Block &newAfterBlock = *newWhileOp.getAfterBody();
4102 auto it = argsMap.find(arg);
4103 assert(it != argsMap.end());
4104 auto pos = it->second;
4105 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4106 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4114 Block &beforeBlock = *op.getBeforeBody();
4115 Block &afterBlock = *op.getAfterBody();
4117 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4118 newBeforeBlock.getArguments());
4119 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4127 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4129 if (args1.size() != args2.size())
4130 return std::nullopt;
4134 auto it = llvm::find(args2, arg1);
4135 if (it == args2.end())
4136 return std::nullopt;
4138 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4145 llvm::SmallDenseSet<Value> set;
4146 for (
Value arg : args) {
4147 if (!set.insert(arg).second)
4160 LogicalResult matchAndRewrite(WhileOp loop,
4162 auto oldBefore = loop.getBeforeBody();
4163 ConditionOp oldTerm = loop.getConditionOp();
4164 ValueRange beforeArgs = oldBefore->getArguments();
4166 if (beforeArgs == termArgs)
4169 if (hasDuplicates(termArgs))
4172 auto mapping = getArgsMapping(beforeArgs, termArgs);
4183 auto oldAfter = loop.getAfterBody();
4187 newResultTypes[
j] = loop.getResult(i).getType();
4189 auto newLoop = rewriter.
create<WhileOp>(
4190 loop.getLoc(), newResultTypes, loop.getInits(),
4192 auto newBefore = newLoop.getBeforeBody();
4193 auto newAfter = newLoop.getAfterBody();
4198 newResults[i] = newLoop.getResult(
j);
4199 newAfterArgs[i] = newAfter->getArgument(
j);
4203 newBefore->getArguments());
4215 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4216 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4217 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4218 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4232 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4235 caseValues.push_back(value);
4244 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4246 p <<
"case " << value <<
' ';
4252 if (getCases().size() != getCaseRegions().size()) {
4253 return emitOpError(
"has ")
4254 << getCaseRegions().size() <<
" case regions but "
4255 << getCases().size() <<
" case values";
4259 for (int64_t value : getCases())
4260 if (!valueSet.insert(value).second)
4261 return emitOpError(
"has duplicate case value: ") << value;
4263 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4265 return emitOpError(
"expected region to end with scf.yield, but got ")
4268 if (yield.getNumOperands() != getNumResults()) {
4269 return (emitOpError(
"expected each region to return ")
4270 << getNumResults() <<
" values, but " << name <<
" returns "
4271 << yield.getNumOperands())
4272 .attachNote(yield.getLoc())
4273 <<
"see yield operation here";
4275 for (
auto [idx, result, operand] :
4276 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4277 yield.getOperandTypes())) {
4278 if (result == operand)
4280 return (emitOpError(
"expected result #")
4281 << idx <<
" of each region to be " << result)
4282 .attachNote(yield.getLoc())
4283 << name <<
" returns " << operand <<
" here";
4288 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4291 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4297 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4299 Block &scf::IndexSwitchOp::getDefaultBlock() {
4300 return getDefaultRegion().
front();
4303 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4304 assert(idx < getNumCases() &&
"case index out-of-bounds");
4305 return getCaseRegions()[idx].front();
4308 void IndexSwitchOp::getSuccessorRegions(
4312 successors.emplace_back(getResults());
4316 llvm::append_range(successors, getRegions());
4319 void IndexSwitchOp::getEntrySuccessorRegions(
4322 FoldAdaptor adaptor(operands, *
this);
4325 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4327 llvm::append_range(successors, getRegions());
4333 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4334 if (caseValue == arg.getInt()) {
4335 successors.emplace_back(&caseRegion);
4339 successors.emplace_back(&getDefaultRegion());
4342 void IndexSwitchOp::getRegionInvocationBounds(
4344 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4345 if (!operandValue) {
4351 unsigned liveIndex = getNumRegions() - 1;
4352 const auto *it = llvm::find(getCases(), operandValue.getInt());
4353 if (it != getCases().end())
4354 liveIndex = std::distance(getCases().begin(), it);
4355 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4356 bounds.emplace_back(0, i == liveIndex);
4367 if (!maybeCst.has_value())
4369 int64_t cst = *maybeCst;
4370 int64_t caseIdx, e = op.getNumCases();
4371 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4372 if (cst == op.getCases()[caseIdx])
4376 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4377 : op.getDefaultRegion();
4401 #define GET_OP_CLASSES
4402 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.