30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
47struct SCFInlinerInterface :
public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
52 IRMapping &valueMapping)
const final {
57 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
62 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
67 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
78void SCFDialect::initialize() {
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
95 scf::YieldOp::create(builder, loc);
100template <
typename TerminatorTy>
102 StringRef errorMessage) {
103 Operation *terminatorOperation =
nullptr;
105 terminatorOperation = ®ion.
front().
back();
106 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
118 auto addOp =
ub.getDefiningOp<arith::AddIOp>();
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
125 if (addOp.getLhs() != lb ||
139 assert(region.
hasOneBlock() &&
"expected single-block region");
159ParseResult ExecuteRegionOp::parse(
OpAsmParser &parser,
187LogicalResult ExecuteRegionOp::verify() {
188 if (getRegion().empty())
189 return emitOpError(
"region needs to have at least one block");
190 if (getRegion().front().getNumArguments() > 0)
191 return emitOpError(
"region cannot have any arguments");
214 if (!op.getRegion().hasOneBlock() || op.getNoInline())
263 if (op.getNoInline())
265 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
268 Block *prevBlock = op->getBlock();
272 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
274 for (
Block &blk : op.getRegion()) {
275 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
277 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278 yieldOp.getResults());
286 for (
auto res : op.getResults())
287 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
304 if (op.getNumResults() == 0)
308 for (
Block &block : op.getRegion()) {
309 if (
auto yield = dyn_cast<scf::YieldOp>(block.
getTerminator()))
310 yieldOps.push_back(yield.getOperation());
313 if (yieldOps.empty())
317 auto yieldOpsOperands = yieldOps[0]->getOperands();
318 for (
auto *yieldOp : yieldOps) {
319 if (yieldOp->getOperands() != yieldOpsOperands)
327 for (
auto [
index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
328 if (isValueFromInsideRegion(yieldedValue, op)) {
329 internalValues.push_back(yieldedValue);
330 opResultsToKeep.push_back(op.getResult(
index));
332 externalValues.push_back(yieldedValue);
333 opResultsToReplaceWithExternalValues.push_back(op.getResult(
index));
337 if (externalValues.empty())
343 for (
Value value : internalValues)
344 resultTypes.push_back(value.getType());
346 ExecuteRegionOp::create(rewriter, op.getLoc(),
TypeRange(resultTypes));
347 newOp->setAttrs(op->getAttrs());
351 newOp.getRegion().end());
355 for (
auto *yieldOp : yieldOps) {
372 bool isValueFromInsideRegion(
Value value,
373 ExecuteRegionOp executeRegionOp)
const {
376 return &executeRegionOp.getRegion() == defOp->getParentRegion();
380 return &executeRegionOp.getRegion() == blockArg.getParentRegion();
392void ExecuteRegionOp::getSuccessorRegions(
412 "condition op can only exit the loop or branch to the after"
415 return getArgsMutable();
418void ConditionOp::getSuccessorRegions(
420 FoldAdaptor adaptor(operands, *
this);
422 WhileOp whileOp = getParentOp();
426 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
427 if (!boolAttr || boolAttr.getValue())
428 regions.emplace_back(&whileOp.getAfter(),
429 whileOp.getAfter().getArguments());
430 if (!boolAttr || !boolAttr.getValue())
431 regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
440 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
444 result.addAttribute(getUnsignedCmpAttrName(
result.name),
447 result.addOperands(initArgs);
448 for (
Value v : initArgs)
449 result.addTypes(v.getType());
454 for (
Value v : initArgs)
460 if (initArgs.empty() && !bodyBuilder) {
461 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
462 }
else if (bodyBuilder) {
470LogicalResult ForOp::verify() {
472 if (getInitArgs().size() != getNumResults())
474 "mismatch in number of loop-carried values and defined values");
479LogicalResult ForOp::verifyRegions() {
484 "expected induction variable to be same type as bounds and step");
486 if (getNumRegionIterArgs() != getNumResults())
488 "mismatch in number of basic block args and defined values");
490 auto initArgs = getInitArgs();
491 auto iterArgs = getRegionIterArgs();
492 auto opResults = getResults();
494 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
496 return emitOpError() <<
"types mismatch between " << i
497 <<
"th iter operand and defined value";
499 return emitOpError() <<
"types mismatch between " << i
500 <<
"th iter region arg and defined value";
507std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
511std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
515std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
519std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
523std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
527LogicalResult ForOp::promoteIfSingleIteration(
RewriterBase &rewriter) {
528 std::optional<APInt> tripCount = getStaticTripCount();
529 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
532 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
535 if (*tripCount == 0) {
542 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
549 llvm::append_range(bbArgReplacements, getInitArgs());
553 getOperation()->getIterator(), bbArgReplacements);
569 StringRef prefix =
"") {
570 assert(blocksArgs.size() == initializers.size() &&
571 "expected same length of arguments and initializers");
572 if (initializers.empty())
576 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
577 p << std::get<0>(it) <<
" = " << std::get<1>(it);
583 if (getUnsignedCmp())
586 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
590 if (!getInitArgs().empty())
591 p <<
" -> (" << getInitArgs().getTypes() <<
')';
594 p <<
" : " << t <<
' ';
597 !getInitArgs().empty());
599 getUnsignedCmpAttrName().strref());
610 result.addAttribute(getUnsignedCmpAttrName(
result.name),
624 regionArgs.push_back(inductionVariable);
634 if (regionArgs.size() !=
result.types.size() + 1)
637 "mismatch in number of loop-carried values and defined values");
646 regionArgs.front().type = type;
647 for (
auto [iterArg, type] :
648 llvm::zip_equal(llvm::drop_begin(regionArgs),
result.types))
655 ForOp::ensureTerminator(*body, builder,
result.location);
664 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
665 operands,
result.types)) {
666 Type type = std::get<2>(argOperandType);
667 std::get<0>(argOperandType).type = type;
684 return getBody()->getArguments().drop_front(getNumInductionVars());
688 return getInitArgsMutable();
691FailureOr<LoopLikeOpInterface>
692ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
694 bool replaceInitOperandUsesInLoop,
699 auto inits = llvm::to_vector(getInitArgs());
700 inits.append(newInitOperands.begin(), newInitOperands.end());
701 scf::ForOp newLoop = scf::ForOp::create(
707 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
709 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
714 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
715 assert(newInitOperands.size() == newYieldedValues.size() &&
716 "expected as many new yield values as new iter operands");
718 yieldOp.getResultsMutable().append(newYieldedValues);
724 newLoop.getBody()->getArguments().take_front(
725 getBody()->getNumArguments()));
727 if (replaceInitOperandUsesInLoop) {
730 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
741 newLoop->getResults().take_front(getNumResults()));
742 return cast<LoopLikeOpInterface>(newLoop.getOperation());
746 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
749 assert(ivArg.getOwner() &&
"unlinked block argument");
750 auto *containingOp = ivArg.getOwner()->getParentOp();
751 return dyn_cast_or_null<ForOp>(containingOp);
755 return getInitArgs();
771LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
772 for (
auto [lb, ub, step] :
773 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
776 if (!tripCount.has_value() || *tripCount != 1)
785 return getBody()->getArguments().drop_front(getRank());
788MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
789 return getOutputsMutable();
795 scf::InParallelOp terminator = forallOp.getTerminator();
800 bbArgReplacements.append(forallOp.getOutputs().begin(),
801 forallOp.getOutputs().end());
805 forallOp->getIterator(), bbArgReplacements);
810 results.reserve(forallOp.getResults().size());
811 for (
auto &yieldingOp : terminator.getYieldingOps()) {
812 auto parallelInsertSliceOp =
813 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
814 if (!parallelInsertSliceOp)
817 Value dst = parallelInsertSliceOp.getDest();
818 Value src = parallelInsertSliceOp.getSource();
819 if (llvm::isa<TensorType>(src.
getType())) {
820 results.push_back(tensor::InsertSliceOp::create(
821 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
822 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
823 parallelInsertSliceOp.getStrides(),
824 parallelInsertSliceOp.getStaticOffsets(),
825 parallelInsertSliceOp.getStaticSizes(),
826 parallelInsertSliceOp.getStaticStrides()));
828 llvm_unreachable(
"unsupported terminator");
843 assert(lbs.size() == ubs.size() &&
844 "expected the same number of lower and upper bounds");
845 assert(lbs.size() == steps.size() &&
846 "expected the same number of lower bounds and steps");
851 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
853 assert(results.size() == iterArgs.size() &&
854 "loop nest body must return as many values as loop has iteration "
856 return LoopNest{{}, std::move(results)};
864 loops.reserve(lbs.size());
865 ivs.reserve(lbs.size());
868 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
869 auto loop = scf::ForOp::create(
870 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
876 currentIterArgs = args;
877 currentLoc = nestedLoc;
883 loops.push_back(loop);
887 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
889 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
896 ? bodyBuilder(builder, currentLoc, ivs,
897 loops.back().getRegionIterArgs())
899 assert(results.size() == iterArgs.size() &&
900 "loop nest body must return as many values as loop has iteration "
903 scf::YieldOp::create(builder, loc, results);
907 llvm::append_range(nestResults, loops.front().getResults());
908 return LoopNest{std::move(loops), std::move(nestResults)};
921 bodyBuilder(nestedBuilder, nestedLoc, ivs);
930 assert(operand.
getOwner() == forOp);
935 "expected an iter OpOperand");
937 "Expected a different type");
939 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
944 newIterOperands.push_back(opOperand.get());
948 scf::ForOp newForOp = scf::ForOp::create(
949 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
950 forOp.getStep(), newIterOperands,
nullptr,
951 forOp.getUnsignedCmp());
952 newForOp->setAttrs(forOp->getAttrs());
953 Block &newBlock = newForOp.getRegion().
front();
961 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
963 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
964 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
968 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
971 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
974 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
975 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
976 clonedYieldOp.getOperand(yieldIdx));
978 newYieldOperands[yieldIdx] = castOut;
979 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
980 rewriter.
eraseOp(clonedYieldOp);
985 newResults[yieldIdx] =
986 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
1006 LogicalResult matchAndRewrite(scf::ForOp forOp,
1008 bool canonicalize =
false;
1015 int64_t numResults = forOp.getNumResults();
1017 keepMask.reserve(numResults);
1020 newBlockTransferArgs.reserve(1 + numResults);
1021 newBlockTransferArgs.push_back(
Value());
1022 newIterArgs.reserve(forOp.getInitArgs().size());
1023 newYieldValues.reserve(numResults);
1024 newResultValues.reserve(numResults);
1026 for (
auto [init, arg,
result, yielded] :
1027 llvm::zip(forOp.getInitArgs(),
1028 forOp.getRegionIterArgs(),
1030 forOp.getYieldedValues()
1037 bool forwarded = (arg == yielded) || (init == yielded) ||
1038 (arg.use_empty() &&
result.use_empty());
1040 canonicalize =
true;
1041 keepMask.push_back(
false);
1042 newBlockTransferArgs.push_back(init);
1043 newResultValues.push_back(init);
1049 if (
auto it = initYieldToArg.find({init, yielded});
1050 it != initYieldToArg.end()) {
1051 canonicalize =
true;
1052 keepMask.push_back(
false);
1053 auto [sameArg, sameResult] = it->second;
1057 newBlockTransferArgs.push_back(init);
1058 newResultValues.push_back(init);
1063 initYieldToArg.insert({{init, yielded}, {arg,
result}});
1064 keepMask.push_back(
true);
1065 newIterArgs.push_back(init);
1066 newYieldValues.push_back(yielded);
1067 newBlockTransferArgs.push_back(Value());
1068 newResultValues.push_back(Value());
1074 scf::ForOp newForOp =
1075 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
1076 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
1077 nullptr, forOp.getUnsignedCmp());
1078 newForOp->setAttrs(forOp->getAttrs());
1079 Block &newBlock = newForOp.getRegion().front();
1082 newBlockTransferArgs[0] = newBlock.
getArgument(0);
1083 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
1085 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
1086 Value &newResultVal = newResultValues[idx];
1087 assert((blockTransferArg && newResultVal) ||
1088 (!blockTransferArg && !newResultVal));
1089 if (!blockTransferArg) {
1090 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
1091 newResultVal = newForOp.getResult(collapsedIdx++);
1095 Block &oldBlock = forOp.getRegion().front();
1097 "unexpected argument size mismatch");
1102 if (newIterArgs.empty()) {
1103 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1106 rewriter.
replaceOp(forOp, newResultValues);
1111 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1112 OpBuilder::InsertionGuard g(rewriter);
1114 SmallVector<Value, 4> filteredOperands;
1115 filteredOperands.reserve(newResultValues.size());
1116 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1118 filteredOperands.push_back(mergedTerminator.getOperand(idx));
1119 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1123 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1124 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1125 cloneFilteredTerminator(mergedYieldOp);
1126 rewriter.
eraseOp(mergedYieldOp);
1127 rewriter.
replaceOp(forOp, newResultValues);
1135struct SimplifyTrivialLoops :
public OpRewritePattern<ForOp> {
1136 using OpRewritePattern<ForOp>::OpRewritePattern;
1138 LogicalResult matchAndRewrite(ForOp op,
1139 PatternRewriter &rewriter)
const override {
1140 std::optional<APInt> tripCount = op.getStaticTripCount();
1141 if (!tripCount.has_value())
1143 "can't compute constant trip count");
1145 if (tripCount->isZero()) {
1146 LDBG() <<
"SimplifyTrivialLoops tripCount is 0 for loop "
1147 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1148 rewriter.
replaceOp(op, op.getInitArgs());
1152 if (tripCount->getSExtValue() == 1) {
1153 LDBG() <<
"SimplifyTrivialLoops tripCount is 1 for loop "
1154 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1155 SmallVector<Value, 4> blockArgs;
1156 blockArgs.reserve(op.getInitArgs().size() + 1);
1157 blockArgs.push_back(op.getLowerBound());
1158 llvm::append_range(blockArgs, op.getInitArgs());
1165 if (!llvm::hasSingleElement(block))
1169 if (llvm::any_of(op.getYieldedValues(),
1170 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1172 LDBG() <<
"SimplifyTrivialLoops empty body loop allows replacement with "
1173 "yield operands for loop "
1174 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1175 rewriter.
replaceOp(op, op.getYieldedValues());
1206struct ForOpTensorCastFolder :
public OpRewritePattern<ForOp> {
1207 using OpRewritePattern<ForOp>::OpRewritePattern;
1209 LogicalResult matchAndRewrite(ForOp op,
1210 PatternRewriter &rewriter)
const override {
1211 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1212 OpOperand &iterOpOperand = std::get<0>(it);
1214 if (!incomingCast ||
1215 incomingCast.getSource().getType() == incomingCast.getType())
1220 incomingCast.getDest().getType(),
1221 incomingCast.getSource().getType()))
1223 if (!std::get<1>(it).hasOneUse())
1229 rewriter, op, iterOpOperand, incomingCast.getSource(),
1230 [](OpBuilder &
b, Location loc, Type type, Value source) {
1231 return tensor::CastOp::create(b, loc, type, source);
1241void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1242 MLIRContext *context) {
1243 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1247std::optional<APInt> ForOp::getConstantStep() {
1250 return step.getValue();
1254std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1255 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1261 if (
auto constantStep = getConstantStep())
1262 if (*constantStep == 1)
1270std::optional<APInt> ForOp::getStaticTripCount() {
1279LogicalResult ForallOp::verify() {
1280 unsigned numLoops = getRank();
1282 if (getNumResults() != getOutputs().size())
1284 << getNumResults() <<
" results, but has only "
1285 << getOutputs().size() <<
" outputs";
1288 auto *body = getBody();
1290 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1291 for (int64_t i = 0; i < numLoops; ++i)
1294 << i <<
"-th block argument to be an index";
1295 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1298 << i <<
"-th output and corresponding block argument";
1299 if (getMapping().has_value() && !getMapping()->empty()) {
1300 if (getDeviceMappingAttrs().size() != numLoops)
1301 return emitOpError() <<
"mapping attribute size must match op rank";
1302 if (
failed(getDeviceMaskingAttr()))
1304 <<
" supports at most one device masking attribute";
1308 Operation *op = getOperation();
1310 getStaticLowerBound(),
1311 getDynamicLowerBound())))
1314 getStaticUpperBound(),
1315 getDynamicUpperBound())))
1318 getStaticStep(), getDynamicStep())))
1324void ForallOp::print(OpAsmPrinter &p) {
1325 Operation *op = getOperation();
1326 p <<
" (" << getInductionVars();
1327 if (isNormalized()) {
1348 if (!getRegionOutArgs().empty())
1349 p <<
"-> (" << getResultTypes() <<
") ";
1350 p.printRegion(getRegion(),
1352 getNumResults() > 0);
1353 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1354 getStaticLowerBoundAttrName(),
1355 getStaticUpperBoundAttrName(),
1356 getStaticStepAttrName()});
1359ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &
result) {
1361 auto indexType =
b.getIndexType();
1366 SmallVector<OpAsmParser::Argument, 4> ivs;
1371 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1381 unsigned numLoops = ivs.size();
1382 staticLbs =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1383 staticSteps =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1412 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1413 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1416 if (outOperands.size() !=
result.types.size())
1418 "mismatch between out operands and types");
1427 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1428 std::unique_ptr<Region> region = std::make_unique<Region>();
1429 for (
auto &iv : ivs) {
1430 iv.type =
b.getIndexType();
1431 regionArgs.push_back(iv);
1433 for (
const auto &it : llvm::enumerate(regionOutArgs)) {
1434 auto &out = it.value();
1435 out.type =
result.types[it.index()];
1436 regionArgs.push_back(out);
1442 ForallOp::ensureTerminator(*region,
b,
result.location);
1443 result.addRegion(std::move(region));
1449 result.addAttribute(
"staticLowerBound", staticLbs);
1450 result.addAttribute(
"staticUpperBound", staticUbs);
1451 result.addAttribute(
"staticStep", staticSteps);
1452 result.addAttribute(
"operandSegmentSizes",
1454 {static_cast<int32_t>(dynamicLbs.size()),
1455 static_cast<int32_t>(dynamicUbs.size()),
1456 static_cast<int32_t>(dynamicSteps.size()),
1457 static_cast<int32_t>(outOperands.size())}));
1462void ForallOp::build(
1463 mlir::OpBuilder &
b, mlir::OperationState &
result,
1464 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1465 ArrayRef<OpFoldResult> steps,
ValueRange outputs,
1466 std::optional<ArrayAttr> mapping,
1468 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1469 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1474 result.addOperands(dynamicLbs);
1475 result.addOperands(dynamicUbs);
1476 result.addOperands(dynamicSteps);
1477 result.addOperands(outputs);
1480 result.addAttribute(getStaticLowerBoundAttrName(
result.name),
1481 b.getDenseI64ArrayAttr(staticLbs));
1482 result.addAttribute(getStaticUpperBoundAttrName(
result.name),
1483 b.getDenseI64ArrayAttr(staticUbs));
1484 result.addAttribute(getStaticStepAttrName(
result.name),
1485 b.getDenseI64ArrayAttr(staticSteps));
1487 "operandSegmentSizes",
1488 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1489 static_cast<int32_t>(dynamicUbs.size()),
1490 static_cast<int32_t>(dynamicSteps.size()),
1491 static_cast<int32_t>(outputs.size())}));
1492 if (mapping.has_value()) {
1493 result.addAttribute(ForallOp::getMappingAttrName(
result.name),
1497 Region *bodyRegion =
result.addRegion();
1498 OpBuilder::InsertionGuard g(
b);
1499 b.createBlock(bodyRegion);
1504 SmallVector<Type>(lbs.size(),
b.getIndexType()),
1505 SmallVector<Location>(staticLbs.size(),
result.location));
1508 SmallVector<Location>(outputs.size(),
result.location));
1510 b.setInsertionPointToStart(&bodyBlock);
1511 if (!bodyBuilderFn) {
1512 ForallOp::ensureTerminator(*bodyRegion,
b,
result.location);
1519void ForallOp::build(
1520 mlir::OpBuilder &
b, mlir::OperationState &
result,
1521 ArrayRef<OpFoldResult> ubs,
ValueRange outputs,
1522 std::optional<ArrayAttr> mapping,
1524 unsigned numLoops = ubs.size();
1525 SmallVector<OpFoldResult> lbs(numLoops,
b.getIndexAttr(0));
1526 SmallVector<OpFoldResult> steps(numLoops,
b.getIndexAttr(1));
1527 build(
b,
result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1531bool ForallOp::isNormalized() {
1532 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1533 return llvm::all_of(results, [&](OpFoldResult ofr) {
1535 return intValue.has_value() && intValue == val;
1538 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1541InParallelOp ForallOp::getTerminator() {
1542 return cast<InParallelOp>(getBody()->getTerminator());
1545SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1546 SmallVector<Operation *> storeOps;
1547 for (Operation *user : bbArg.
getUsers()) {
1548 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1549 storeOps.push_back(parallelOp);
1555SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1556 SmallVector<DeviceMappingAttrInterface> res;
1559 for (
auto attr : getMapping()->getValue()) {
1560 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1567FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1568 DeviceMaskingAttrInterface res;
1571 for (
auto attr : getMapping()->getValue()) {
1572 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1581bool ForallOp::usesLinearMapping() {
1582 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1585 return ifaces.front().isLinearMapping();
1588std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1589 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1593std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1595 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(),
b);
1599std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1601 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(),
b);
1605std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1611 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1614 assert(tidxArg.getOwner() &&
"unlinked block argument");
1615 auto *containingOp = tidxArg.getOwner()->getParentOp();
1616 return dyn_cast<ForallOp>(containingOp);
1624 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1626 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1630 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1633 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1638class ForallOpControlOperandsFolder :
public OpRewritePattern<ForallOp> {
1640 using OpRewritePattern<ForallOp>::OpRewritePattern;
1642 LogicalResult matchAndRewrite(ForallOp op,
1643 PatternRewriter &rewriter)
const override {
1644 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1645 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1646 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1653 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1654 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1657 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1658 op.setStaticLowerBound(staticLowerBound);
1662 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1663 op.setStaticUpperBound(staticUpperBound);
1666 op.getDynamicStepMutable().assign(dynamicStep);
1667 op.setStaticStep(staticStep);
1669 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1671 {static_cast<int32_t>(dynamicLowerBound.size()),
1672 static_cast<int32_t>(dynamicUpperBound.size()),
1673 static_cast<int32_t>(dynamicStep.size()),
1674 static_cast<int32_t>(op.getNumResults())}));
1753struct ForallOpIterArgsFolder :
public OpRewritePattern<ForallOp> {
1754 using OpRewritePattern<ForallOp>::OpRewritePattern;
1756 LogicalResult matchAndRewrite(ForallOp forallOp,
1757 PatternRewriter &rewriter)
const final {
1768 SmallVector<Value> resultsToDelete;
1769 SmallVector<Value> outsToDelete;
1770 SmallVector<BlockArgument> blockArgsToDelete;
1771 SmallVector<Value> newOuts;
1772 BitVector resultIndicesToDelete(forallOp.getNumResults(),
false);
1773 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1775 for (OpResult
result : forallOp.getResults()) {
1776 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1777 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1778 if (
result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1779 resultsToDelete.push_back(
result);
1780 outsToDelete.push_back(opOperand->
get());
1781 blockArgsToDelete.push_back(blockArg);
1782 resultIndicesToDelete[
result.getResultNumber()] =
true;
1785 newOuts.push_back(opOperand->
get());
1791 if (resultsToDelete.empty())
1796 for (
auto blockArg : blockArgsToDelete) {
1797 SmallVector<Operation *> combiningOps =
1798 forallOp.getCombiningOps(blockArg);
1799 for (Operation *combiningOp : combiningOps)
1800 rewriter.
eraseOp(combiningOp);
1802 for (
auto [blockArg,
result, out] :
1803 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1809 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1814 auto newForallOp = cast<scf::ForallOp>(
1816 newForallOp.getOutputsMutable().assign(newOuts);
1822struct ForallOpSingleOrZeroIterationDimsFolder
1823 :
public OpRewritePattern<ForallOp> {
1824 using OpRewritePattern<ForallOp>::OpRewritePattern;
1826 LogicalResult matchAndRewrite(ForallOp op,
1827 PatternRewriter &rewriter)
const override {
1829 if (op.getMapping().has_value() && !op.getMapping()->empty())
1831 Location loc = op.getLoc();
1834 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1837 for (
auto [lb, ub, step, iv] :
1838 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1839 op.getMixedStep(), op.getInductionVars())) {
1840 auto numIterations =
1842 if (numIterations.has_value()) {
1844 if (*numIterations == 0) {
1845 rewriter.
replaceOp(op, op.getOutputs());
1850 if (*numIterations == 1) {
1855 newMixedLowerBounds.push_back(lb);
1856 newMixedUpperBounds.push_back(ub);
1857 newMixedSteps.push_back(step);
1861 if (newMixedLowerBounds.empty()) {
1867 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1869 op,
"no dimensions have 0 or 1 iterations");
1874 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1875 newMixedUpperBounds, newMixedSteps,
1876 op.getOutputs(), std::nullopt,
nullptr);
1877 newOp.getBodyRegion().getBlocks().clear();
1881 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1882 newOp.getStaticLowerBoundAttrName(),
1883 newOp.getStaticUpperBoundAttrName(),
1884 newOp.getStaticStepAttrName()};
1885 for (
const auto &namedAttr : op->getAttrs()) {
1886 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1889 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1893 newOp.getRegion().begin(), mapping);
1894 rewriter.
replaceOp(op, newOp.getResults());
1900struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1901 using OpRewritePattern<ForallOp>::OpRewritePattern;
1903 LogicalResult matchAndRewrite(ForallOp op,
1904 PatternRewriter &rewriter)
const override {
1905 Location loc = op.getLoc();
1907 for (
auto [lb, ub, step, iv] :
1908 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1909 op.getMixedStep(), op.getInductionVars())) {
1912 auto numIterations =
1914 if (!numIterations.has_value() || numIterations.value() != 1) {
1925struct FoldTensorCastOfOutputIntoForallOp
1926 :
public OpRewritePattern<scf::ForallOp> {
1927 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1934 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1935 PatternRewriter &rewriter)
const final {
1936 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1937 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1938 for (
auto en : llvm::enumerate(newOutputTensors)) {
1939 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1946 castOp.getSource().getType())) {
1950 tensorCastProducers[en.index()] =
1951 TypeCast{castOp.getSource().getType(), castOp.getType()};
1952 newOutputTensors[en.index()] = castOp.getSource();
1955 if (tensorCastProducers.empty())
1959 Location loc = forallOp.getLoc();
1960 auto newForallOp = ForallOp::create(
1961 rewriter, loc, forallOp.getMixedLowerBound(),
1962 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1963 newOutputTensors, forallOp.getMapping(),
1964 [&](OpBuilder nestedBuilder, Location nestedLoc,
ValueRange bbArgs) {
1965 auto castBlockArgs =
1966 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1967 for (auto [index, cast] : tensorCastProducers) {
1968 Value &oldTypeBBArg = castBlockArgs[index];
1969 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1970 cast.dstType, oldTypeBBArg);
1974 SmallVector<Value> ivsBlockArgs =
1975 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1976 ivsBlockArgs.append(castBlockArgs);
1978 bbArgs.front().getParentBlock(), ivsBlockArgs);
1984 auto terminator = newForallOp.getTerminator();
1985 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1986 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1987 if (
auto parallelCombingingOp =
1988 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1989 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1995 SmallVector<Value> castResults = newForallOp.getResults();
1996 for (
auto &item : tensorCastProducers) {
1997 Value &oldTypeResult = castResults[item.first];
1998 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
2001 rewriter.
replaceOp(forallOp, castResults);
2008void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
2009 MLIRContext *context) {
2010 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
2011 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
2012 ForallOpSingleOrZeroIterationDimsFolder,
2013 ForallOpReplaceConstantInductionVar>(context);
2021void ForallOp::getSuccessorRegions(RegionBranchPoint point,
2022 SmallVectorImpl<RegionSuccessor> ®ions) {
2027 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2030 RegionSuccessor(getOperation(), getOperation()->getResults()));
2038void InParallelOp::build(OpBuilder &
b, OperationState &
result) {
2039 OpBuilder::InsertionGuard g(
b);
2040 Region *bodyRegion =
result.addRegion();
2041 b.createBlock(bodyRegion);
2044LogicalResult InParallelOp::verify() {
2045 scf::ForallOp forallOp =
2046 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
2048 return this->
emitOpError(
"expected forall op parent");
2050 for (Operation &op : getRegion().front().getOperations()) {
2051 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
2052 if (!parallelCombiningOp) {
2053 return this->
emitOpError(
"expected only ParallelCombiningOpInterface")
2058 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
2059 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
2060 for (OpOperand &dest : dests) {
2061 if (!llvm::is_contained(regionOutArgs, dest.get()))
2062 return op.emitOpError(
"may only insert into an output block argument");
2069void InParallelOp::print(OpAsmPrinter &p) {
2077ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
2080 SmallVector<OpAsmParser::Argument, 8> regionOperands;
2081 std::unique_ptr<Region> region = std::make_unique<Region>();
2085 if (region->empty())
2086 OpBuilder(builder.
getContext()).createBlock(region.get());
2087 result.addRegion(std::move(region));
2095OpResult InParallelOp::getParentResult(int64_t idx) {
2096 return getOperation()->getParentOp()->getResult(idx);
2099SmallVector<BlockArgument> InParallelOp::getDests() {
2100 SmallVector<BlockArgument> updatedDests;
2101 for (Operation &yieldingOp : getYieldingOps()) {
2102 auto parallelCombiningOp =
2103 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2104 if (!parallelCombiningOp)
2106 for (OpOperand &updatedOperand :
2107 parallelCombiningOp.getUpdatedDestinations())
2108 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2110 return updatedDests;
2113llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2114 return getRegion().front().getOperations();
2122 assert(a &&
"expected non-empty operation");
2123 assert(
b &&
"expected non-empty operation");
2128 if (ifOp->isProperAncestor(
b))
2131 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2132 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*
b));
2134 ifOp = ifOp->getParentOfType<IfOp>();
2142IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2143 IfOp::Adaptor adaptor,
2145 if (adaptor.getRegions().empty())
2147 Region *r = &adaptor.getThenRegion();
2153 auto yieldOp = llvm::dyn_cast<YieldOp>(
b.back());
2156 TypeRange types = yieldOp.getOperandTypes();
2157 llvm::append_range(inferredReturnTypes, types);
2163 return build(builder,
result, resultTypes, cond,
false,
2167void IfOp::build(OpBuilder &builder, OperationState &
result,
2168 TypeRange resultTypes, Value cond,
bool addThenBlock,
2169 bool addElseBlock) {
2170 assert((!addElseBlock || addThenBlock) &&
2171 "must not create else block w/o then block");
2172 result.addTypes(resultTypes);
2173 result.addOperands(cond);
2176 OpBuilder::InsertionGuard guard(builder);
2177 Region *thenRegion =
result.addRegion();
2180 Region *elseRegion =
result.addRegion();
2185void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2186 bool withElseRegion) {
2190void IfOp::build(OpBuilder &builder, OperationState &
result,
2191 TypeRange resultTypes, Value cond,
bool withElseRegion) {
2192 result.addTypes(resultTypes);
2193 result.addOperands(cond);
2196 OpBuilder::InsertionGuard guard(builder);
2197 Region *thenRegion =
result.addRegion();
2199 if (resultTypes.empty())
2200 IfOp::ensureTerminator(*thenRegion, builder,
result.location);
2203 Region *elseRegion =
result.addRegion();
2204 if (withElseRegion) {
2206 if (resultTypes.empty())
2207 IfOp::ensureTerminator(*elseRegion, builder,
result.location);
2211void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2213 function_ref<
void(OpBuilder &, Location)> elseBuilder) {
2214 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2215 result.addOperands(cond);
2218 OpBuilder::InsertionGuard guard(builder);
2219 Region *thenRegion =
result.addRegion();
2221 thenBuilder(builder,
result.location);
2224 Region *elseRegion =
result.addRegion();
2227 elseBuilder(builder,
result.location);
2231 SmallVector<Type> inferredReturnTypes;
2233 auto attrDict = DictionaryAttr::get(ctx,
result.attributes);
2234 if (succeeded(inferReturnTypes(ctx, std::nullopt,
result.operands, attrDict,
2236 inferredReturnTypes))) {
2237 result.addTypes(inferredReturnTypes);
2241LogicalResult IfOp::verify() {
2242 if (getNumResults() != 0 && getElseRegion().empty())
2243 return emitOpError(
"must have an else block if defining values");
2247ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
2249 result.regions.reserve(2);
2250 Region *thenRegion =
result.addRegion();
2251 Region *elseRegion =
result.addRegion();
2254 OpAsmParser::UnresolvedOperand cond;
2280void IfOp::print(OpAsmPrinter &p) {
2281 bool printBlockTerminators =
false;
2283 p <<
" " << getCondition();
2284 if (!getResults().empty()) {
2285 p <<
" -> (" << getResultTypes() <<
")";
2287 printBlockTerminators =
true;
2292 printBlockTerminators);
2295 auto &elseRegion = getElseRegion();
2296 if (!elseRegion.
empty()) {
2300 printBlockTerminators);
2306void IfOp::getSuccessorRegions(RegionBranchPoint point,
2307 SmallVectorImpl<RegionSuccessor> ®ions) {
2311 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2315 regions.push_back(RegionSuccessor(&getThenRegion()));
2318 Region *elseRegion = &this->getElseRegion();
2319 if (elseRegion->
empty())
2321 RegionSuccessor(getOperation(), getOperation()->getResults()));
2323 regions.push_back(RegionSuccessor(elseRegion));
2326void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2327 SmallVectorImpl<RegionSuccessor> ®ions) {
2328 FoldAdaptor adaptor(operands, *
this);
2329 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2330 if (!boolAttr || boolAttr.getValue())
2331 regions.emplace_back(&getThenRegion());
2334 if (!boolAttr || !boolAttr.getValue()) {
2335 if (!getElseRegion().empty())
2336 regions.emplace_back(&getElseRegion());
2338 regions.emplace_back(getOperation(), getResults());
2342LogicalResult IfOp::fold(FoldAdaptor adaptor,
2343 SmallVectorImpl<OpFoldResult> &results) {
2345 if (getElseRegion().empty())
2348 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2355 getConditionMutable().assign(xorStmt.getLhs());
2356 Block *thenBlock = &getThenRegion().front();
2359 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2360 getElseRegion().getBlocks());
2361 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2362 getThenRegion().getBlocks(), thenBlock);
2366void IfOp::getRegionInvocationBounds(
2367 ArrayRef<Attribute> operands,
2368 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2369 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2372 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2373 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2376 invocationBounds.assign(2, {0, 1});
2382struct RemoveUnusedResults :
public OpRewritePattern<IfOp> {
2383 using OpRewritePattern<IfOp>::OpRewritePattern;
2385 LogicalResult matchAndRewrite(IfOp op,
2386 PatternRewriter &rewriter)
const override {
2388 BitVector toErase(op.getNumResults(),
false);
2389 for (
auto [idx,
result] : llvm::enumerate(op.getResults()))
2391 toErase[idx] =
true;
2396 auto newOp = cast<scf::IfOp>(rewriter.
eraseOpResults(op, toErase));
2400 newOp.thenYield()->eraseOperands(toErase);
2403 newOp.elseYield()->eraseOperands(toErase);
2410struct RemoveStaticCondition :
public OpRewritePattern<IfOp> {
2411 using OpRewritePattern<IfOp>::OpRewritePattern;
2413 LogicalResult matchAndRewrite(IfOp op,
2414 PatternRewriter &rewriter)
const override {
2421 else if (!op.getElseRegion().empty())
2432struct ConvertTrivialIfToSelect :
public OpRewritePattern<IfOp> {
2433 using OpRewritePattern<IfOp>::OpRewritePattern;
2435 LogicalResult matchAndRewrite(IfOp op,
2436 PatternRewriter &rewriter)
const override {
2437 if (op->getNumResults() == 0)
2440 auto cond = op.getCondition();
2441 auto thenYieldArgs = op.thenYield().getOperands();
2442 auto elseYieldArgs = op.elseYield().getOperands();
2444 SmallVector<Type> nonHoistable;
2445 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2446 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2447 &op.getElseRegion() == falseVal.getParentRegion())
2448 nonHoistable.push_back(trueVal.getType());
2452 if (nonHoistable.size() == op->getNumResults())
2455 IfOp
replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2459 replacement.getThenRegion().takeBody(op.getThenRegion());
2460 replacement.getElseRegion().takeBody(op.getElseRegion());
2462 SmallVector<Value> results(op->getNumResults());
2463 assert(thenYieldArgs.size() == results.size());
2464 assert(elseYieldArgs.size() == results.size());
2466 SmallVector<Value> trueYields;
2467 SmallVector<Value> falseYields;
2469 for (
const auto &it :
2470 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2471 Value trueVal = std::get<0>(it.value());
2472 Value falseVal = std::get<1>(it.value());
2475 results[it.index()] =
replacement.getResult(trueYields.size());
2476 trueYields.push_back(trueVal);
2477 falseYields.push_back(falseVal);
2478 }
else if (trueVal == falseVal)
2479 results[it.index()] = trueVal;
2481 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2482 cond, trueVal, falseVal);
2509struct ConditionPropagation :
public OpRewritePattern<IfOp> {
2510 using OpRewritePattern<IfOp>::OpRewritePattern;
2513 enum class Parent { Then, Else,
None };
2518 static Parent getParentType(Region *toCheck, IfOp op,
2520 Region *endRegion) {
2521 SmallVector<Region *> seen;
2522 while (toCheck != endRegion) {
2523 auto found = cache.find(toCheck);
2524 if (found != cache.end())
2525 return found->second;
2526 seen.push_back(toCheck);
2527 if (&op.getThenRegion() == toCheck) {
2528 for (Region *region : seen)
2529 cache[region] = Parent::Then;
2530 return Parent::Then;
2532 if (&op.getElseRegion() == toCheck) {
2533 for (Region *region : seen)
2534 cache[region] = Parent::Else;
2535 return Parent::Else;
2540 for (Region *region : seen)
2541 cache[region] = Parent::None;
2542 return Parent::None;
2545 LogicalResult matchAndRewrite(IfOp op,
2546 PatternRewriter &rewriter)
const override {
2557 Value constantTrue =
nullptr;
2558 Value constantFalse =
nullptr;
2561 for (OpOperand &use :
2562 llvm::make_early_inc_range(op.getCondition().getUses())) {
2565 case Parent::Then: {
2569 constantTrue = arith::ConstantOp::create(
2573 [&]() { use.set(constantTrue); });
2576 case Parent::Else: {
2580 constantFalse = arith::ConstantOp::create(
2584 [&]() { use.set(constantFalse); });
2632struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2633 using OpRewritePattern<IfOp>::OpRewritePattern;
2635 LogicalResult matchAndRewrite(IfOp op,
2636 PatternRewriter &rewriter)
const override {
2638 if (op.getNumResults() == 0)
2642 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2644 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2647 op.getOperation()->getIterator());
2650 for (
auto [trueResult, falseResult, opResult] :
2651 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2653 if (trueResult == falseResult) {
2654 if (!opResult.use_empty()) {
2655 opResult.replaceAllUsesWith(trueResult);
2661 BoolAttr trueYield, falseYield;
2666 bool trueVal = trueYield.
getValue();
2667 bool falseVal = falseYield.
getValue();
2668 if (!trueVal && falseVal) {
2669 if (!opResult.use_empty()) {
2670 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2671 Value notCond = arith::XOrIOp::create(
2672 rewriter, op.getLoc(), op.getCondition(),
2678 opResult.replaceAllUsesWith(notCond);
2682 if (trueVal && !falseVal) {
2683 if (!opResult.use_empty()) {
2684 opResult.replaceAllUsesWith(op.getCondition());
2714struct CombineIfs :
public OpRewritePattern<IfOp> {
2715 using OpRewritePattern<IfOp>::OpRewritePattern;
2717 LogicalResult matchAndRewrite(IfOp nextIf,
2718 PatternRewriter &rewriter)
const override {
2719 Block *parent = nextIf->getBlock();
2720 if (nextIf == &parent->
front())
2723 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2731 Block *nextThen =
nullptr;
2732 Block *nextElse =
nullptr;
2733 if (nextIf.getCondition() == prevIf.getCondition()) {
2734 nextThen = nextIf.thenBlock();
2735 if (!nextIf.getElseRegion().empty())
2736 nextElse = nextIf.elseBlock();
2738 if (arith::XOrIOp notv =
2739 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2740 if (notv.getLhs() == prevIf.getCondition() &&
2742 nextElse = nextIf.thenBlock();
2743 if (!nextIf.getElseRegion().empty())
2744 nextThen = nextIf.elseBlock();
2747 if (arith::XOrIOp notv =
2748 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2749 if (notv.getLhs() == nextIf.getCondition() &&
2751 nextElse = nextIf.thenBlock();
2752 if (!nextIf.getElseRegion().empty())
2753 nextThen = nextIf.elseBlock();
2757 if (!nextThen && !nextElse)
2760 SmallVector<Value> prevElseYielded;
2761 if (!prevIf.getElseRegion().empty())
2762 prevElseYielded = prevIf.elseYield().getOperands();
2765 for (
auto it : llvm::zip(prevIf.getResults(),
2766 prevIf.thenYield().getOperands(), prevElseYielded))
2767 for (OpOperand &use :
2768 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2772 use.
set(std::get<1>(it));
2777 use.
set(std::get<2>(it));
2782 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2783 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2785 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2786 prevIf.getCondition(),
false);
2787 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2790 combinedIf.getThenRegion(),
2791 combinedIf.getThenRegion().begin());
2794 YieldOp thenYield = combinedIf.thenYield();
2795 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2796 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2799 SmallVector<Value> mergedYields(thenYield.getOperands());
2800 llvm::append_range(mergedYields, thenYield2.getOperands());
2801 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2807 combinedIf.getElseRegion(),
2808 combinedIf.getElseRegion().begin());
2811 if (combinedIf.getElseRegion().empty()) {
2813 combinedIf.getElseRegion(),
2814 combinedIf.getElseRegion().
begin());
2816 YieldOp elseYield = combinedIf.elseYield();
2817 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2818 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2822 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2823 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2825 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2831 SmallVector<Value> prevValues;
2832 SmallVector<Value> nextValues;
2833 for (
const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2834 if (pair.index() < prevIf.getNumResults())
2835 prevValues.push_back(pair.value());
2837 nextValues.push_back(pair.value());
2846struct RemoveEmptyElseBranch :
public OpRewritePattern<IfOp> {
2847 using OpRewritePattern<IfOp>::OpRewritePattern;
2849 LogicalResult matchAndRewrite(IfOp ifOp,
2850 PatternRewriter &rewriter)
const override {
2852 if (ifOp.getNumResults())
2854 Block *elseBlock = ifOp.elseBlock();
2855 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2859 newIfOp.getThenRegion().begin());
2881struct CombineNestedIfs :
public OpRewritePattern<IfOp> {
2882 using OpRewritePattern<IfOp>::OpRewritePattern;
2884 LogicalResult matchAndRewrite(IfOp op,
2885 PatternRewriter &rewriter)
const override {
2886 auto nestedOps = op.thenBlock()->without_terminator();
2888 if (!llvm::hasSingleElement(nestedOps))
2892 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2895 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2899 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2902 SmallVector<Value> thenYield(op.thenYield().getOperands());
2903 SmallVector<Value> elseYield;
2905 llvm::append_range(elseYield, op.elseYield().getOperands());
2909 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2918 for (
const auto &tup : llvm::enumerate(thenYield)) {
2919 if (tup.value().getDefiningOp() == nestedIf) {
2920 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2921 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2922 elseYield[tup.index()]) {
2927 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2940 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2943 elseYieldsToUpgradeToSelect.push_back(tup.index());
2946 Location loc = op.getLoc();
2947 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2948 nestedIf.getCondition());
2949 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2952 SmallVector<Value> results;
2953 llvm::append_range(results, newIf.getResults());
2956 for (
auto idx : elseYieldsToUpgradeToSelect)
2958 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2959 thenYield[idx], elseYield[idx]);
2961 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2964 if (!elseYield.empty()) {
2967 YieldOp::create(rewriter, loc, elseYield);
2976void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2977 MLIRContext *context) {
2978 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2979 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2980 RemoveStaticCondition, RemoveUnusedResults,
2981 ReplaceIfYieldWithConditionOrValue>(context);
2984Block *IfOp::thenBlock() {
return &getThenRegion().back(); }
2985YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2986Block *IfOp::elseBlock() {
2987 Region &r = getElseRegion();
2992YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2998void ParallelOp::build(
3003 result.addOperands(lowerBounds);
3004 result.addOperands(upperBounds);
3005 result.addOperands(steps);
3006 result.addOperands(initVals);
3008 ParallelOp::getOperandSegmentSizeAttr(),
3010 static_cast<int32_t>(upperBounds.size()),
3011 static_cast<int32_t>(steps.size()),
3012 static_cast<int32_t>(initVals.size())}));
3015 OpBuilder::InsertionGuard guard(builder);
3016 unsigned numIVs = steps.size();
3017 SmallVector<Type, 8> argTypes(numIVs, builder.
getIndexType());
3018 SmallVector<Location, 8> argLocs(numIVs,
result.location);
3019 Region *bodyRegion =
result.addRegion();
3022 if (bodyBuilderFn) {
3024 bodyBuilderFn(builder,
result.location,
3029 if (initVals.empty())
3030 ParallelOp::ensureTerminator(*bodyRegion, builder,
result.location);
3033void ParallelOp::build(
3040 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
3043 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3047 wrapper = wrappedBuilderFn;
3053LogicalResult ParallelOp::verify() {
3058 if (stepValues.empty())
3060 "needs at least one tuple element for lowerBound, upperBound and step");
3063 for (Value stepValue : stepValues)
3066 return emitOpError(
"constant step operand must be positive");
3070 Block *body = getBody();
3072 return emitOpError() <<
"expects the same number of induction variables: "
3074 <<
" as bound and step values: " << stepValues.size();
3076 if (!arg.getType().isIndex())
3078 "expects arguments for the induction variable to be of index type");
3082 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
3087 auto resultsSize = getResults().size();
3088 auto reductionsSize = reduceOp.getReductions().size();
3089 auto initValsSize = getInitVals().size();
3090 if (resultsSize != reductionsSize)
3091 return emitOpError() <<
"expects number of results: " << resultsSize
3092 <<
" to be the same as number of reductions: "
3094 if (resultsSize != initValsSize)
3095 return emitOpError() <<
"expects number of results: " << resultsSize
3096 <<
" to be the same as number of initial values: "
3098 if (reduceOp.getNumOperands() != initValsSize)
3103 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3104 auto resultType = getOperation()->getResult(i).getType();
3105 auto reductionOperandType = reduceOp.getOperands()[i].getType();
3106 if (resultType != reductionOperandType)
3107 return reduceOp.emitOpError()
3108 <<
"expects type of " << i
3109 <<
"-th reduction operand: " << reductionOperandType
3110 <<
" to be the same as the " << i
3111 <<
"-th result type: " << resultType;
3116ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
3119 SmallVector<OpAsmParser::Argument, 4> ivs;
3124 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
3131 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
3139 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
3147 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
3158 Region *body =
result.addRegion();
3159 for (
auto &iv : ivs)
3166 ParallelOp::getOperandSegmentSizeAttr(),
3168 static_cast<int32_t>(upper.size()),
3169 static_cast<int32_t>(steps.size()),
3170 static_cast<int32_t>(initVals.size())}));
3179 ParallelOp::ensureTerminator(*body, builder,
result.location);
3183void ParallelOp::print(OpAsmPrinter &p) {
3184 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3185 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3186 if (!getInitVals().empty())
3187 p <<
" init (" << getInitVals() <<
")";
3192 (*this)->getAttrs(),
3193 ParallelOp::getOperandSegmentSizeAttr());
3196SmallVector<Region *> ParallelOp::getLoopRegions() {
return {&getRegion()}; }
3198std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3199 return SmallVector<Value>{getBody()->getArguments()};
3202std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3206std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3210std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3215 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3217 return ParallelOp();
3218 assert(ivArg.getOwner() &&
"unlinked block argument");
3219 auto *containingOp = ivArg.getOwner()->getParentOp();
3220 return dyn_cast<ParallelOp>(containingOp);
3225struct ParallelOpSingleOrZeroIterationDimsFolder
3229 LogicalResult matchAndRewrite(ParallelOp op,
3236 for (
auto [lb,
ub, step, iv] :
3237 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3238 op.getInductionVars())) {
3239 auto numIterations =
3241 if (numIterations.has_value()) {
3243 if (*numIterations == 0) {
3244 rewriter.
replaceOp(op, op.getInitVals());
3249 if (*numIterations == 1) {
3254 newLowerBounds.push_back(lb);
3255 newUpperBounds.push_back(ub);
3256 newSteps.push_back(step);
3259 if (newLowerBounds.size() == op.getLowerBound().size())
3262 if (newLowerBounds.empty()) {
3265 SmallVector<Value> results;
3266 results.reserve(op.getInitVals().size());
3267 for (
auto &bodyOp : op.getBody()->without_terminator())
3268 rewriter.
clone(bodyOp, mapping);
3269 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3270 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3271 Block &reduceBlock = reduceOp.getReductions()[i].front();
3272 auto initValIndex = results.size();
3273 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3277 rewriter.
clone(reduceBodyOp, mapping);
3280 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3281 results.push_back(
result);
3289 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3290 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3296 newOp.getRegion().begin(), mapping);
3297 rewriter.
replaceOp(op, newOp.getResults());
3302struct MergeNestedParallelLoops :
public OpRewritePattern<ParallelOp> {
3303 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3305 LogicalResult matchAndRewrite(ParallelOp op,
3306 PatternRewriter &rewriter)
const override {
3307 Block &outerBody = *op.getBody();
3311 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3316 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3317 llvm::is_contained(innerOp.getUpperBound(), val) ||
3318 llvm::is_contained(innerOp.getStep(), val))
3322 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3325 auto bodyBuilder = [&](OpBuilder &builder, Location ,
3327 Block &innerBody = *innerOp.getBody();
3328 assert(iterVals.size() ==
3336 builder.
clone(op, mapping);
3339 auto concatValues = [](
const auto &first,
const auto &second) {
3340 SmallVector<Value> ret;
3341 ret.reserve(first.size() + second.size());
3342 ret.assign(first.begin(), first.end());
3343 ret.append(second.begin(), second.end());
3347 auto newLowerBounds =
3348 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3349 auto newUpperBounds =
3350 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3351 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3362void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3363 MLIRContext *context) {
3365 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3374void ParallelOp::getSuccessorRegions(
3375 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3379 regions.push_back(RegionSuccessor(&getRegion()));
3380 regions.push_back(RegionSuccessor(
3381 getOperation(), ResultRange{getResults().end(), getResults().end()}));
3388void ReduceOp::build(OpBuilder &builder, OperationState &
result) {}
3390void ReduceOp::build(OpBuilder &builder, OperationState &
result,
3392 result.addOperands(operands);
3393 for (Value v : operands) {
3394 OpBuilder::InsertionGuard guard(builder);
3395 Region *bodyRegion =
result.addRegion();
3402LogicalResult ReduceOp::verifyRegions() {
3403 if (getReductions().size() != getOperands().size())
3404 return emitOpError() <<
"expects number of reduction regions: "
3405 << getReductions().size()
3406 <<
" to be the same as number of reduction operands: "
3407 << getOperands().size();
3410 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3411 auto type = getOperands()[i].getType();
3414 return emitOpError() << i <<
"-th reduction has an empty body";
3416 llvm::any_of(block.
getArguments(), [&](
const BlockArgument &arg) {
3417 return arg.getType() != type;
3419 return emitOpError() <<
"expected two block arguments with type " << type
3420 <<
" in the " << i <<
"-th reduction region";
3424 return emitOpError(
"reduction bodies must be terminated with an "
3425 "'scf.reduce.return' op");
3432ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3434 return MutableOperandRange(getOperation(), 0, 0);
3441LogicalResult ReduceReturnOp::verify() {
3444 Block *reductionBody = getOperation()->getBlock();
3446 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3448 if (expectedResultType != getResult().
getType())
3449 return emitOpError() <<
"must have type " << expectedResultType
3450 <<
" (the type of the reduction inputs)";
3458void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3459 ::mlir::OperationState &odsState,
TypeRange resultTypes,
3460 ValueRange inits, BodyBuilderFn beforeBuilder,
3461 BodyBuilderFn afterBuilder) {
3465 OpBuilder::InsertionGuard guard(odsBuilder);
3468 SmallVector<Location, 4> beforeArgLocs;
3469 beforeArgLocs.reserve(inits.size());
3470 for (Value operand : inits) {
3471 beforeArgLocs.push_back(operand.getLoc());
3474 Region *beforeRegion = odsState.
addRegion();
3476 inits.getTypes(), beforeArgLocs);
3481 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.
location);
3483 Region *afterRegion = odsState.
addRegion();
3485 resultTypes, afterArgLocs);
3491ConditionOp WhileOp::getConditionOp() {
3492 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3495YieldOp WhileOp::getYieldOp() {
3496 return cast<YieldOp>(getAfterBody()->getTerminator());
3499std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3500 return getYieldOp().getResultsMutable();
3504 return getBeforeBody()->getArguments();
3508 return getAfterBody()->getArguments();
3512 return getBeforeArguments();
3515OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3517 "WhileOp is expected to branch only to the first region");
3521void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3522 SmallVectorImpl<RegionSuccessor> ®ions) {
3525 regions.emplace_back(&getBefore(), getBefore().getArguments());
3529 assert(llvm::is_contained(
3530 {&getAfter(), &getBefore()},
3532 "there are only two regions in a WhileOp");
3536 regions.emplace_back(&getBefore(), getBefore().getArguments());
3540 regions.emplace_back(getOperation(), getResults());
3541 regions.emplace_back(&getAfter(), getAfter().getArguments());
3544SmallVector<Region *> WhileOp::getLoopRegions() {
3545 return {&getBefore(), &getAfter()};
3555ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
3556 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3557 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3558 Region *before =
result.addRegion();
3559 Region *after =
result.addRegion();
3561 OptionalParseResult listResult =
3566 FunctionType functionType;
3571 result.addTypes(functionType.getResults());
3573 if (functionType.getNumInputs() != operands.size()) {
3575 <<
"expected as many input types as operands " <<
"(expected "
3576 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3586 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3587 regionArgs[i].type = functionType.getInput(i);
3589 return failure(parser.
parseRegion(*before, regionArgs) ||
3595void scf::WhileOp::print(OpAsmPrinter &p) {
3609template <
typename OpTy>
3612 if (left.size() != right.size())
3613 return op.emitOpError(
"expects the same number of ") << message;
3615 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3616 if (left[i] != right[i]) {
3619 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3620 <<
" and " << right[i];
3628LogicalResult scf::WhileOp::verify() {
3631 "expects the 'before' region to terminate with 'scf.condition'");
3632 if (!beforeTerminator)
3637 "expects the 'after' region to terminate with 'scf.yield'");
3638 return success(afterTerminator !=
nullptr);
3675struct WhileMoveIfDown :
public OpRewritePattern<scf::WhileOp> {
3676 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3678 LogicalResult matchAndRewrite(scf::WhileOp op,
3679 PatternRewriter &rewriter)
const override {
3680 auto conditionOp = op.getConditionOp();
3688 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3694 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3695 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3698 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3699 *ifOp->user_begin() == conditionOp)) &&
3700 "ifOp has unexpected uses");
3702 Location loc = op.getLoc();
3706 for (
auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3707 auto it = llvm::find(ifOp->getResults(), arg);
3708 if (it != ifOp->getResults().end()) {
3709 size_t ifOpIdx = it.getIndex();
3710 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3711 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3721 if (&op.getBefore() == operand->get().getParentRegion())
3722 additionalUsedValuesSet.insert(operand->get());
3726 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3727 auto additionalValueTypes = llvm::map_to_vector(
3728 additionalUsedValues, [](Value val) {
return val.
getType(); });
3729 size_t additionalValueSize = additionalUsedValues.size();
3730 SmallVector<Type> newResultTypes(op.getResultTypes());
3731 newResultTypes.append(additionalValueTypes);
3734 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3737 newWhileOp.getBefore().takeBody(op.getBefore());
3738 newWhileOp.getAfter().takeBody(op.getAfter());
3739 newWhileOp.getAfter().addArguments(
3740 additionalValueTypes,
3741 SmallVector<Location>(additionalValueSize, loc));
3745 conditionOp.getArgsMutable().append(additionalUsedValues);
3751 additionalUsedValues,
3752 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3753 [&](OpOperand &use) {
3754 return ifOp.getThenRegion().isAncestor(
3755 use.getOwner()->getParentRegion());
3759 rewriter.
eraseOp(ifOp.thenYield());
3761 newWhileOp.getAfterBody()->begin());
3764 newWhileOp->getResults().drop_back(additionalValueSize));
3788struct WhileConditionTruth :
public OpRewritePattern<WhileOp> {
3789 using OpRewritePattern<WhileOp>::OpRewritePattern;
3791 LogicalResult matchAndRewrite(WhileOp op,
3792 PatternRewriter &rewriter)
const override {
3793 auto term = op.getConditionOp();
3797 Value constantTrue =
nullptr;
3799 bool replaced =
false;
3800 for (
auto yieldedAndBlockArgs :
3801 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3802 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3803 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3805 constantTrue = arith::ConstantOp::create(
3806 rewriter, op.getLoc(), term.getCondition().getType(),
3867struct RemoveLoopInvariantArgsFromBeforeBlock
3868 :
public OpRewritePattern<WhileOp> {
3869 using OpRewritePattern<WhileOp>::OpRewritePattern;
3871 LogicalResult matchAndRewrite(WhileOp op,
3872 PatternRewriter &rewriter)
const override {
3873 Block &afterBlock = *op.getAfterBody();
3875 ConditionOp condOp = op.getConditionOp();
3876 OperandRange condOpArgs = condOp.getArgs();
3880 bool canSimplify =
false;
3881 for (
const auto &it :
3882 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3883 auto index =
static_cast<unsigned>(it.index());
3884 auto [initVal, yieldOpArg] = it.value();
3887 if (yieldOpArg == initVal) {
3896 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3897 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3898 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3899 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3909 SmallVector<Value> newInitArgs, newYieldOpArgs;
3911 SmallVector<Location> newBeforeBlockArgLocs;
3912 for (
const auto &it :
3913 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3914 auto index =
static_cast<unsigned>(it.index());
3915 auto [initVal, yieldOpArg] = it.value();
3919 if (yieldOpArg == initVal) {
3920 beforeBlockInitValMap.insert({index, initVal});
3928 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3929 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3930 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3931 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3932 beforeBlockInitValMap.insert({index, initVal});
3937 newInitArgs.emplace_back(initVal);
3938 newYieldOpArgs.emplace_back(yieldOpArg);
3939 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3943 OpBuilder::InsertionGuard g(rewriter);
3948 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3952 &newWhile.getBefore(), {},
3953 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3955 Block &beforeBlock = *op.getBeforeBody();
3962 for (
unsigned i = 0, j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3965 if (beforeBlockInitValMap.count(i) != 0)
3966 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3968 newBeforeBlockArgs[i] = newBeforeBlock.
getArgument(j++);
3971 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3973 newWhile.getAfter().begin());
3975 rewriter.
replaceOp(op, newWhile.getResults());
4020struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
4021 using OpRewritePattern<WhileOp>::OpRewritePattern;
4023 LogicalResult matchAndRewrite(WhileOp op,
4024 PatternRewriter &rewriter)
const override {
4025 Block &beforeBlock = *op.getBeforeBody();
4026 ConditionOp condOp = op.getConditionOp();
4027 OperandRange condOpArgs = condOp.getArgs();
4029 bool canSimplify =
false;
4030 for (Value condOpArg : condOpArgs) {
4045 SmallVector<Value> newCondOpArgs;
4046 SmallVector<Type> newAfterBlockType;
4048 SmallVector<Location> newAfterBlockArgLocs;
4049 for (
const auto &it : llvm::enumerate(condOpArgs)) {
4050 auto index =
static_cast<unsigned>(it.index());
4051 Value condOpArg = it.value();
4056 condOpInitValMap.insert({index, condOpArg});
4058 newCondOpArgs.emplace_back(condOpArg);
4059 newAfterBlockType.emplace_back(condOpArg.
getType());
4060 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
4065 OpBuilder::InsertionGuard g(rewriter);
4071 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
4074 Block &newAfterBlock =
4076 newAfterBlockType, newAfterBlockArgLocs);
4078 Block &afterBlock = *op.getAfterBody();
4085 for (
unsigned i = 0, j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
4086 Value afterBlockArg,
result;
4089 if (condOpInitValMap.count(i) != 0) {
4090 afterBlockArg = condOpInitValMap[i];
4094 result = newWhile.getResult(j);
4097 newAfterBlockArgs[i] = afterBlockArg;
4098 newWhileResults[i] =
result;
4101 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4103 newWhile.getBefore().begin());
4105 rewriter.
replaceOp(op, newWhileResults);
4136struct WhileUnusedResult :
public OpRewritePattern<WhileOp> {
4137 using OpRewritePattern<WhileOp>::OpRewritePattern;
4139 LogicalResult matchAndRewrite(WhileOp op,
4140 PatternRewriter &rewriter)
const override {
4141 auto term = op.getConditionOp();
4142 auto afterArgs = op.getAfterArguments();
4143 auto termArgs = term.getArgs();
4146 SmallVector<unsigned> newResultsIndices;
4147 SmallVector<Type> newResultTypes;
4148 SmallVector<Value> newTermArgs;
4149 SmallVector<Location> newArgLocs;
4150 bool needUpdate =
false;
4151 for (
const auto &it :
4152 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
4153 auto i =
static_cast<unsigned>(it.index());
4154 Value
result = std::get<0>(it.value());
4155 Value afterArg = std::get<1>(it.value());
4156 Value termArg = std::get<2>(it.value());
4160 newResultsIndices.emplace_back(i);
4161 newTermArgs.emplace_back(termArg);
4162 newResultTypes.emplace_back(
result.getType());
4163 newArgLocs.emplace_back(
result.getLoc());
4171 OpBuilder::InsertionGuard g(rewriter);
4178 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
4181 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
4185 SmallVector<Value> newResults(op.getNumResults());
4186 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
4187 for (
const auto &it : llvm::enumerate(newResultsIndices)) {
4188 newResults[it.value()] = newWhile.getResult(it.index());
4189 newAfterBlockArgs[it.value()] = newAfterBlock.
getArgument(it.index());
4193 newWhile.getBefore().begin());
4195 Block &afterBlock = *op.getAfterBody();
4196 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4225struct WhileCmpCond :
public OpRewritePattern<scf::WhileOp> {
4226 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
4228 LogicalResult matchAndRewrite(scf::WhileOp op,
4229 PatternRewriter &rewriter)
const override {
4230 using namespace scf;
4231 auto cond = op.getConditionOp();
4232 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4236 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4237 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
4238 if (std::get<0>(tup) != cmp.getOperand(opIdx))
4241 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4242 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4246 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4249 if (cmp2.getPredicate() == cmp.getPredicate())
4250 samePredicate =
true;
4251 else if (cmp2.getPredicate() ==
4252 arith::invertPredicate(cmp.getPredicate()))
4253 samePredicate =
false;
4268struct WhileRemoveUnusedArgs :
public OpRewritePattern<WhileOp> {
4269 using OpRewritePattern<WhileOp>::OpRewritePattern;
4271 LogicalResult matchAndRewrite(WhileOp op,
4272 PatternRewriter &rewriter)
const override {
4274 if (!llvm::any_of(op.getBeforeArguments(),
4275 [](Value arg) { return arg.use_empty(); }))
4278 YieldOp yield = op.getYieldOp();
4281 SmallVector<Value> newYields;
4282 SmallVector<Value> newInits;
4283 llvm::BitVector argsToErase;
4285 size_t argsCount = op.getBeforeArguments().size();
4286 newYields.reserve(argsCount);
4287 newInits.reserve(argsCount);
4288 argsToErase.reserve(argsCount);
4289 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4290 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4291 if (beforeArg.use_empty()) {
4292 argsToErase.push_back(
true);
4294 argsToErase.push_back(
false);
4295 newYields.emplace_back(yieldValue);
4296 newInits.emplace_back(initValue);
4300 Block &beforeBlock = *op.getBeforeBody();
4301 Block &afterBlock = *op.getAfterBody();
4305 Location loc = op.getLoc();
4307 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4309 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4310 Block &newAfterBlock = *newWhileOp.getAfterBody();
4312 OpBuilder::InsertionGuard g(rewriter);
4316 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4321 rewriter.
replaceOp(op, newWhileOp.getResults());
4327struct WhileRemoveDuplicatedResults :
public OpRewritePattern<WhileOp> {
4330 LogicalResult matchAndRewrite(WhileOp op,
4331 PatternRewriter &rewriter)
const override {
4332 ConditionOp condOp = op.getConditionOp();
4335 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4337 if (argsSet.size() == condOpArgs.size())
4340 llvm::SmallDenseMap<Value, unsigned> argsMap;
4341 SmallVector<Value> newArgs;
4342 argsMap.reserve(condOpArgs.size());
4343 newArgs.reserve(condOpArgs.size());
4344 for (Value arg : condOpArgs) {
4345 if (!argsMap.count(arg)) {
4346 auto pos =
static_cast<unsigned>(argsMap.size());
4347 argsMap.insert({arg, pos});
4348 newArgs.emplace_back(arg);
4354 Location loc = op.getLoc();
4356 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4359 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4360 Block &newAfterBlock = *newWhileOp.getAfterBody();
4362 SmallVector<Value> afterArgsMapping;
4363 SmallVector<Value> resultsMapping;
4364 for (
auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4365 auto it = argsMap.find(arg);
4366 assert(it != argsMap.end());
4367 auto pos = it->second;
4368 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4369 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4372 OpBuilder::InsertionGuard g(rewriter);
4377 Block &beforeBlock = *op.getBeforeBody();
4378 Block &afterBlock = *op.getAfterBody();
4380 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4382 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4390static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4392 if (args1.size() != args2.size())
4393 return std::nullopt;
4395 SmallVector<unsigned> ret(args1.size());
4396 for (
auto &&[i, arg1] : llvm::enumerate(args1)) {
4397 auto it = llvm::find(args2, arg1);
4398 if (it == args2.end())
4399 return std::nullopt;
4401 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4408 llvm::SmallDenseSet<Value> set;
4409 for (Value arg : args) {
4410 if (!set.insert(arg).second)
4420struct WhileOpAlignBeforeArgs :
public OpRewritePattern<WhileOp> {
4423 LogicalResult matchAndRewrite(WhileOp loop,
4424 PatternRewriter &rewriter)
const override {
4425 auto *oldBefore = loop.getBeforeBody();
4426 ConditionOp oldTerm = loop.getConditionOp();
4427 ValueRange beforeArgs = oldBefore->getArguments();
4429 if (beforeArgs == termArgs)
4432 if (hasDuplicates(termArgs))
4435 auto mapping = getArgsMapping(beforeArgs, termArgs);
4440 OpBuilder::InsertionGuard g(rewriter);
4446 auto *oldAfter = loop.getAfterBody();
4448 SmallVector<Type> newResultTypes(beforeArgs.size());
4449 for (
auto &&[i, j] : llvm::enumerate(*mapping))
4450 newResultTypes[j] = loop.getResult(i).getType();
4452 auto newLoop = WhileOp::create(
4453 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4455 auto *newBefore = newLoop.getBeforeBody();
4456 auto *newAfter = newLoop.getAfterBody();
4458 SmallVector<Value> newResults(beforeArgs.size());
4459 SmallVector<Value> newAfterArgs(beforeArgs.size());
4460 for (
auto &&[i, j] : llvm::enumerate(*mapping)) {
4461 newResults[i] = newLoop.getResult(j);
4462 newAfterArgs[i] = newAfter->getArgument(j);
4466 newBefore->getArguments());
4476void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4477 MLIRContext *context) {
4478 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4479 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4480 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4481 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4496 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4499 caseValues.push_back(value);
4508 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4510 p <<
"case " << value <<
' ';
4515LogicalResult scf::IndexSwitchOp::verify() {
4516 if (getCases().size() != getCaseRegions().size()) {
4518 << getCaseRegions().size() <<
" case regions but "
4519 << getCases().size() <<
" case values";
4523 for (int64_t value : getCases())
4524 if (!valueSet.insert(value).second)
4525 return emitOpError(
"has duplicate case value: ") << value;
4526 auto verifyRegion = [&](Region ®ion,
const Twine &name) -> LogicalResult {
4527 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4529 return emitOpError(
"expected region to end with scf.yield, but got ")
4532 if (yield.getNumOperands() != getNumResults()) {
4533 return (
emitOpError(
"expected each region to return ")
4534 << getNumResults() <<
" values, but " << name <<
" returns "
4535 << yield.getNumOperands())
4536 .attachNote(yield.getLoc())
4537 <<
"see yield operation here";
4539 for (
auto [idx,
result, operand] :
4540 llvm::enumerate(getResultTypes(), yield.getOperands())) {
4542 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
4543 if (
result == operand.getType())
4546 << idx <<
" of each region to be " <<
result)
4547 .attachNote(yield.getLoc())
4548 << name <<
" returns " << operand.getType() <<
" here";
4555 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4562unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4564Block &scf::IndexSwitchOp::getDefaultBlock() {
4565 return getDefaultRegion().front();
4568Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4569 assert(idx < getNumCases() &&
"case index out-of-bounds");
4570 return getCaseRegions()[idx].front();
4573void IndexSwitchOp::getSuccessorRegions(
4574 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4577 successors.emplace_back(getOperation(), getResults());
4581 llvm::append_range(successors, getRegions());
4584void IndexSwitchOp::getEntrySuccessorRegions(
4585 ArrayRef<Attribute> operands,
4586 SmallVectorImpl<RegionSuccessor> &successors) {
4587 FoldAdaptor adaptor(operands, *
this);
4590 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4592 llvm::append_range(successors, getRegions());
4598 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4599 if (caseValue == arg.getInt()) {
4600 successors.emplace_back(&caseRegion);
4604 successors.emplace_back(&getDefaultRegion());
4607void IndexSwitchOp::getRegionInvocationBounds(
4608 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4609 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4610 if (!operandValue) {
4612 bounds.append(getNumRegions(), InvocationBounds(0, 1));
4616 unsigned liveIndex = getNumRegions() - 1;
4617 const auto *it = llvm::find(getCases(), operandValue.getInt());
4618 if (it != getCases().end())
4619 liveIndex = std::distance(getCases().begin(), it);
4620 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4621 bounds.emplace_back(0, i == liveIndex);
4632 if (!maybeCst.has_value())
4635 int64_t caseIdx, e = op.getNumCases();
4636 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4637 if (cst == op.getCases()[caseIdx])
4641 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4642 : op.getDefaultRegion();
4665 BitVector deadResults(op.getNumResults(),
false);
4666 for (
auto [idx,
result] : llvm::enumerate(op.getResults()))
4668 deadResults[idx] =
true;
4669 if (!deadResults.any())
4674 cast<scf::IndexSwitchOp>(rewriter.
eraseOpResults(op, deadResults));
4677 auto updateCaseRegion = [&](
Region ®ion) {
4679 assert(isa<YieldOp>(terminator) &&
"expected yield op");
4681 terminator, [&]() { terminator->
eraseOperands(deadResults); });
4683 updateCaseRegion(newOp.getDefaultRegion());
4684 for (
Region &caseRegion : newOp.getCaseRegions())
4685 updateCaseRegion(caseRegion);
4691void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4692 MLIRContext *context) {
4693 results.
add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
4700#define GET_OP_CLASSES
4701#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Canonicalization patterns that folds away dead results of "scf.index_switch" ops.
LogicalResult matchAndRewrite(IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.