26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
55 auto retValOp = dyn_cast<scf::YieldOp>(op);
59 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
70 void SCFDialect::initialize() {
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 addInterfaces<SCFInlinerInterface>();
76 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
77 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78 InParallelOp, ReduceReturnOp>();
79 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81 ForallOp, InParallelOp, WhileOp, YieldOp>();
82 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
87 scf::YieldOp::create(builder, loc);
92 template <
typename TerminatorTy>
94 StringRef errorMessage) {
97 terminatorOperation = ®ion.
front().
back();
98 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
102 if (terminatorOperation)
103 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
115 assert(region.
hasOneBlock() &&
"expected single-block region");
161 if (getRegion().empty())
162 return emitOpError(
"region needs to have at least one block");
163 if (getRegion().front().getNumArguments() > 0)
164 return emitOpError(
"region cannot have any arguments");
187 if (!op.getRegion().hasOneBlock())
236 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
239 Block *prevBlock = op->getBlock();
243 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
245 for (
Block &blk : op.getRegion()) {
246 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
248 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
249 yieldOp.getResults());
257 for (
auto res : op.getResults())
258 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
270 void ExecuteRegionOp::getSuccessorRegions(
288 assert((point.
isParent() || point == getParentOp().getAfter()) &&
289 "condition op can only exit the loop or branch to the after"
292 return getArgsMutable();
295 void ConditionOp::getSuccessorRegions(
297 FoldAdaptor adaptor(operands, *
this);
299 WhileOp whileOp = getParentOp();
303 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
304 if (!boolAttr || boolAttr.getValue())
305 regions.emplace_back(&whileOp.getAfter(),
306 whileOp.getAfter().getArguments());
307 if (!boolAttr || !boolAttr.getValue())
308 regions.emplace_back(whileOp.getResults());
317 BodyBuilderFn bodyBuilder) {
322 for (
Value v : initArgs)
328 for (
Value v : initArgs)
334 if (initArgs.empty() && !bodyBuilder) {
335 ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
336 }
else if (bodyBuilder) {
346 if (getInitArgs().size() != getNumResults())
348 "mismatch in number of loop-carried values and defined values");
353 LogicalResult ForOp::verifyRegions() {
358 "expected induction variable to be same type as bounds and step");
360 if (getNumRegionIterArgs() != getNumResults())
362 "mismatch in number of basic block args and defined values");
364 auto initArgs = getInitArgs();
365 auto iterArgs = getRegionIterArgs();
366 auto opResults = getResults();
368 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
370 return emitOpError() <<
"types mismatch between " << i
371 <<
"th iter operand and defined value";
373 return emitOpError() <<
"types mismatch between " << i
374 <<
"th iter region arg and defined value";
381 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
385 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
389 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
393 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
397 std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
402 std::optional<int64_t> tripCount =
404 if (!tripCount.has_value() || tripCount != 1)
408 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
415 llvm::append_range(bbArgReplacements, getInitArgs());
419 getOperation()->getIterator(), bbArgReplacements);
435 StringRef prefix =
"") {
436 assert(blocksArgs.size() == initializers.size() &&
437 "expected same length of arguments and initializers");
438 if (initializers.empty())
442 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
443 p << std::get<0>(it) <<
" = " << std::get<1>(it);
449 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
453 if (!getInitArgs().empty())
454 p <<
" -> (" << getInitArgs().getTypes() <<
')';
457 p <<
" : " << t <<
' ';
460 !getInitArgs().empty());
482 regionArgs.push_back(inductionVariable);
492 if (regionArgs.size() != result.
types.size() + 1)
495 "mismatch in number of loop-carried values and defined values");
504 regionArgs.front().type = type;
505 for (
auto [iterArg, type] :
506 llvm::zip_equal(llvm::drop_begin(regionArgs), result.
types))
513 ForOp::ensureTerminator(*body, builder, result.
location);
522 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
523 operands, result.
types)) {
524 Type type = std::get<2>(argOperandType);
525 std::get<0>(argOperandType).type = type;
542 return getBody()->getArguments().drop_front(getNumInductionVars());
546 return getInitArgsMutable();
549 FailureOr<LoopLikeOpInterface>
550 ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
552 bool replaceInitOperandUsesInLoop,
557 auto inits = llvm::to_vector(getInitArgs());
558 inits.append(newInitOperands.begin(), newInitOperands.end());
559 scf::ForOp newLoop = scf::ForOp::create(
565 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
567 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
572 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
573 assert(newInitOperands.size() == newYieldedValues.size() &&
574 "expected as many new yield values as new iter operands");
576 yieldOp.getResultsMutable().append(newYieldedValues);
582 newLoop.getBody()->getArguments().take_front(
583 getBody()->getNumArguments()));
585 if (replaceInitOperandUsesInLoop) {
588 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
599 newLoop->getResults().take_front(getNumResults()));
600 return cast<LoopLikeOpInterface>(newLoop.getOperation());
604 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
607 assert(ivArg.getOwner() &&
"unlinked block argument");
608 auto *containingOp = ivArg.getOwner()->getParentOp();
609 return dyn_cast_or_null<ForOp>(containingOp);
613 return getInitArgs();
630 for (
auto [lb, ub, step] :
631 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
633 if (!tripCount.has_value() || *tripCount != 1)
642 return getBody()->getArguments().drop_front(getRank());
646 return getOutputsMutable();
652 scf::InParallelOp terminator = forallOp.getTerminator();
657 bbArgReplacements.append(forallOp.getOutputs().begin(),
658 forallOp.getOutputs().end());
662 forallOp->getIterator(), bbArgReplacements);
667 results.reserve(forallOp.getResults().size());
668 for (
auto &yieldingOp : terminator.getYieldingOps()) {
669 auto parallelInsertSliceOp =
670 cast<tensor::ParallelInsertSliceOp>(yieldingOp);
672 Value dst = parallelInsertSliceOp.getDest();
673 Value src = parallelInsertSliceOp.getSource();
674 if (llvm::isa<TensorType>(src.
getType())) {
675 results.push_back(tensor::InsertSliceOp::create(
676 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
677 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
678 parallelInsertSliceOp.getStrides(),
679 parallelInsertSliceOp.getStaticOffsets(),
680 parallelInsertSliceOp.getStaticSizes(),
681 parallelInsertSliceOp.getStaticStrides()));
683 llvm_unreachable(
"unsupported terminator");
698 assert(lbs.size() == ubs.size() &&
699 "expected the same number of lower and upper bounds");
700 assert(lbs.size() == steps.size() &&
701 "expected the same number of lower bounds and steps");
706 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
708 assert(results.size() == iterArgs.size() &&
709 "loop nest body must return as many values as loop has iteration "
711 return LoopNest{{}, std::move(results)};
719 loops.reserve(lbs.size());
720 ivs.reserve(lbs.size());
723 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
724 auto loop = scf::ForOp::create(
725 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
731 currentIterArgs = args;
732 currentLoc = nestedLoc;
738 loops.push_back(loop);
742 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
744 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
751 ? bodyBuilder(builder, currentLoc, ivs,
752 loops.back().getRegionIterArgs())
754 assert(results.size() == iterArgs.size() &&
755 "loop nest body must return as many values as loop has iteration "
758 scf::YieldOp::create(builder, loc, results);
762 llvm::append_range(nestResults, loops.front().getResults());
763 return LoopNest{std::move(loops), std::move(nestResults)};
776 bodyBuilder(nestedBuilder, nestedLoc, ivs);
785 assert(operand.
getOwner() == forOp);
790 "expected an iter OpOperand");
792 "Expected a different type");
794 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
796 newIterOperands.push_back(replacement);
799 newIterOperands.push_back(opOperand.get());
803 scf::ForOp newForOp = scf::ForOp::create(
804 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
805 forOp.getStep(), newIterOperands);
806 newForOp->setAttrs(forOp->getAttrs());
807 Block &newBlock = newForOp.getRegion().
front();
815 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
817 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
818 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
822 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
825 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
828 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
829 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
830 clonedYieldOp.getOperand(yieldIdx));
832 newYieldOperands[yieldIdx] = castOut;
833 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
834 rewriter.
eraseOp(clonedYieldOp);
839 newResults[yieldIdx] =
840 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
860 LogicalResult matchAndRewrite(scf::ForOp forOp,
862 bool canonicalize =
false;
869 int64_t numResults = forOp.getNumResults();
871 keepMask.reserve(numResults);
874 newBlockTransferArgs.reserve(1 + numResults);
875 newBlockTransferArgs.push_back(
Value());
876 newIterArgs.reserve(forOp.getInitArgs().size());
877 newYieldValues.reserve(numResults);
878 newResultValues.reserve(numResults);
880 for (
auto [init, arg, result, yielded] :
881 llvm::zip(forOp.getInitArgs(),
882 forOp.getRegionIterArgs(),
884 forOp.getYieldedValues()
891 bool forwarded = (arg == yielded) || (init == yielded) ||
892 (arg.use_empty() && result.use_empty());
895 keepMask.push_back(
false);
896 newBlockTransferArgs.push_back(init);
897 newResultValues.push_back(init);
903 if (
auto it = initYieldToArg.find({init, yielded});
904 it != initYieldToArg.end()) {
906 keepMask.push_back(
false);
907 auto [sameArg, sameResult] = it->second;
911 newBlockTransferArgs.push_back(init);
912 newResultValues.push_back(init);
917 initYieldToArg.insert({{init, yielded}, {arg, result}});
918 keepMask.push_back(
true);
919 newIterArgs.push_back(init);
920 newYieldValues.push_back(yielded);
921 newBlockTransferArgs.push_back(
Value());
922 newResultValues.push_back(
Value());
928 scf::ForOp newForOp =
929 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
930 forOp.getUpperBound(), forOp.getStep(), newIterArgs);
931 newForOp->setAttrs(forOp->getAttrs());
932 Block &newBlock = newForOp.getRegion().
front();
936 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
938 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
939 Value &newResultVal = newResultValues[idx];
940 assert((blockTransferArg && newResultVal) ||
941 (!blockTransferArg && !newResultVal));
942 if (!blockTransferArg) {
943 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
944 newResultVal = newForOp.getResult(collapsedIdx++);
950 "unexpected argument size mismatch");
955 if (newIterArgs.empty()) {
956 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
959 rewriter.
replaceOp(forOp, newResultValues);
964 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
968 filteredOperands.reserve(newResultValues.size());
969 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
971 filteredOperands.push_back(mergedTerminator.getOperand(idx));
972 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
976 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
977 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
978 cloneFilteredTerminator(mergedYieldOp);
979 rewriter.
eraseOp(mergedYieldOp);
980 rewriter.
replaceOp(forOp, newResultValues);
988 static std::optional<int64_t> computeConstDiff(
Value l,
Value u) {
989 IntegerAttr clb, cub;
991 llvm::APInt lbValue = clb.getValue();
992 llvm::APInt ubValue = cub.getValue();
993 return (ubValue - lbValue).getSExtValue();
1002 return diff.getSExtValue();
1003 return std::nullopt;
1012 LogicalResult matchAndRewrite(ForOp op,
1016 if (op.getLowerBound() == op.getUpperBound()) {
1017 rewriter.
replaceOp(op, op.getInitArgs());
1021 std::optional<int64_t> diff =
1022 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1028 rewriter.
replaceOp(op, op.getInitArgs());
1032 std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1033 if (!maybeStepValue)
1038 llvm::APInt stepValue = *maybeStepValue;
1039 if (stepValue.sge(*diff)) {
1041 blockArgs.reserve(op.getInitArgs().size() + 1);
1042 blockArgs.push_back(op.getLowerBound());
1043 llvm::append_range(blockArgs, op.getInitArgs());
1050 if (!llvm::hasSingleElement(block))
1054 if (llvm::any_of(op.getYieldedValues(),
1055 [&](
Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1057 rewriter.
replaceOp(op, op.getYieldedValues());
1091 LogicalResult matchAndRewrite(ForOp op,
1093 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1094 OpOperand &iterOpOperand = std::get<0>(it);
1096 if (!incomingCast ||
1097 incomingCast.getSource().getType() == incomingCast.getType())
1102 incomingCast.getDest().getType(),
1103 incomingCast.getSource().getType()))
1105 if (!std::get<1>(it).hasOneUse())
1111 rewriter, op, iterOpOperand, incomingCast.getSource(),
1113 return tensor::CastOp::create(b, loc, type, source);
1125 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1129 std::optional<APInt> ForOp::getConstantStep() {
1132 return step.getValue();
1136 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1137 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1143 if (
auto constantStep = getConstantStep())
1144 if (*constantStep == 1)
1157 unsigned numLoops = getRank();
1159 if (getNumResults() != getOutputs().size())
1160 return emitOpError(
"produces ")
1161 << getNumResults() <<
" results, but has only "
1162 << getOutputs().size() <<
" outputs";
1165 auto *body = getBody();
1167 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1168 for (int64_t i = 0; i < numLoops; ++i)
1170 return emitOpError(
"expects ")
1171 << i <<
"-th block argument to be an index";
1172 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1174 return emitOpError(
"type mismatch between ")
1175 << i <<
"-th output and corresponding block argument";
1176 if (getMapping().has_value() && !getMapping()->empty()) {
1177 if (getDeviceMappingAttrs().size() != numLoops)
1178 return emitOpError() <<
"mapping attribute size must match op rank";
1179 if (failed(getDeviceMaskingAttr()))
1181 <<
" supports at most one device masking attribute";
1187 getStaticLowerBound(),
1188 getDynamicLowerBound())))
1191 getStaticUpperBound(),
1192 getDynamicUpperBound())))
1195 getStaticStep(), getDynamicStep())))
1203 p <<
" (" << getInductionVars();
1204 if (isNormalized()) {
1225 if (!getRegionOutArgs().empty())
1226 p <<
"-> (" << getResultTypes() <<
") ";
1227 p.printRegion(getRegion(),
1229 getNumResults() > 0);
1230 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1231 getStaticLowerBoundAttrName(),
1232 getStaticUpperBoundAttrName(),
1233 getStaticStepAttrName()});
1238 auto indexType = b.getIndexType();
1258 unsigned numLoops = ivs.size();
1293 if (outOperands.size() != result.
types.size())
1295 "mismatch between out operands and types");
1305 std::unique_ptr<Region> region = std::make_unique<Region>();
1306 for (
auto &iv : ivs) {
1307 iv.type = b.getIndexType();
1308 regionArgs.push_back(iv);
1311 auto &out = it.value();
1312 out.type = result.
types[it.index()];
1313 regionArgs.push_back(out);
1319 ForallOp::ensureTerminator(*region, b, result.
location);
1331 {static_cast<int32_t>(dynamicLbs.size()),
1332 static_cast<int32_t>(dynamicUbs.size()),
1333 static_cast<int32_t>(dynamicSteps.size()),
1334 static_cast<int32_t>(outOperands.size())}));
1339 void ForallOp::build(
1343 std::optional<ArrayAttr> mapping,
1364 "operandSegmentSizes",
1366 static_cast<int32_t>(dynamicUbs.size()),
1367 static_cast<int32_t>(dynamicSteps.size()),
1368 static_cast<int32_t>(outputs.size())}));
1369 if (mapping.has_value()) {
1388 if (!bodyBuilderFn) {
1389 ForallOp::ensureTerminator(*bodyRegion, b, result.
location);
1396 void ForallOp::build(
1399 std::optional<ArrayAttr> mapping,
1401 unsigned numLoops = ubs.size();
1404 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1408 bool ForallOp::isNormalized() {
1412 return intValue.has_value() && intValue == val;
1415 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1418 InParallelOp ForallOp::getTerminator() {
1419 return cast<InParallelOp>(getBody()->getTerminator());
1424 InParallelOp inParallelOp = getTerminator();
1425 for (
Operation &yieldOp : inParallelOp.getYieldingOps()) {
1426 if (
auto parallelInsertSliceOp =
1427 dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1428 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1429 storeOps.push_back(parallelInsertSliceOp);
1439 for (
auto attr : getMapping()->getValue()) {
1440 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1447 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1448 DeviceMaskingAttrInterface res;
1451 for (
auto attr : getMapping()->getValue()) {
1452 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1461 bool ForallOp::usesLinearMapping() {
1465 return ifaces.front().isLinearMapping();
1468 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1473 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1475 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1479 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1481 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1485 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1491 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1494 assert(tidxArg.getOwner() &&
"unlinked block argument");
1495 auto *containingOp = tidxArg.getOwner()->getParentOp();
1496 return dyn_cast<ForallOp>(containingOp);
1504 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1506 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1510 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1513 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1522 LogicalResult matchAndRewrite(ForallOp op,
1537 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1538 op.setStaticLowerBound(staticLowerBound);
1542 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1543 op.setStaticUpperBound(staticUpperBound);
1546 op.getDynamicStepMutable().assign(dynamicStep);
1547 op.setStaticStep(staticStep);
1549 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1551 {static_cast<int32_t>(dynamicLowerBound.size()),
1552 static_cast<int32_t>(dynamicUpperBound.size()),
1553 static_cast<int32_t>(dynamicStep.size()),
1554 static_cast<int32_t>(op.getNumResults())}));
1636 LogicalResult matchAndRewrite(ForallOp forallOp,
1655 for (
OpResult result : forallOp.getResults()) {
1656 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1657 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1658 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1659 resultToDelete.insert(result);
1661 resultToReplace.push_back(result);
1662 newOuts.push_back(opOperand->
get());
1668 if (resultToDelete.empty())
1676 for (
OpResult result : resultToDelete) {
1677 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1678 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1680 forallOp.getCombiningOps(blockArg);
1681 for (
Operation *combiningOp : combiningOps)
1682 rewriter.
eraseOp(combiningOp);
1687 auto newForallOp = scf::ForallOp::create(
1688 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1689 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1690 forallOp.getMapping(),
1695 Block *loopBody = forallOp.getBody();
1696 Block *newLoopBody = newForallOp.getBody();
1701 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1708 for (
OpResult result : forallOp.getResults()) {
1709 if (resultToDelete.count(result)) {
1710 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1712 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1715 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1719 for (
auto &&[oldResult, newResult] :
1720 llvm::zip(resultToReplace, newForallOp->getResults()))
1726 for (
OpResult oldResult : resultToDelete)
1728 forallOp.getTiedOpOperand(oldResult)->get());
1733 struct ForallOpSingleOrZeroIterationDimsFolder
1737 LogicalResult matchAndRewrite(ForallOp op,
1740 if (op.getMapping().has_value() && !op.getMapping()->empty())
1748 for (
auto [lb, ub, step, iv] :
1749 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1750 op.getMixedStep(), op.getInductionVars())) {
1752 if (numIterations.has_value()) {
1754 if (*numIterations == 0) {
1755 rewriter.
replaceOp(op, op.getOutputs());
1760 if (*numIterations == 1) {
1765 newMixedLowerBounds.push_back(lb);
1766 newMixedUpperBounds.push_back(ub);
1767 newMixedSteps.push_back(step);
1771 if (newMixedLowerBounds.empty()) {
1777 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1779 op,
"no dimensions have 0 or 1 iterations");
1784 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1785 newMixedUpperBounds, newMixedSteps,
1786 op.getOutputs(), std::nullopt,
nullptr);
1787 newOp.getBodyRegion().getBlocks().clear();
1792 newOp.getStaticLowerBoundAttrName(),
1793 newOp.getStaticUpperBoundAttrName(),
1794 newOp.getStaticStepAttrName()};
1795 for (
const auto &namedAttr : op->getAttrs()) {
1796 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1799 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1803 newOp.getRegion().begin(), mapping);
1804 rewriter.
replaceOp(op, newOp.getResults());
1810 struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1813 LogicalResult matchAndRewrite(ForallOp op,
1817 for (
auto [lb, ub, step, iv] :
1818 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1819 op.getMixedStep(), op.getInductionVars())) {
1823 if (!numIterations.has_value() || numIterations.value() != 1) {
1834 struct FoldTensorCastOfOutputIntoForallOp
1843 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1845 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1848 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1855 castOp.getSource().getType())) {
1859 tensorCastProducers[en.index()] =
1860 TypeCast{castOp.getSource().getType(), castOp.getType()};
1861 newOutputTensors[en.index()] = castOp.getSource();
1864 if (tensorCastProducers.empty())
1869 auto newForallOp = ForallOp::create(
1870 rewriter, loc, forallOp.getMixedLowerBound(),
1871 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1872 newOutputTensors, forallOp.getMapping(),
1874 auto castBlockArgs =
1875 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1876 for (auto [index, cast] : tensorCastProducers) {
1877 Value &oldTypeBBArg = castBlockArgs[index];
1878 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1879 cast.dstType, oldTypeBBArg);
1884 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1885 ivsBlockArgs.append(castBlockArgs);
1887 bbArgs.front().getParentBlock(), ivsBlockArgs);
1893 auto terminator = newForallOp.getTerminator();
1894 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1895 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1896 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1897 insertSliceOp.getDestMutable().assign(outputBlockArg);
1903 for (
auto &item : tensorCastProducers) {
1904 Value &oldTypeResult = castResults[item.first];
1905 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1908 rewriter.
replaceOp(forallOp, castResults);
1917 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1918 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1919 ForallOpSingleOrZeroIterationDimsFolder,
1920 ForallOpReplaceConstantInductionVar>(context);
1951 scf::ForallOp forallOp =
1952 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1954 return this->emitOpError(
"expected forall op parent");
1957 for (
Operation &op : getRegion().front().getOperations()) {
1958 if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1959 return this->emitOpError(
"expected only ")
1960 << tensor::ParallelInsertSliceOp::getOperationName() <<
" ops";
1964 Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1966 if (!llvm::is_contained(regionOutArgs, dest))
1967 return op.emitOpError(
"may only insert into an output block argument");
1984 std::unique_ptr<Region> region = std::make_unique<Region>();
1988 if (region->empty())
1998 OpResult InParallelOp::getParentResult(int64_t idx) {
1999 return getOperation()->getParentOp()->getResult(idx);
2003 return llvm::to_vector<4>(
2004 llvm::map_range(getYieldingOps(), [](
Operation &op) {
2006 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
2007 return llvm::cast<BlockArgument>(insertSliceOp.getDest());
2012 return getRegion().front().getOperations();
2020 assert(a &&
"expected non-empty operation");
2021 assert(b &&
"expected non-empty operation");
2026 if (ifOp->isProperAncestor(b))
2029 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2030 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2032 ifOp = ifOp->getParentOfType<IfOp>();
2040 IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2041 IfOp::Adaptor adaptor,
2043 if (adaptor.getRegions().empty())
2045 Region *r = &adaptor.getThenRegion();
2051 auto yieldOp = llvm::dyn_cast<YieldOp>(b.
back());
2054 TypeRange types = yieldOp.getOperandTypes();
2055 llvm::append_range(inferredReturnTypes, types);
2061 return build(builder, result, resultTypes, cond,
false,
2067 bool addElseBlock) {
2068 assert((!addElseBlock || addThenBlock) &&
2069 "must not create else block w/o then block");
2084 bool withElseRegion) {
2085 build(builder, result,
TypeRange{}, cond, withElseRegion);
2097 if (resultTypes.empty())
2098 IfOp::ensureTerminator(*thenRegion, builder, result.
location);
2102 if (withElseRegion) {
2104 if (resultTypes.empty())
2105 IfOp::ensureTerminator(*elseRegion, builder, result.
location);
2112 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2119 thenBuilder(builder, result.
location);
2125 elseBuilder(builder, result.
location);
2132 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.
operands, attrDict,
2134 inferredReturnTypes))) {
2135 result.
addTypes(inferredReturnTypes);
2140 if (getNumResults() != 0 && getElseRegion().empty())
2141 return emitOpError(
"must have an else block if defining values");
2179 bool printBlockTerminators =
false;
2181 p <<
" " << getCondition();
2182 if (!getResults().empty()) {
2183 p <<
" -> (" << getResultTypes() <<
")";
2185 printBlockTerminators =
true;
2190 printBlockTerminators);
2193 auto &elseRegion = getElseRegion();
2194 if (!elseRegion.
empty()) {
2198 printBlockTerminators);
2215 Region *elseRegion = &this->getElseRegion();
2216 if (elseRegion->
empty())
2224 FoldAdaptor adaptor(operands, *
this);
2225 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2226 if (!boolAttr || boolAttr.getValue())
2227 regions.emplace_back(&getThenRegion());
2230 if (!boolAttr || !boolAttr.getValue()) {
2231 if (!getElseRegion().empty())
2232 regions.emplace_back(&getElseRegion());
2234 regions.emplace_back(getResults());
2238 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2241 if (getElseRegion().empty())
2244 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2251 getConditionMutable().assign(xorStmt.getLhs());
2255 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2256 getElseRegion().getBlocks());
2257 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2258 getThenRegion().getBlocks(), thenBlock);
2262 void IfOp::getRegionInvocationBounds(
2265 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2268 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2269 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2272 invocationBounds.assign(2, {0, 1});
2288 llvm::transform(usedResults, std::back_inserter(usedOperands),
2293 [&]() { yieldOp->setOperands(usedOperands); });
2300 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2301 [](
OpResult result) { return !result.use_empty(); });
2304 if (usedResults.size() == op.getNumResults())
2309 llvm::transform(usedResults, std::back_inserter(newTypes),
2314 IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2320 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2321 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2326 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2343 else if (!op.getElseRegion().empty())
2359 if (op->getNumResults() == 0)
2362 auto cond = op.getCondition();
2363 auto thenYieldArgs = op.thenYield().getOperands();
2364 auto elseYieldArgs = op.elseYield().getOperands();
2367 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2368 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2369 &op.getElseRegion() == falseVal.getParentRegion())
2370 nonHoistable.push_back(trueVal.getType());
2374 if (nonHoistable.size() == op->getNumResults())
2377 IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2379 if (replacement.thenBlock())
2380 rewriter.
eraseBlock(replacement.thenBlock());
2381 replacement.getThenRegion().takeBody(op.getThenRegion());
2382 replacement.getElseRegion().takeBody(op.getElseRegion());
2385 assert(thenYieldArgs.size() == results.size());
2386 assert(elseYieldArgs.size() == results.size());
2391 for (
const auto &it :
2393 Value trueVal = std::get<0>(it.value());
2394 Value falseVal = std::get<1>(it.value());
2397 results[it.index()] = replacement.getResult(trueYields.size());
2398 trueYields.push_back(trueVal);
2399 falseYields.push_back(falseVal);
2400 }
else if (trueVal == falseVal)
2401 results[it.index()] = trueVal;
2403 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2404 cond, trueVal, falseVal);
2454 struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2460 if (op.getNumResults() == 0)
2464 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2466 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2469 op.getOperation()->getIterator());
2472 for (
auto [trueResult, falseResult, opResult] :
2473 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2475 if (trueResult == falseResult) {
2476 if (!opResult.use_empty()) {
2477 opResult.replaceAllUsesWith(trueResult);
2488 bool trueVal = trueYield.
getValue();
2489 bool falseVal = falseYield.
getValue();
2490 if (!trueVal && falseVal) {
2491 if (!opResult.use_empty()) {
2492 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2493 Value notCond = arith::XOrIOp::create(
2494 rewriter, op.getLoc(), op.getCondition(),
2500 opResult.replaceAllUsesWith(notCond);
2504 if (trueVal && !falseVal) {
2505 if (!opResult.use_empty()) {
2506 opResult.replaceAllUsesWith(op.getCondition());
2541 Block *parent = nextIf->getBlock();
2542 if (nextIf == &parent->
front())
2545 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2553 Block *nextThen =
nullptr;
2554 Block *nextElse =
nullptr;
2555 if (nextIf.getCondition() == prevIf.getCondition()) {
2556 nextThen = nextIf.thenBlock();
2557 if (!nextIf.getElseRegion().empty())
2558 nextElse = nextIf.elseBlock();
2560 if (arith::XOrIOp notv =
2561 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2562 if (notv.getLhs() == prevIf.getCondition() &&
2564 nextElse = nextIf.thenBlock();
2565 if (!nextIf.getElseRegion().empty())
2566 nextThen = nextIf.elseBlock();
2569 if (arith::XOrIOp notv =
2570 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2571 if (notv.getLhs() == nextIf.getCondition() &&
2573 nextElse = nextIf.thenBlock();
2574 if (!nextIf.getElseRegion().empty())
2575 nextThen = nextIf.elseBlock();
2579 if (!nextThen && !nextElse)
2583 if (!prevIf.getElseRegion().empty())
2584 prevElseYielded = prevIf.elseYield().getOperands();
2587 for (
auto it : llvm::zip(prevIf.getResults(),
2588 prevIf.thenYield().getOperands(), prevElseYielded))
2590 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2594 use.
set(std::get<1>(it));
2599 use.
set(std::get<2>(it));
2605 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2607 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2608 prevIf.getCondition(),
false);
2609 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2612 combinedIf.getThenRegion(),
2613 combinedIf.getThenRegion().begin());
2616 YieldOp thenYield = combinedIf.thenYield();
2617 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2618 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2622 llvm::append_range(mergedYields, thenYield2.getOperands());
2623 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2629 combinedIf.getElseRegion(),
2630 combinedIf.getElseRegion().begin());
2633 if (combinedIf.getElseRegion().empty()) {
2635 combinedIf.getElseRegion(),
2636 combinedIf.getElseRegion().
begin());
2638 YieldOp elseYield = combinedIf.elseYield();
2639 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2640 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2645 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2647 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2656 if (pair.index() < prevIf.getNumResults())
2657 prevValues.push_back(pair.value());
2659 nextValues.push_back(pair.value());
2674 if (ifOp.getNumResults())
2676 Block *elseBlock = ifOp.elseBlock();
2677 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2681 newIfOp.getThenRegion().begin());
2708 auto nestedOps = op.thenBlock()->without_terminator();
2710 if (!llvm::hasSingleElement(nestedOps))
2714 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2717 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2721 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2727 llvm::append_range(elseYield, op.elseYield().getOperands());
2741 if (tup.value().getDefiningOp() == nestedIf) {
2742 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2743 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2744 elseYield[tup.index()]) {
2749 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2762 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2765 elseYieldsToUpgradeToSelect.push_back(tup.index());
2769 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2770 nestedIf.getCondition());
2771 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2775 llvm::append_range(results, newIf.getResults());
2778 for (
auto idx : elseYieldsToUpgradeToSelect)
2780 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2781 thenYield[idx], elseYield[idx]);
2783 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2786 if (!elseYield.empty()) {
2789 YieldOp::create(rewriter, loc, elseYield);
2800 results.
add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
2801 RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
2802 ReplaceIfYieldWithConditionOrValue>(context);
2805 Block *IfOp::thenBlock() {
return &getThenRegion().
back(); }
2806 YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2807 Block *IfOp::elseBlock() {
2808 Region &r = getElseRegion();
2813 YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2819 void ParallelOp::build(
2829 ParallelOp::getOperandSegmentSizeAttr(),
2831 static_cast<int32_t>(upperBounds.size()),
2832 static_cast<int32_t>(steps.size()),
2833 static_cast<int32_t>(initVals.size())}));
2837 unsigned numIVs = steps.size();
2843 if (bodyBuilderFn) {
2845 bodyBuilderFn(builder, result.
location,
2850 if (initVals.empty())
2851 ParallelOp::ensureTerminator(*bodyRegion, builder, result.
location);
2854 void ParallelOp::build(
2861 auto wrappedBuilderFn = [&bodyBuilderFn](
OpBuilder &nestedBuilder,
2864 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2868 wrapper = wrappedBuilderFn;
2870 build(builder, result, lowerBounds, upperBounds, steps,
ValueRange(),
2879 if (stepValues.empty())
2881 "needs at least one tuple element for lowerBound, upperBound and step");
2884 for (
Value stepValue : stepValues)
2887 return emitOpError(
"constant step operand must be positive");
2891 Block *body = getBody();
2893 return emitOpError() <<
"expects the same number of induction variables: "
2895 <<
" as bound and step values: " << stepValues.size();
2897 if (!arg.getType().isIndex())
2899 "expects arguments for the induction variable to be of index type");
2902 auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2903 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2908 auto resultsSize = getResults().size();
2909 auto reductionsSize = reduceOp.getReductions().size();
2910 auto initValsSize = getInitVals().size();
2911 if (resultsSize != reductionsSize)
2912 return emitOpError() <<
"expects number of results: " << resultsSize
2913 <<
" to be the same as number of reductions: "
2915 if (resultsSize != initValsSize)
2916 return emitOpError() <<
"expects number of results: " << resultsSize
2917 <<
" to be the same as number of initial values: "
2921 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2922 auto resultType = getOperation()->getResult(i).getType();
2923 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2924 if (resultType != reductionOperandType)
2925 return reduceOp.emitOpError()
2926 <<
"expects type of " << i
2927 <<
"-th reduction operand: " << reductionOperandType
2928 <<
" to be the same as the " << i
2929 <<
"-th result type: " << resultType;
2945 OpAsmParser::Delimiter::Paren) ||
2952 OpAsmParser::Delimiter::Paren) ||
2960 OpAsmParser::Delimiter::Paren) ||
2977 for (
auto &iv : ivs)
2984 ParallelOp::getOperandSegmentSizeAttr(),
2986 static_cast<int32_t>(upper.size()),
2987 static_cast<int32_t>(steps.size()),
2988 static_cast<int32_t>(initVals.size())}));
2997 ParallelOp::ensureTerminator(*body, builder, result.
location);
3002 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3003 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3004 if (!getInitVals().empty())
3005 p <<
" init (" << getInitVals() <<
")";
3010 (*this)->getAttrs(),
3011 ParallelOp::getOperandSegmentSizeAttr());
3016 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3020 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3024 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3028 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3033 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3035 return ParallelOp();
3036 assert(ivArg.getOwner() &&
"unlinked block argument");
3037 auto *containingOp = ivArg.getOwner()->getParentOp();
3038 return dyn_cast<ParallelOp>(containingOp);
3043 struct ParallelOpSingleOrZeroIterationDimsFolder
3054 for (
auto [lb, ub, step, iv] :
3055 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3056 op.getInductionVars())) {
3058 if (numIterations.has_value()) {
3060 if (*numIterations == 0) {
3061 rewriter.
replaceOp(op, op.getInitVals());
3066 if (*numIterations == 1) {
3071 newLowerBounds.push_back(lb);
3072 newUpperBounds.push_back(ub);
3073 newSteps.push_back(step);
3076 if (newLowerBounds.size() == op.getLowerBound().size())
3079 if (newLowerBounds.empty()) {
3083 results.reserve(op.getInitVals().size());
3084 for (
auto &bodyOp : op.getBody()->without_terminator())
3085 rewriter.
clone(bodyOp, mapping);
3086 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3087 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3088 Block &reduceBlock = reduceOp.getReductions()[i].
front();
3089 auto initValIndex = results.size();
3090 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3094 rewriter.
clone(reduceBodyOp, mapping);
3097 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3098 results.push_back(result);
3106 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3107 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3113 newOp.getRegion().begin(), mapping);
3114 rewriter.
replaceOp(op, newOp.getResults());
3124 Block &outerBody = *op.getBody();
3128 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3133 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3134 llvm::is_contained(innerOp.getUpperBound(), val) ||
3135 llvm::is_contained(innerOp.getStep(), val))
3139 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3144 Block &innerBody = *innerOp.getBody();
3145 assert(iterVals.size() ==
3153 builder.
clone(op, mapping);
3156 auto concatValues = [](
const auto &first,
const auto &second) {
3158 ret.reserve(first.size() + second.size());
3159 ret.assign(first.begin(), first.end());
3160 ret.append(second.begin(), second.end());
3164 auto newLowerBounds =
3165 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3166 auto newUpperBounds =
3167 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3168 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3182 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3191 void ParallelOp::getSuccessorRegions(
3209 for (
Value v : operands) {
3218 LogicalResult ReduceOp::verifyRegions() {
3221 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3222 auto type = getOperands()[i].getType();
3225 return emitOpError() << i <<
"-th reduction has an empty body";
3228 return arg.getType() != type;
3230 return emitOpError() <<
"expected two block arguments with type " << type
3231 <<
" in the " << i <<
"-th reduction region";
3235 return emitOpError(
"reduction bodies must be terminated with an "
3236 "'scf.reduce.return' op");
3255 Block *reductionBody = getOperation()->getBlock();
3257 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3259 if (expectedResultType != getResult().
getType())
3260 return emitOpError() <<
"must have type " << expectedResultType
3261 <<
" (the type of the reduction inputs)";
3271 ValueRange inits, BodyBuilderFn beforeBuilder,
3272 BodyBuilderFn afterBuilder) {
3280 beforeArgLocs.reserve(inits.size());
3281 for (
Value operand : inits) {
3282 beforeArgLocs.push_back(operand.getLoc());
3287 inits.getTypes(), beforeArgLocs);
3296 resultTypes, afterArgLocs);
3302 ConditionOp WhileOp::getConditionOp() {
3303 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3306 YieldOp WhileOp::getYieldOp() {
3307 return cast<YieldOp>(getAfterBody()->getTerminator());
3310 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3311 return getYieldOp().getResultsMutable();
3315 return getBeforeBody()->getArguments();
3319 return getAfterBody()->getArguments();
3323 return getBeforeArguments();
3327 assert(point == getBefore() &&
3328 "WhileOp is expected to branch only to the first region");
3336 regions.emplace_back(&getBefore(), getBefore().getArguments());
3340 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3341 "there are only two regions in a WhileOp");
3343 if (point == getAfter()) {
3344 regions.emplace_back(&getBefore(), getBefore().getArguments());
3348 regions.emplace_back(getResults());
3349 regions.emplace_back(&getAfter(), getAfter().getArguments());
3353 return {&getBefore(), &getAfter()};
3374 FunctionType functionType;
3379 result.
addTypes(functionType.getResults());
3381 if (functionType.getNumInputs() != operands.size()) {
3383 <<
"expected as many input types as operands "
3384 <<
"(expected " << operands.size() <<
" got "
3385 << functionType.getNumInputs() <<
")";
3395 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3396 regionArgs[i].type = functionType.getInput(i);
3398 return failure(parser.
parseRegion(*before, regionArgs) ||
3418 template <
typename OpTy>
3421 if (left.size() != right.size())
3422 return op.emitOpError(
"expects the same number of ") << message;
3424 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3425 if (left[i] != right[i]) {
3428 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3429 <<
" and " << right[i];
3438 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3440 "expects the 'before' region to terminate with 'scf.condition'");
3441 if (!beforeTerminator)
3444 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3446 "expects the 'after' region to terminate with 'scf.yield'");
3447 return success(afterTerminator !=
nullptr);
3475 auto term = op.getConditionOp();
3479 Value constantTrue =
nullptr;
3481 bool replaced =
false;
3482 for (
auto yieldedAndBlockArgs :
3483 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3484 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3485 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3487 constantTrue = arith::ConstantOp::create(
3488 rewriter, op.getLoc(), term.getCondition().getType(),
3497 return success(replaced);
3549 struct RemoveLoopInvariantArgsFromBeforeBlock
3555 Block &afterBlock = *op.getAfterBody();
3557 ConditionOp condOp = op.getConditionOp();
3562 bool canSimplify =
false;
3563 for (
const auto &it :
3565 auto index =
static_cast<unsigned>(it.index());
3566 auto [initVal, yieldOpArg] = it.value();
3569 if (yieldOpArg == initVal) {
3578 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3579 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3580 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3581 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3594 for (
const auto &it :
3596 auto index =
static_cast<unsigned>(it.index());
3597 auto [initVal, yieldOpArg] = it.value();
3601 if (yieldOpArg == initVal) {
3602 beforeBlockInitValMap.insert({index, initVal});
3610 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3611 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3612 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3613 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3614 beforeBlockInitValMap.insert({index, initVal});
3619 newInitArgs.emplace_back(initVal);
3620 newYieldOpArgs.emplace_back(yieldOpArg);
3621 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3630 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3634 &newWhile.getBefore(), {},
3635 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3637 Block &beforeBlock = *op.getBeforeBody();
3644 for (
unsigned i = 0,
j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3647 if (beforeBlockInitValMap.count(i) != 0)
3648 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3650 newBeforeBlockArgs[i] = newBeforeBlock.
getArgument(
j++);
3653 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3655 newWhile.getAfter().begin());
3657 rewriter.
replaceOp(op, newWhile.getResults());
3702 struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3707 Block &beforeBlock = *op.getBeforeBody();
3708 ConditionOp condOp = op.getConditionOp();
3711 bool canSimplify =
false;
3712 for (
Value condOpArg : condOpArgs) {
3732 auto index =
static_cast<unsigned>(it.index());
3733 Value condOpArg = it.value();
3738 condOpInitValMap.insert({index, condOpArg});
3740 newCondOpArgs.emplace_back(condOpArg);
3741 newAfterBlockType.emplace_back(condOpArg.
getType());
3742 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3753 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3756 Block &newAfterBlock =
3758 newAfterBlockType, newAfterBlockArgLocs);
3760 Block &afterBlock = *op.getAfterBody();
3767 for (
unsigned i = 0,
j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
3768 Value afterBlockArg, result;
3771 if (condOpInitValMap.count(i) != 0) {
3772 afterBlockArg = condOpInitValMap[i];
3773 result = afterBlockArg;
3776 result = newWhile.getResult(
j);
3779 newAfterBlockArgs[i] = afterBlockArg;
3780 newWhileResults[i] = result;
3783 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3785 newWhile.getBefore().begin());
3787 rewriter.
replaceOp(op, newWhileResults);
3823 auto term = op.getConditionOp();
3824 auto afterArgs = op.getAfterArguments();
3825 auto termArgs = term.getArgs();
3832 bool needUpdate =
false;
3833 for (
const auto &it :
3835 auto i =
static_cast<unsigned>(it.index());
3836 Value result = std::get<0>(it.value());
3837 Value afterArg = std::get<1>(it.value());
3838 Value termArg = std::get<2>(it.value());
3842 newResultsIndices.emplace_back(i);
3843 newTermArgs.emplace_back(termArg);
3844 newResultTypes.emplace_back(result.
getType());
3845 newArgLocs.emplace_back(result.
getLoc());
3860 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
3863 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3870 newResults[it.value()] = newWhile.getResult(it.index());
3871 newAfterBlockArgs[it.value()] = newAfterBlock.
getArgument(it.index());
3875 newWhile.getBefore().begin());
3877 Block &afterBlock = *op.getAfterBody();
3878 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3912 using namespace scf;
3913 auto cond = op.getConditionOp();
3914 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3918 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3919 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3920 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3923 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3924 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3928 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3931 if (cmp2.getPredicate() == cmp.getPredicate())
3932 samePredicate =
true;
3933 else if (cmp2.getPredicate() ==
3935 samePredicate =
false;
3953 LogicalResult matchAndRewrite(WhileOp op,
3956 if (!llvm::any_of(op.getBeforeArguments(),
3957 [](
Value arg) { return arg.use_empty(); }))
3960 YieldOp yield = op.getYieldOp();
3965 llvm::BitVector argsToErase;
3967 size_t argsCount = op.getBeforeArguments().size();
3968 newYields.reserve(argsCount);
3969 newInits.reserve(argsCount);
3970 argsToErase.reserve(argsCount);
3971 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3972 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3973 if (beforeArg.use_empty()) {
3974 argsToErase.push_back(
true);
3976 argsToErase.push_back(
false);
3977 newYields.emplace_back(yieldValue);
3978 newInits.emplace_back(initValue);
3982 Block &beforeBlock = *op.getBeforeBody();
3983 Block &afterBlock = *op.getAfterBody();
3989 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
3991 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3992 Block &newAfterBlock = *newWhileOp.getAfterBody();
3998 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4003 rewriter.
replaceOp(op, newWhileOp.getResults());
4010 using OpRewritePattern::OpRewritePattern;
4012 LogicalResult matchAndRewrite(WhileOp op,
4014 ConditionOp condOp = op.getConditionOp();
4019 if (argsSet.size() == condOpArgs.size())
4022 llvm::SmallDenseMap<Value, unsigned> argsMap;
4024 argsMap.reserve(condOpArgs.size());
4025 newArgs.reserve(condOpArgs.size());
4026 for (
Value arg : condOpArgs) {
4027 if (!argsMap.count(arg)) {
4028 auto pos =
static_cast<unsigned>(argsMap.size());
4029 argsMap.insert({arg, pos});
4030 newArgs.emplace_back(arg);
4038 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4041 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4042 Block &newAfterBlock = *newWhileOp.getAfterBody();
4047 auto it = argsMap.find(arg);
4048 assert(it != argsMap.end());
4049 auto pos = it->second;
4050 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4051 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4059 Block &beforeBlock = *op.getBeforeBody();
4060 Block &afterBlock = *op.getAfterBody();
4062 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4064 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4072 static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4074 if (args1.size() != args2.size())
4075 return std::nullopt;
4079 auto it = llvm::find(args2, arg1);
4080 if (it == args2.end())
4081 return std::nullopt;
4083 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4090 llvm::SmallDenseSet<Value> set;
4091 for (
Value arg : args) {
4092 if (!set.insert(arg).second)
4103 using OpRewritePattern::OpRewritePattern;
4105 LogicalResult matchAndRewrite(WhileOp loop,
4107 auto oldBefore = loop.getBeforeBody();
4108 ConditionOp oldTerm = loop.getConditionOp();
4109 ValueRange beforeArgs = oldBefore->getArguments();
4111 if (beforeArgs == termArgs)
4114 if (hasDuplicates(termArgs))
4117 auto mapping = getArgsMapping(beforeArgs, termArgs);
4128 auto oldAfter = loop.getAfterBody();
4132 newResultTypes[
j] = loop.getResult(i).getType();
4134 auto newLoop = WhileOp::create(
4135 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4137 auto newBefore = newLoop.getBeforeBody();
4138 auto newAfter = newLoop.getAfterBody();
4143 newResults[i] = newLoop.getResult(
j);
4144 newAfterArgs[i] = newAfter->getArgument(
j);
4148 newBefore->getArguments());
4160 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4161 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4162 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4163 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4177 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4180 caseValues.push_back(value);
4189 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4191 p <<
"case " << value <<
' ';
4197 if (getCases().size() != getCaseRegions().size()) {
4198 return emitOpError(
"has ")
4199 << getCaseRegions().size() <<
" case regions but "
4200 << getCases().size() <<
" case values";
4204 for (int64_t value : getCases())
4205 if (!valueSet.insert(value).second)
4206 return emitOpError(
"has duplicate case value: ") << value;
4208 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4210 return emitOpError(
"expected region to end with scf.yield, but got ")
4213 if (yield.getNumOperands() != getNumResults()) {
4214 return (emitOpError(
"expected each region to return ")
4215 << getNumResults() <<
" values, but " << name <<
" returns "
4216 << yield.getNumOperands())
4217 .attachNote(yield.getLoc())
4218 <<
"see yield operation here";
4220 for (
auto [idx, result, operand] :
4221 llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4222 yield.getOperandTypes())) {
4223 if (result == operand)
4225 return (emitOpError(
"expected result #")
4226 << idx <<
" of each region to be " << result)
4227 .attachNote(yield.getLoc())
4228 << name <<
" returns " << operand <<
" here";
4233 if (failed(
verifyRegion(getDefaultRegion(),
"default region")))
4236 if (failed(
verifyRegion(caseRegion,
"case region #" + Twine(idx))))
4242 unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4244 Block &scf::IndexSwitchOp::getDefaultBlock() {
4245 return getDefaultRegion().
front();
4248 Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4249 assert(idx < getNumCases() &&
"case index out-of-bounds");
4250 return getCaseRegions()[idx].front();
4253 void IndexSwitchOp::getSuccessorRegions(
4257 successors.emplace_back(getResults());
4261 llvm::append_range(successors, getRegions());
4264 void IndexSwitchOp::getEntrySuccessorRegions(
4267 FoldAdaptor adaptor(operands, *
this);
4270 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4272 llvm::append_range(successors, getRegions());
4278 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4279 if (caseValue == arg.getInt()) {
4280 successors.emplace_back(&caseRegion);
4284 successors.emplace_back(&getDefaultRegion());
4287 void IndexSwitchOp::getRegionInvocationBounds(
4289 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4290 if (!operandValue) {
4296 unsigned liveIndex = getNumRegions() - 1;
4297 const auto *it = llvm::find(getCases(), operandValue.getInt());
4298 if (it != getCases().end())
4299 liveIndex = std::distance(getCases().begin(), it);
4300 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4301 bounds.emplace_back(0, i == liveIndex);
4312 if (!maybeCst.has_value())
4314 int64_t cst = *maybeCst;
4315 int64_t caseIdx, e = op.getNumCases();
4316 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4317 if (cst == op.getCases()[caseIdx])
4321 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4322 : op.getDefaultRegion();
4346 #define GET_OP_CLASSES
4347 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.