30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
47struct SCFInlinerInterface :
public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
52 IRMapping &valueMapping)
const final {
57 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
62 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
67 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
78void SCFDialect::initialize() {
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
95 scf::YieldOp::create(builder, loc);
100template <
typename TerminatorTy>
102 StringRef errorMessage) {
103 Operation *terminatorOperation =
nullptr;
105 terminatorOperation = ®ion.
front().
back();
106 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
118 auto addOp =
ub.getDefiningOp<arith::AddIOp>();
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
125 if (addOp.getLhs() != lb ||
139 assert(region.
hasOneBlock() &&
"expected single-block region");
159ParseResult ExecuteRegionOp::parse(
OpAsmParser &parser,
187LogicalResult ExecuteRegionOp::verify() {
188 if (getRegion().empty())
189 return emitOpError(
"region needs to have at least one block");
190 if (getRegion().front().getNumArguments() > 0)
191 return emitOpError(
"region cannot have any arguments");
214 if (!op.getRegion().hasOneBlock() || op.getNoInline())
263 if (op.getNoInline())
265 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
268 Block *prevBlock = op->getBlock();
272 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
274 for (
Block &blk : op.getRegion()) {
275 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
277 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278 yieldOp.getResults());
286 for (
auto res : op.getResults())
287 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
304 if (op.getNumResults() == 0)
308 for (
Block &block : op.getRegion()) {
309 if (
auto yield = dyn_cast<scf::YieldOp>(block.
getTerminator()))
310 yieldOps.push_back(yield.getOperation());
313 if (yieldOps.empty())
317 auto yieldOpsOperands = yieldOps[0]->getOperands();
318 for (
auto *yieldOp : yieldOps) {
319 if (yieldOp->getOperands() != yieldOpsOperands)
327 for (
auto [
index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
328 if (isValueFromInsideRegion(yieldedValue, op)) {
329 internalValues.push_back(yieldedValue);
330 opResultsToKeep.push_back(op.getResult(
index));
332 externalValues.push_back(yieldedValue);
333 opResultsToReplaceWithExternalValues.push_back(op.getResult(
index));
337 if (externalValues.empty())
343 for (
Value value : internalValues)
344 resultTypes.push_back(value.getType());
346 ExecuteRegionOp::create(rewriter, op.getLoc(),
TypeRange(resultTypes));
347 newOp->setAttrs(op->getAttrs());
351 newOp.getRegion().end());
355 for (
auto *yieldOp : yieldOps) {
372 bool isValueFromInsideRegion(
Value value,
373 ExecuteRegionOp executeRegionOp)
const {
376 return &executeRegionOp.getRegion() == defOp->getParentRegion();
380 return &executeRegionOp.getRegion() == blockArg.getParentRegion();
392void ExecuteRegionOp::getSuccessorRegions(
412 "condition op can only exit the loop or branch to the after"
415 return getArgsMutable();
418void ConditionOp::getSuccessorRegions(
420 FoldAdaptor adaptor(operands, *
this);
422 WhileOp whileOp = getParentOp();
426 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
427 if (!boolAttr || boolAttr.getValue())
428 regions.emplace_back(&whileOp.getAfter(),
429 whileOp.getAfter().getArguments());
430 if (!boolAttr || !boolAttr.getValue())
431 regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
440 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
444 result.addAttribute(getUnsignedCmpAttrName(
result.name),
447 result.addOperands(initArgs);
448 for (
Value v : initArgs)
449 result.addTypes(v.getType());
454 for (
Value v : initArgs)
460 if (initArgs.empty() && !bodyBuilder) {
461 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
462 }
else if (bodyBuilder) {
470LogicalResult ForOp::verify() {
472 if (getInitArgs().size() != getNumResults())
474 "mismatch in number of loop-carried values and defined values");
479LogicalResult ForOp::verifyRegions() {
484 "expected induction variable to be same type as bounds and step");
486 if (getNumRegionIterArgs() != getNumResults())
488 "mismatch in number of basic block args and defined values");
490 auto initArgs = getInitArgs();
491 auto iterArgs = getRegionIterArgs();
492 auto opResults = getResults();
494 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
496 return emitOpError() <<
"types mismatch between " << i
497 <<
"th iter operand and defined value";
499 return emitOpError() <<
"types mismatch between " << i
500 <<
"th iter region arg and defined value";
507std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
511std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
515std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
519std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
523std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
527LogicalResult ForOp::promoteIfSingleIteration(
RewriterBase &rewriter) {
528 std::optional<APInt> tripCount = getStaticTripCount();
529 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
532 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
535 if (*tripCount == 0) {
542 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
549 llvm::append_range(bbArgReplacements, getInitArgs());
553 getOperation()->getIterator(), bbArgReplacements);
569 StringRef prefix =
"") {
570 assert(blocksArgs.size() == initializers.size() &&
571 "expected same length of arguments and initializers");
572 if (initializers.empty())
576 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
577 p << std::get<0>(it) <<
" = " << std::get<1>(it);
583 if (getUnsignedCmp())
586 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
590 if (!getInitArgs().empty())
591 p <<
" -> (" << getInitArgs().getTypes() <<
')';
594 p <<
" : " << t <<
' ';
597 !getInitArgs().empty());
599 getUnsignedCmpAttrName().strref());
610 result.addAttribute(getUnsignedCmpAttrName(
result.name),
624 regionArgs.push_back(inductionVariable);
634 if (regionArgs.size() !=
result.types.size() + 1)
637 "mismatch in number of loop-carried values and defined values");
646 regionArgs.front().type = type;
647 for (
auto [iterArg, type] :
648 llvm::zip_equal(llvm::drop_begin(regionArgs),
result.types))
655 ForOp::ensureTerminator(*body, builder,
result.location);
664 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
665 operands,
result.types)) {
666 Type type = std::get<2>(argOperandType);
667 std::get<0>(argOperandType).type = type;
684 return getBody()->getArguments().drop_front(getNumInductionVars());
688 return getInitArgsMutable();
691FailureOr<LoopLikeOpInterface>
692ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
694 bool replaceInitOperandUsesInLoop,
699 auto inits = llvm::to_vector(getInitArgs());
700 inits.append(newInitOperands.begin(), newInitOperands.end());
701 scf::ForOp newLoop = scf::ForOp::create(
707 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
709 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
714 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
715 assert(newInitOperands.size() == newYieldedValues.size() &&
716 "expected as many new yield values as new iter operands");
718 yieldOp.getResultsMutable().append(newYieldedValues);
724 newLoop.getBody()->getArguments().take_front(
725 getBody()->getNumArguments()));
727 if (replaceInitOperandUsesInLoop) {
730 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
741 newLoop->getResults().take_front(getNumResults()));
742 return cast<LoopLikeOpInterface>(newLoop.getOperation());
746 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
749 assert(ivArg.getOwner() &&
"unlinked block argument");
750 auto *containingOp = ivArg.getOwner()->getParentOp();
751 return dyn_cast_or_null<ForOp>(containingOp);
755 return getInitArgs();
771LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
772 for (
auto [lb, ub, step] :
773 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
776 if (!tripCount.has_value() || *tripCount != 1)
785 return getBody()->getArguments().drop_front(getRank());
788MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
789 return getOutputsMutable();
795 scf::InParallelOp terminator = forallOp.getTerminator();
800 bbArgReplacements.append(forallOp.getOutputs().begin(),
801 forallOp.getOutputs().end());
805 forallOp->getIterator(), bbArgReplacements);
810 results.reserve(forallOp.getResults().size());
811 for (
auto &yieldingOp : terminator.getYieldingOps()) {
812 auto parallelInsertSliceOp =
813 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
814 if (!parallelInsertSliceOp)
817 Value dst = parallelInsertSliceOp.getDest();
818 Value src = parallelInsertSliceOp.getSource();
819 if (llvm::isa<TensorType>(src.
getType())) {
820 results.push_back(tensor::InsertSliceOp::create(
821 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
822 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
823 parallelInsertSliceOp.getStrides(),
824 parallelInsertSliceOp.getStaticOffsets(),
825 parallelInsertSliceOp.getStaticSizes(),
826 parallelInsertSliceOp.getStaticStrides()));
828 llvm_unreachable(
"unsupported terminator");
843 assert(lbs.size() == ubs.size() &&
844 "expected the same number of lower and upper bounds");
845 assert(lbs.size() == steps.size() &&
846 "expected the same number of lower bounds and steps");
851 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
853 assert(results.size() == iterArgs.size() &&
854 "loop nest body must return as many values as loop has iteration "
856 return LoopNest{{}, std::move(results)};
864 loops.reserve(lbs.size());
865 ivs.reserve(lbs.size());
868 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
869 auto loop = scf::ForOp::create(
870 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
876 currentIterArgs = args;
877 currentLoc = nestedLoc;
883 loops.push_back(loop);
887 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
889 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
896 ? bodyBuilder(builder, currentLoc, ivs,
897 loops.back().getRegionIterArgs())
899 assert(results.size() == iterArgs.size() &&
900 "loop nest body must return as many values as loop has iteration "
903 scf::YieldOp::create(builder, loc, results);
907 llvm::append_range(nestResults, loops.front().getResults());
908 return LoopNest{std::move(loops), std::move(nestResults)};
921 bodyBuilder(nestedBuilder, nestedLoc, ivs);
930 assert(operand.
getOwner() == forOp);
935 "expected an iter OpOperand");
937 "Expected a different type");
939 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
944 newIterOperands.push_back(opOperand.get());
948 scf::ForOp newForOp = scf::ForOp::create(
949 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
950 forOp.getStep(), newIterOperands,
nullptr,
951 forOp.getUnsignedCmp());
952 newForOp->setAttrs(forOp->getAttrs());
953 Block &newBlock = newForOp.getRegion().
front();
961 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
963 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
964 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
968 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
971 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
974 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
975 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
976 clonedYieldOp.getOperand(yieldIdx));
978 newYieldOperands[yieldIdx] = castOut;
979 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
980 rewriter.
eraseOp(clonedYieldOp);
985 newResults[yieldIdx] =
986 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
1006 LogicalResult matchAndRewrite(scf::ForOp forOp,
1008 bool canonicalize =
false;
1015 int64_t numResults = forOp.getNumResults();
1017 keepMask.reserve(numResults);
1020 newBlockTransferArgs.reserve(1 + numResults);
1021 newBlockTransferArgs.push_back(
Value());
1022 newIterArgs.reserve(forOp.getInitArgs().size());
1023 newYieldValues.reserve(numResults);
1024 newResultValues.reserve(numResults);
1026 for (
auto [init, arg,
result, yielded] :
1027 llvm::zip(forOp.getInitArgs(),
1028 forOp.getRegionIterArgs(),
1030 forOp.getYieldedValues()
1037 bool forwarded = (arg == yielded) || (init == yielded) ||
1038 (arg.use_empty() &&
result.use_empty());
1040 canonicalize =
true;
1041 keepMask.push_back(
false);
1042 newBlockTransferArgs.push_back(init);
1043 newResultValues.push_back(init);
1049 if (
auto it = initYieldToArg.find({init, yielded});
1050 it != initYieldToArg.end()) {
1051 canonicalize =
true;
1052 keepMask.push_back(
false);
1053 auto [sameArg, sameResult] = it->second;
1057 newBlockTransferArgs.push_back(init);
1058 newResultValues.push_back(init);
1063 initYieldToArg.insert({{init, yielded}, {arg,
result}});
1064 keepMask.push_back(
true);
1065 newIterArgs.push_back(init);
1066 newYieldValues.push_back(yielded);
1067 newBlockTransferArgs.push_back(Value());
1068 newResultValues.push_back(Value());
1074 scf::ForOp newForOp =
1075 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
1076 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
1077 nullptr, forOp.getUnsignedCmp());
1078 newForOp->setAttrs(forOp->getAttrs());
1079 Block &newBlock = newForOp.getRegion().front();
1082 newBlockTransferArgs[0] = newBlock.
getArgument(0);
1083 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
1085 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
1086 Value &newResultVal = newResultValues[idx];
1087 assert((blockTransferArg && newResultVal) ||
1088 (!blockTransferArg && !newResultVal));
1089 if (!blockTransferArg) {
1090 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
1091 newResultVal = newForOp.getResult(collapsedIdx++);
1095 Block &oldBlock = forOp.getRegion().front();
1097 "unexpected argument size mismatch");
1102 if (newIterArgs.empty()) {
1103 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1106 rewriter.
replaceOp(forOp, newResultValues);
1111 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1112 OpBuilder::InsertionGuard g(rewriter);
1114 SmallVector<Value, 4> filteredOperands;
1115 filteredOperands.reserve(newResultValues.size());
1116 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1118 filteredOperands.push_back(mergedTerminator.getOperand(idx));
1119 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1123 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1124 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1125 cloneFilteredTerminator(mergedYieldOp);
1126 rewriter.
eraseOp(mergedYieldOp);
1127 rewriter.
replaceOp(forOp, newResultValues);
1135struct SimplifyTrivialLoops :
public OpRewritePattern<ForOp> {
1136 using OpRewritePattern<ForOp>::OpRewritePattern;
1138 LogicalResult matchAndRewrite(ForOp op,
1139 PatternRewriter &rewriter)
const override {
1140 std::optional<APInt> tripCount = op.getStaticTripCount();
1141 if (!tripCount.has_value())
1143 "can't compute constant trip count");
1145 if (tripCount->isZero()) {
1146 LDBG() <<
"SimplifyTrivialLoops tripCount is 0 for loop "
1147 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1148 rewriter.
replaceOp(op, op.getInitArgs());
1152 if (tripCount->getSExtValue() == 1) {
1153 LDBG() <<
"SimplifyTrivialLoops tripCount is 1 for loop "
1154 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1155 SmallVector<Value, 4> blockArgs;
1156 blockArgs.reserve(op.getInitArgs().size() + 1);
1157 blockArgs.push_back(op.getLowerBound());
1158 llvm::append_range(blockArgs, op.getInitArgs());
1165 if (!llvm::hasSingleElement(block))
1169 if (llvm::any_of(op.getYieldedValues(),
1170 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1172 LDBG() <<
"SimplifyTrivialLoops empty body loop allows replacement with "
1173 "yield operands for loop "
1174 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1175 rewriter.
replaceOp(op, op.getYieldedValues());
1206struct ForOpTensorCastFolder :
public OpRewritePattern<ForOp> {
1207 using OpRewritePattern<ForOp>::OpRewritePattern;
1209 LogicalResult matchAndRewrite(ForOp op,
1210 PatternRewriter &rewriter)
const override {
1211 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1212 OpOperand &iterOpOperand = std::get<0>(it);
1214 if (!incomingCast ||
1215 incomingCast.getSource().getType() == incomingCast.getType())
1220 incomingCast.getDest().getType(),
1221 incomingCast.getSource().getType()))
1223 if (!std::get<1>(it).hasOneUse())
1229 rewriter, op, iterOpOperand, incomingCast.getSource(),
1230 [](OpBuilder &
b, Location loc, Type type, Value source) {
1231 return tensor::CastOp::create(b, loc, type, source);
1241void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1242 MLIRContext *context) {
1243 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1247std::optional<APInt> ForOp::getConstantStep() {
1250 return step.getValue();
1254std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1255 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1261 if (
auto constantStep = getConstantStep())
1262 if (*constantStep == 1)
1270std::optional<APInt> ForOp::getStaticTripCount() {
1279LogicalResult ForallOp::verify() {
1280 unsigned numLoops = getRank();
1282 if (getNumResults() != getOutputs().size())
1284 << getNumResults() <<
" results, but has only "
1285 << getOutputs().size() <<
" outputs";
1288 auto *body = getBody();
1290 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1291 for (int64_t i = 0; i < numLoops; ++i)
1294 << i <<
"-th block argument to be an index";
1295 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1298 << i <<
"-th output and corresponding block argument";
1299 if (getMapping().has_value() && !getMapping()->empty()) {
1300 if (getDeviceMappingAttrs().size() != numLoops)
1301 return emitOpError() <<
"mapping attribute size must match op rank";
1302 if (
failed(getDeviceMaskingAttr()))
1304 <<
" supports at most one device masking attribute";
1308 Operation *op = getOperation();
1310 getStaticLowerBound(),
1311 getDynamicLowerBound())))
1314 getStaticUpperBound(),
1315 getDynamicUpperBound())))
1318 getStaticStep(), getDynamicStep())))
1324void ForallOp::print(OpAsmPrinter &p) {
1325 Operation *op = getOperation();
1326 p <<
" (" << getInductionVars();
1327 if (isNormalized()) {
1348 if (!getRegionOutArgs().empty())
1349 p <<
"-> (" << getResultTypes() <<
") ";
1350 p.printRegion(getRegion(),
1352 getNumResults() > 0);
1353 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1354 getStaticLowerBoundAttrName(),
1355 getStaticUpperBoundAttrName(),
1356 getStaticStepAttrName()});
1359ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &
result) {
1361 auto indexType =
b.getIndexType();
1366 SmallVector<OpAsmParser::Argument, 4> ivs;
1371 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1381 unsigned numLoops = ivs.size();
1382 staticLbs =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1383 staticSteps =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1412 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1413 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1416 if (outOperands.size() !=
result.types.size())
1418 "mismatch between out operands and types");
1427 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1428 std::unique_ptr<Region> region = std::make_unique<Region>();
1429 for (
auto &iv : ivs) {
1430 iv.type =
b.getIndexType();
1431 regionArgs.push_back(iv);
1433 for (
const auto &it : llvm::enumerate(regionOutArgs)) {
1434 auto &out = it.value();
1435 out.type =
result.types[it.index()];
1436 regionArgs.push_back(out);
1442 ForallOp::ensureTerminator(*region,
b,
result.location);
1443 result.addRegion(std::move(region));
1449 result.addAttribute(
"staticLowerBound", staticLbs);
1450 result.addAttribute(
"staticUpperBound", staticUbs);
1451 result.addAttribute(
"staticStep", staticSteps);
1452 result.addAttribute(
"operandSegmentSizes",
1454 {static_cast<int32_t>(dynamicLbs.size()),
1455 static_cast<int32_t>(dynamicUbs.size()),
1456 static_cast<int32_t>(dynamicSteps.size()),
1457 static_cast<int32_t>(outOperands.size())}));
1462void ForallOp::build(
1463 mlir::OpBuilder &
b, mlir::OperationState &
result,
1464 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1465 ArrayRef<OpFoldResult> steps,
ValueRange outputs,
1466 std::optional<ArrayAttr> mapping,
1468 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1469 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1474 result.addOperands(dynamicLbs);
1475 result.addOperands(dynamicUbs);
1476 result.addOperands(dynamicSteps);
1477 result.addOperands(outputs);
1480 result.addAttribute(getStaticLowerBoundAttrName(
result.name),
1481 b.getDenseI64ArrayAttr(staticLbs));
1482 result.addAttribute(getStaticUpperBoundAttrName(
result.name),
1483 b.getDenseI64ArrayAttr(staticUbs));
1484 result.addAttribute(getStaticStepAttrName(
result.name),
1485 b.getDenseI64ArrayAttr(staticSteps));
1487 "operandSegmentSizes",
1488 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1489 static_cast<int32_t>(dynamicUbs.size()),
1490 static_cast<int32_t>(dynamicSteps.size()),
1491 static_cast<int32_t>(outputs.size())}));
1492 if (mapping.has_value()) {
1493 result.addAttribute(ForallOp::getMappingAttrName(
result.name),
1497 Region *bodyRegion =
result.addRegion();
1498 OpBuilder::InsertionGuard g(
b);
1499 b.createBlock(bodyRegion);
1504 SmallVector<Type>(lbs.size(),
b.getIndexType()),
1505 SmallVector<Location>(staticLbs.size(),
result.location));
1508 SmallVector<Location>(outputs.size(),
result.location));
1510 b.setInsertionPointToStart(&bodyBlock);
1511 if (!bodyBuilderFn) {
1512 ForallOp::ensureTerminator(*bodyRegion,
b,
result.location);
1519void ForallOp::build(
1520 mlir::OpBuilder &
b, mlir::OperationState &
result,
1521 ArrayRef<OpFoldResult> ubs,
ValueRange outputs,
1522 std::optional<ArrayAttr> mapping,
1524 unsigned numLoops = ubs.size();
1525 SmallVector<OpFoldResult> lbs(numLoops,
b.getIndexAttr(0));
1526 SmallVector<OpFoldResult> steps(numLoops,
b.getIndexAttr(1));
1527 build(
b,
result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1531bool ForallOp::isNormalized() {
1532 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1533 return llvm::all_of(results, [&](OpFoldResult ofr) {
1535 return intValue.has_value() && intValue == val;
1538 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1541InParallelOp ForallOp::getTerminator() {
1542 return cast<InParallelOp>(getBody()->getTerminator());
1545SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1546 SmallVector<Operation *> storeOps;
1547 for (Operation *user : bbArg.
getUsers()) {
1548 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1549 storeOps.push_back(parallelOp);
1555SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1556 SmallVector<DeviceMappingAttrInterface> res;
1559 for (
auto attr : getMapping()->getValue()) {
1560 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1567FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1568 DeviceMaskingAttrInterface res;
1571 for (
auto attr : getMapping()->getValue()) {
1572 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1581bool ForallOp::usesLinearMapping() {
1582 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1585 return ifaces.front().isLinearMapping();
1588std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1589 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1593std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1595 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(),
b);
1599std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1601 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(),
b);
1605std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1611 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1614 assert(tidxArg.getOwner() &&
"unlinked block argument");
1615 auto *containingOp = tidxArg.getOwner()->getParentOp();
1616 return dyn_cast<ForallOp>(containingOp);
1624 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1626 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1630 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1633 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1638class ForallOpControlOperandsFolder :
public OpRewritePattern<ForallOp> {
1640 using OpRewritePattern<ForallOp>::OpRewritePattern;
1642 LogicalResult matchAndRewrite(ForallOp op,
1643 PatternRewriter &rewriter)
const override {
1644 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1645 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1646 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1653 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1654 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1657 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1658 op.setStaticLowerBound(staticLowerBound);
1662 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1663 op.setStaticUpperBound(staticUpperBound);
1666 op.getDynamicStepMutable().assign(dynamicStep);
1667 op.setStaticStep(staticStep);
1669 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1671 {static_cast<int32_t>(dynamicLowerBound.size()),
1672 static_cast<int32_t>(dynamicUpperBound.size()),
1673 static_cast<int32_t>(dynamicStep.size()),
1674 static_cast<int32_t>(op.getNumResults())}));
1753struct ForallOpIterArgsFolder :
public OpRewritePattern<ForallOp> {
1754 using OpRewritePattern<ForallOp>::OpRewritePattern;
1756 LogicalResult matchAndRewrite(ForallOp forallOp,
1757 PatternRewriter &rewriter)
const final {
1773 SmallVector<Value> resultToReplace;
1774 SmallVector<Value> newOuts;
1775 for (OpResult
result : forallOp.getResults()) {
1776 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1777 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1778 if (
result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1779 resultToDelete.insert(
result);
1781 resultToReplace.push_back(
result);
1782 newOuts.push_back(opOperand->
get());
1788 if (resultToDelete.empty())
1796 for (OpResult
result : resultToDelete) {
1797 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1798 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1799 SmallVector<Operation *> combiningOps =
1800 forallOp.getCombiningOps(blockArg);
1801 for (Operation *combiningOp : combiningOps)
1802 rewriter.
eraseOp(combiningOp);
1807 auto newForallOp = scf::ForallOp::create(
1808 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1809 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1810 forallOp.getMapping(),
1815 Block *loopBody = forallOp.getBody();
1816 Block *newLoopBody = newForallOp.getBody();
1817 ArrayRef<BlockArgument> newBbArgs = newLoopBody->
getArguments();
1820 SmallVector<Value> newBlockArgs =
1821 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1822 [](BlockArgument
b) -> Value { return b; });
1828 for (OpResult
result : forallOp.getResults()) {
1829 if (resultToDelete.count(
result)) {
1830 newBlockArgs.push_back(forallOp.getTiedOpOperand(
result)->get());
1832 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1835 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1839 for (
auto &&[oldResult, newResult] :
1840 llvm::zip(resultToReplace, newForallOp->getResults()))
1846 for (OpResult oldResult : resultToDelete)
1848 forallOp.getTiedOpOperand(oldResult)->get());
1853struct ForallOpSingleOrZeroIterationDimsFolder
1854 :
public OpRewritePattern<ForallOp> {
1855 using OpRewritePattern<ForallOp>::OpRewritePattern;
1857 LogicalResult matchAndRewrite(ForallOp op,
1858 PatternRewriter &rewriter)
const override {
1860 if (op.getMapping().has_value() && !op.getMapping()->empty())
1862 Location loc = op.getLoc();
1865 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1868 for (
auto [lb, ub, step, iv] :
1869 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1870 op.getMixedStep(), op.getInductionVars())) {
1871 auto numIterations =
1873 if (numIterations.has_value()) {
1875 if (*numIterations == 0) {
1876 rewriter.
replaceOp(op, op.getOutputs());
1881 if (*numIterations == 1) {
1886 newMixedLowerBounds.push_back(lb);
1887 newMixedUpperBounds.push_back(ub);
1888 newMixedSteps.push_back(step);
1892 if (newMixedLowerBounds.empty()) {
1898 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1900 op,
"no dimensions have 0 or 1 iterations");
1905 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1906 newMixedUpperBounds, newMixedSteps,
1907 op.getOutputs(), std::nullopt,
nullptr);
1908 newOp.getBodyRegion().getBlocks().clear();
1912 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1913 newOp.getStaticLowerBoundAttrName(),
1914 newOp.getStaticUpperBoundAttrName(),
1915 newOp.getStaticStepAttrName()};
1916 for (
const auto &namedAttr : op->getAttrs()) {
1917 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1920 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1924 newOp.getRegion().begin(), mapping);
1925 rewriter.
replaceOp(op, newOp.getResults());
1931struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1932 using OpRewritePattern<ForallOp>::OpRewritePattern;
1934 LogicalResult matchAndRewrite(ForallOp op,
1935 PatternRewriter &rewriter)
const override {
1936 Location loc = op.getLoc();
1938 for (
auto [lb, ub, step, iv] :
1939 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1940 op.getMixedStep(), op.getInductionVars())) {
1943 auto numIterations =
1945 if (!numIterations.has_value() || numIterations.value() != 1) {
1956struct FoldTensorCastOfOutputIntoForallOp
1957 :
public OpRewritePattern<scf::ForallOp> {
1958 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1965 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1966 PatternRewriter &rewriter)
const final {
1967 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1968 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1969 for (
auto en : llvm::enumerate(newOutputTensors)) {
1970 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1977 castOp.getSource().getType())) {
1981 tensorCastProducers[en.index()] =
1982 TypeCast{castOp.getSource().getType(), castOp.getType()};
1983 newOutputTensors[en.index()] = castOp.getSource();
1986 if (tensorCastProducers.empty())
1990 Location loc = forallOp.getLoc();
1991 auto newForallOp = ForallOp::create(
1992 rewriter, loc, forallOp.getMixedLowerBound(),
1993 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1994 newOutputTensors, forallOp.getMapping(),
1995 [&](OpBuilder nestedBuilder, Location nestedLoc,
ValueRange bbArgs) {
1996 auto castBlockArgs =
1997 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1998 for (auto [index, cast] : tensorCastProducers) {
1999 Value &oldTypeBBArg = castBlockArgs[index];
2000 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
2001 cast.dstType, oldTypeBBArg);
2005 SmallVector<Value> ivsBlockArgs =
2006 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
2007 ivsBlockArgs.append(castBlockArgs);
2009 bbArgs.front().getParentBlock(), ivsBlockArgs);
2015 auto terminator = newForallOp.getTerminator();
2016 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
2017 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
2018 if (
auto parallelCombingingOp =
2019 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
2020 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
2026 SmallVector<Value> castResults = newForallOp.getResults();
2027 for (
auto &item : tensorCastProducers) {
2028 Value &oldTypeResult = castResults[item.first];
2029 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
2032 rewriter.
replaceOp(forallOp, castResults);
2039void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
2040 MLIRContext *context) {
2041 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
2042 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
2043 ForallOpSingleOrZeroIterationDimsFolder,
2044 ForallOpReplaceConstantInductionVar>(context);
2052void ForallOp::getSuccessorRegions(RegionBranchPoint point,
2053 SmallVectorImpl<RegionSuccessor> ®ions) {
2058 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2061 RegionSuccessor(getOperation(), getOperation()->getResults()));
2069void InParallelOp::build(OpBuilder &
b, OperationState &
result) {
2070 OpBuilder::InsertionGuard g(
b);
2071 Region *bodyRegion =
result.addRegion();
2072 b.createBlock(bodyRegion);
2075LogicalResult InParallelOp::verify() {
2076 scf::ForallOp forallOp =
2077 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
2079 return this->
emitOpError(
"expected forall op parent");
2081 for (Operation &op : getRegion().front().getOperations()) {
2082 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
2083 if (!parallelCombiningOp) {
2084 return this->
emitOpError(
"expected only ParallelCombiningOpInterface")
2089 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
2090 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
2091 for (OpOperand &dest : dests) {
2092 if (!llvm::is_contained(regionOutArgs, dest.get()))
2093 return op.emitOpError(
"may only insert into an output block argument");
2100void InParallelOp::print(OpAsmPrinter &p) {
2108ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
2111 SmallVector<OpAsmParser::Argument, 8> regionOperands;
2112 std::unique_ptr<Region> region = std::make_unique<Region>();
2116 if (region->empty())
2117 OpBuilder(builder.
getContext()).createBlock(region.get());
2118 result.addRegion(std::move(region));
2126OpResult InParallelOp::getParentResult(int64_t idx) {
2127 return getOperation()->getParentOp()->getResult(idx);
2130SmallVector<BlockArgument> InParallelOp::getDests() {
2131 SmallVector<BlockArgument> updatedDests;
2132 for (Operation &yieldingOp : getYieldingOps()) {
2133 auto parallelCombiningOp =
2134 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2135 if (!parallelCombiningOp)
2137 for (OpOperand &updatedOperand :
2138 parallelCombiningOp.getUpdatedDestinations())
2139 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2141 return updatedDests;
2144llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2145 return getRegion().front().getOperations();
2153 assert(a &&
"expected non-empty operation");
2154 assert(
b &&
"expected non-empty operation");
2159 if (ifOp->isProperAncestor(
b))
2162 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2163 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*
b));
2165 ifOp = ifOp->getParentOfType<IfOp>();
2173IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2174 IfOp::Adaptor adaptor,
2176 if (adaptor.getRegions().empty())
2178 Region *r = &adaptor.getThenRegion();
2184 auto yieldOp = llvm::dyn_cast<YieldOp>(
b.back());
2187 TypeRange types = yieldOp.getOperandTypes();
2188 llvm::append_range(inferredReturnTypes, types);
2194 return build(builder,
result, resultTypes, cond,
false,
2198void IfOp::build(OpBuilder &builder, OperationState &
result,
2199 TypeRange resultTypes, Value cond,
bool addThenBlock,
2200 bool addElseBlock) {
2201 assert((!addElseBlock || addThenBlock) &&
2202 "must not create else block w/o then block");
2203 result.addTypes(resultTypes);
2204 result.addOperands(cond);
2207 OpBuilder::InsertionGuard guard(builder);
2208 Region *thenRegion =
result.addRegion();
2211 Region *elseRegion =
result.addRegion();
2216void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2217 bool withElseRegion) {
2221void IfOp::build(OpBuilder &builder, OperationState &
result,
2222 TypeRange resultTypes, Value cond,
bool withElseRegion) {
2223 result.addTypes(resultTypes);
2224 result.addOperands(cond);
2227 OpBuilder::InsertionGuard guard(builder);
2228 Region *thenRegion =
result.addRegion();
2230 if (resultTypes.empty())
2231 IfOp::ensureTerminator(*thenRegion, builder,
result.location);
2234 Region *elseRegion =
result.addRegion();
2235 if (withElseRegion) {
2237 if (resultTypes.empty())
2238 IfOp::ensureTerminator(*elseRegion, builder,
result.location);
2242void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2244 function_ref<
void(OpBuilder &, Location)> elseBuilder) {
2245 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2246 result.addOperands(cond);
2249 OpBuilder::InsertionGuard guard(builder);
2250 Region *thenRegion =
result.addRegion();
2252 thenBuilder(builder,
result.location);
2255 Region *elseRegion =
result.addRegion();
2258 elseBuilder(builder,
result.location);
2262 SmallVector<Type> inferredReturnTypes;
2264 auto attrDict = DictionaryAttr::get(ctx,
result.attributes);
2265 if (succeeded(inferReturnTypes(ctx, std::nullopt,
result.operands, attrDict,
2267 inferredReturnTypes))) {
2268 result.addTypes(inferredReturnTypes);
2272LogicalResult IfOp::verify() {
2273 if (getNumResults() != 0 && getElseRegion().empty())
2274 return emitOpError(
"must have an else block if defining values");
2278ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
2280 result.regions.reserve(2);
2281 Region *thenRegion =
result.addRegion();
2282 Region *elseRegion =
result.addRegion();
2285 OpAsmParser::UnresolvedOperand cond;
2311void IfOp::print(OpAsmPrinter &p) {
2312 bool printBlockTerminators =
false;
2314 p <<
" " << getCondition();
2315 if (!getResults().empty()) {
2316 p <<
" -> (" << getResultTypes() <<
")";
2318 printBlockTerminators =
true;
2323 printBlockTerminators);
2326 auto &elseRegion = getElseRegion();
2327 if (!elseRegion.
empty()) {
2331 printBlockTerminators);
2337void IfOp::getSuccessorRegions(RegionBranchPoint point,
2338 SmallVectorImpl<RegionSuccessor> ®ions) {
2342 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2346 regions.push_back(RegionSuccessor(&getThenRegion()));
2349 Region *elseRegion = &this->getElseRegion();
2350 if (elseRegion->
empty())
2352 RegionSuccessor(getOperation(), getOperation()->getResults()));
2354 regions.push_back(RegionSuccessor(elseRegion));
2357void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2358 SmallVectorImpl<RegionSuccessor> ®ions) {
2359 FoldAdaptor adaptor(operands, *
this);
2360 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2361 if (!boolAttr || boolAttr.getValue())
2362 regions.emplace_back(&getThenRegion());
2365 if (!boolAttr || !boolAttr.getValue()) {
2366 if (!getElseRegion().empty())
2367 regions.emplace_back(&getElseRegion());
2369 regions.emplace_back(getOperation(), getResults());
2373LogicalResult IfOp::fold(FoldAdaptor adaptor,
2374 SmallVectorImpl<OpFoldResult> &results) {
2376 if (getElseRegion().empty())
2379 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2386 getConditionMutable().assign(xorStmt.getLhs());
2387 Block *thenBlock = &getThenRegion().front();
2390 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2391 getElseRegion().getBlocks());
2392 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2393 getThenRegion().getBlocks(), thenBlock);
2397void IfOp::getRegionInvocationBounds(
2398 ArrayRef<Attribute> operands,
2399 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2400 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2403 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2404 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2407 invocationBounds.assign(2, {0, 1});
2413struct RemoveUnusedResults :
public OpRewritePattern<IfOp> {
2414 using OpRewritePattern<IfOp>::OpRewritePattern;
2416 void transferBody(
Block *source,
Block *dest, ArrayRef<OpResult> usedResults,
2417 PatternRewriter &rewriter)
const {
2422 SmallVector<Value, 4> usedOperands;
2423 llvm::transform(usedResults, std::back_inserter(usedOperands),
2425 return yieldOp.getOperand(
result.getResultNumber());
2428 [&]() { yieldOp->setOperands(usedOperands); });
2431 LogicalResult matchAndRewrite(IfOp op,
2432 PatternRewriter &rewriter)
const override {
2434 SmallVector<OpResult, 4> usedResults;
2435 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2436 [](OpResult
result) { return !result.use_empty(); });
2439 if (usedResults.size() == op.getNumResults())
2443 SmallVector<Type, 4> newTypes;
2444 llvm::transform(usedResults, std::back_inserter(newTypes),
2449 IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2455 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2456 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2459 SmallVector<Value, 4> repResults(op.getNumResults());
2460 for (
const auto &en : llvm::enumerate(usedResults))
2461 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2467struct RemoveStaticCondition :
public OpRewritePattern<IfOp> {
2468 using OpRewritePattern<IfOp>::OpRewritePattern;
2470 LogicalResult matchAndRewrite(IfOp op,
2471 PatternRewriter &rewriter)
const override {
2478 else if (!op.getElseRegion().empty())
2489struct ConvertTrivialIfToSelect :
public OpRewritePattern<IfOp> {
2490 using OpRewritePattern<IfOp>::OpRewritePattern;
2492 LogicalResult matchAndRewrite(IfOp op,
2493 PatternRewriter &rewriter)
const override {
2494 if (op->getNumResults() == 0)
2497 auto cond = op.getCondition();
2498 auto thenYieldArgs = op.thenYield().getOperands();
2499 auto elseYieldArgs = op.elseYield().getOperands();
2501 SmallVector<Type> nonHoistable;
2502 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2503 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2504 &op.getElseRegion() == falseVal.getParentRegion())
2505 nonHoistable.push_back(trueVal.getType());
2509 if (nonHoistable.size() == op->getNumResults())
2512 IfOp
replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2516 replacement.getThenRegion().takeBody(op.getThenRegion());
2517 replacement.getElseRegion().takeBody(op.getElseRegion());
2519 SmallVector<Value> results(op->getNumResults());
2520 assert(thenYieldArgs.size() == results.size());
2521 assert(elseYieldArgs.size() == results.size());
2523 SmallVector<Value> trueYields;
2524 SmallVector<Value> falseYields;
2526 for (
const auto &it :
2527 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2528 Value trueVal = std::get<0>(it.value());
2529 Value falseVal = std::get<1>(it.value());
2532 results[it.index()] =
replacement.getResult(trueYields.size());
2533 trueYields.push_back(trueVal);
2534 falseYields.push_back(falseVal);
2535 }
else if (trueVal == falseVal)
2536 results[it.index()] = trueVal;
2538 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2539 cond, trueVal, falseVal);
2566struct ConditionPropagation :
public OpRewritePattern<IfOp> {
2567 using OpRewritePattern<IfOp>::OpRewritePattern;
2570 enum class Parent { Then, Else,
None };
2575 static Parent getParentType(Region *toCheck, IfOp op,
2577 Region *endRegion) {
2578 SmallVector<Region *> seen;
2579 while (toCheck != endRegion) {
2580 auto found = cache.find(toCheck);
2581 if (found != cache.end())
2582 return found->second;
2583 seen.push_back(toCheck);
2584 if (&op.getThenRegion() == toCheck) {
2585 for (Region *region : seen)
2586 cache[region] = Parent::Then;
2587 return Parent::Then;
2589 if (&op.getElseRegion() == toCheck) {
2590 for (Region *region : seen)
2591 cache[region] = Parent::Else;
2592 return Parent::Else;
2597 for (Region *region : seen)
2598 cache[region] = Parent::None;
2599 return Parent::None;
2602 LogicalResult matchAndRewrite(IfOp op,
2603 PatternRewriter &rewriter)
const override {
2614 Value constantTrue =
nullptr;
2615 Value constantFalse =
nullptr;
2618 for (OpOperand &use :
2619 llvm::make_early_inc_range(op.getCondition().getUses())) {
2622 case Parent::Then: {
2626 constantTrue = arith::ConstantOp::create(
2630 [&]() { use.set(constantTrue); });
2633 case Parent::Else: {
2637 constantFalse = arith::ConstantOp::create(
2641 [&]() { use.set(constantFalse); });
2689struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2690 using OpRewritePattern<IfOp>::OpRewritePattern;
2692 LogicalResult matchAndRewrite(IfOp op,
2693 PatternRewriter &rewriter)
const override {
2695 if (op.getNumResults() == 0)
2699 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2701 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2704 op.getOperation()->getIterator());
2707 for (
auto [trueResult, falseResult, opResult] :
2708 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2710 if (trueResult == falseResult) {
2711 if (!opResult.use_empty()) {
2712 opResult.replaceAllUsesWith(trueResult);
2718 BoolAttr trueYield, falseYield;
2723 bool trueVal = trueYield.
getValue();
2724 bool falseVal = falseYield.
getValue();
2725 if (!trueVal && falseVal) {
2726 if (!opResult.use_empty()) {
2727 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2728 Value notCond = arith::XOrIOp::create(
2729 rewriter, op.getLoc(), op.getCondition(),
2735 opResult.replaceAllUsesWith(notCond);
2739 if (trueVal && !falseVal) {
2740 if (!opResult.use_empty()) {
2741 opResult.replaceAllUsesWith(op.getCondition());
2771struct CombineIfs :
public OpRewritePattern<IfOp> {
2772 using OpRewritePattern<IfOp>::OpRewritePattern;
2774 LogicalResult matchAndRewrite(IfOp nextIf,
2775 PatternRewriter &rewriter)
const override {
2776 Block *parent = nextIf->getBlock();
2777 if (nextIf == &parent->
front())
2780 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2788 Block *nextThen =
nullptr;
2789 Block *nextElse =
nullptr;
2790 if (nextIf.getCondition() == prevIf.getCondition()) {
2791 nextThen = nextIf.thenBlock();
2792 if (!nextIf.getElseRegion().empty())
2793 nextElse = nextIf.elseBlock();
2795 if (arith::XOrIOp notv =
2796 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2797 if (notv.getLhs() == prevIf.getCondition() &&
2799 nextElse = nextIf.thenBlock();
2800 if (!nextIf.getElseRegion().empty())
2801 nextThen = nextIf.elseBlock();
2804 if (arith::XOrIOp notv =
2805 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2806 if (notv.getLhs() == nextIf.getCondition() &&
2808 nextElse = nextIf.thenBlock();
2809 if (!nextIf.getElseRegion().empty())
2810 nextThen = nextIf.elseBlock();
2814 if (!nextThen && !nextElse)
2817 SmallVector<Value> prevElseYielded;
2818 if (!prevIf.getElseRegion().empty())
2819 prevElseYielded = prevIf.elseYield().getOperands();
2822 for (
auto it : llvm::zip(prevIf.getResults(),
2823 prevIf.thenYield().getOperands(), prevElseYielded))
2824 for (OpOperand &use :
2825 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2829 use.
set(std::get<1>(it));
2834 use.
set(std::get<2>(it));
2839 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2840 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2842 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2843 prevIf.getCondition(),
false);
2844 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2847 combinedIf.getThenRegion(),
2848 combinedIf.getThenRegion().begin());
2851 YieldOp thenYield = combinedIf.thenYield();
2852 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2853 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2856 SmallVector<Value> mergedYields(thenYield.getOperands());
2857 llvm::append_range(mergedYields, thenYield2.getOperands());
2858 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2864 combinedIf.getElseRegion(),
2865 combinedIf.getElseRegion().begin());
2868 if (combinedIf.getElseRegion().empty()) {
2870 combinedIf.getElseRegion(),
2871 combinedIf.getElseRegion().
begin());
2873 YieldOp elseYield = combinedIf.elseYield();
2874 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2875 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2879 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2880 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2882 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2888 SmallVector<Value> prevValues;
2889 SmallVector<Value> nextValues;
2890 for (
const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2891 if (pair.index() < prevIf.getNumResults())
2892 prevValues.push_back(pair.value());
2894 nextValues.push_back(pair.value());
2903struct RemoveEmptyElseBranch :
public OpRewritePattern<IfOp> {
2904 using OpRewritePattern<IfOp>::OpRewritePattern;
2906 LogicalResult matchAndRewrite(IfOp ifOp,
2907 PatternRewriter &rewriter)
const override {
2909 if (ifOp.getNumResults())
2911 Block *elseBlock = ifOp.elseBlock();
2912 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2916 newIfOp.getThenRegion().begin());
2938struct CombineNestedIfs :
public OpRewritePattern<IfOp> {
2939 using OpRewritePattern<IfOp>::OpRewritePattern;
2941 LogicalResult matchAndRewrite(IfOp op,
2942 PatternRewriter &rewriter)
const override {
2943 auto nestedOps = op.thenBlock()->without_terminator();
2945 if (!llvm::hasSingleElement(nestedOps))
2949 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2952 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2956 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2959 SmallVector<Value> thenYield(op.thenYield().getOperands());
2960 SmallVector<Value> elseYield;
2962 llvm::append_range(elseYield, op.elseYield().getOperands());
2966 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2975 for (
const auto &tup : llvm::enumerate(thenYield)) {
2976 if (tup.value().getDefiningOp() == nestedIf) {
2977 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2978 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2979 elseYield[tup.index()]) {
2984 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2997 if (tup.value().getParentRegion() == &op.getThenRegion()) {
3000 elseYieldsToUpgradeToSelect.push_back(tup.index());
3003 Location loc = op.getLoc();
3004 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
3005 nestedIf.getCondition());
3006 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
3009 SmallVector<Value> results;
3010 llvm::append_range(results, newIf.getResults());
3013 for (
auto idx : elseYieldsToUpgradeToSelect)
3015 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
3016 thenYield[idx], elseYield[idx]);
3018 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
3021 if (!elseYield.empty()) {
3024 YieldOp::create(rewriter, loc, elseYield);
3033void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3034 MLIRContext *context) {
3035 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
3036 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
3037 RemoveStaticCondition, RemoveUnusedResults,
3038 ReplaceIfYieldWithConditionOrValue>(context);
3041Block *IfOp::thenBlock() {
return &getThenRegion().back(); }
3042YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
3043Block *IfOp::elseBlock() {
3044 Region &r = getElseRegion();
3049YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
3055void ParallelOp::build(
3060 result.addOperands(lowerBounds);
3061 result.addOperands(upperBounds);
3062 result.addOperands(steps);
3063 result.addOperands(initVals);
3065 ParallelOp::getOperandSegmentSizeAttr(),
3067 static_cast<int32_t>(upperBounds.size()),
3068 static_cast<int32_t>(steps.size()),
3069 static_cast<int32_t>(initVals.size())}));
3072 OpBuilder::InsertionGuard guard(builder);
3073 unsigned numIVs = steps.size();
3074 SmallVector<Type, 8> argTypes(numIVs, builder.
getIndexType());
3075 SmallVector<Location, 8> argLocs(numIVs,
result.location);
3076 Region *bodyRegion =
result.addRegion();
3079 if (bodyBuilderFn) {
3081 bodyBuilderFn(builder,
result.location,
3086 if (initVals.empty())
3087 ParallelOp::ensureTerminator(*bodyRegion, builder,
result.location);
3090void ParallelOp::build(
3097 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
3100 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3104 wrapper = wrappedBuilderFn;
3110LogicalResult ParallelOp::verify() {
3115 if (stepValues.empty())
3117 "needs at least one tuple element for lowerBound, upperBound and step");
3120 for (Value stepValue : stepValues)
3123 return emitOpError(
"constant step operand must be positive");
3127 Block *body = getBody();
3129 return emitOpError() <<
"expects the same number of induction variables: "
3131 <<
" as bound and step values: " << stepValues.size();
3133 if (!arg.getType().isIndex())
3135 "expects arguments for the induction variable to be of index type");
3139 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
3144 auto resultsSize = getResults().size();
3145 auto reductionsSize = reduceOp.getReductions().size();
3146 auto initValsSize = getInitVals().size();
3147 if (resultsSize != reductionsSize)
3148 return emitOpError() <<
"expects number of results: " << resultsSize
3149 <<
" to be the same as number of reductions: "
3151 if (resultsSize != initValsSize)
3152 return emitOpError() <<
"expects number of results: " << resultsSize
3153 <<
" to be the same as number of initial values: "
3155 if (reduceOp.getNumOperands() != initValsSize)
3160 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3161 auto resultType = getOperation()->getResult(i).getType();
3162 auto reductionOperandType = reduceOp.getOperands()[i].getType();
3163 if (resultType != reductionOperandType)
3164 return reduceOp.emitOpError()
3165 <<
"expects type of " << i
3166 <<
"-th reduction operand: " << reductionOperandType
3167 <<
" to be the same as the " << i
3168 <<
"-th result type: " << resultType;
3173ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
3176 SmallVector<OpAsmParser::Argument, 4> ivs;
3181 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
3188 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
3196 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
3204 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
3215 Region *body =
result.addRegion();
3216 for (
auto &iv : ivs)
3223 ParallelOp::getOperandSegmentSizeAttr(),
3225 static_cast<int32_t>(upper.size()),
3226 static_cast<int32_t>(steps.size()),
3227 static_cast<int32_t>(initVals.size())}));
3236 ParallelOp::ensureTerminator(*body, builder,
result.location);
3240void ParallelOp::print(OpAsmPrinter &p) {
3241 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3242 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3243 if (!getInitVals().empty())
3244 p <<
" init (" << getInitVals() <<
")";
3249 (*this)->getAttrs(),
3250 ParallelOp::getOperandSegmentSizeAttr());
3253SmallVector<Region *> ParallelOp::getLoopRegions() {
return {&getRegion()}; }
3255std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3256 return SmallVector<Value>{getBody()->getArguments()};
3259std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3263std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3267std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3272 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3274 return ParallelOp();
3275 assert(ivArg.getOwner() &&
"unlinked block argument");
3276 auto *containingOp = ivArg.getOwner()->getParentOp();
3277 return dyn_cast<ParallelOp>(containingOp);
3282struct ParallelOpSingleOrZeroIterationDimsFolder
3286 LogicalResult matchAndRewrite(ParallelOp op,
3293 for (
auto [lb,
ub, step, iv] :
3294 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3295 op.getInductionVars())) {
3296 auto numIterations =
3298 if (numIterations.has_value()) {
3300 if (*numIterations == 0) {
3301 rewriter.
replaceOp(op, op.getInitVals());
3306 if (*numIterations == 1) {
3311 newLowerBounds.push_back(lb);
3312 newUpperBounds.push_back(ub);
3313 newSteps.push_back(step);
3316 if (newLowerBounds.size() == op.getLowerBound().size())
3319 if (newLowerBounds.empty()) {
3322 SmallVector<Value> results;
3323 results.reserve(op.getInitVals().size());
3324 for (
auto &bodyOp : op.getBody()->without_terminator())
3325 rewriter.
clone(bodyOp, mapping);
3326 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3327 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3328 Block &reduceBlock = reduceOp.getReductions()[i].front();
3329 auto initValIndex = results.size();
3330 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3334 rewriter.
clone(reduceBodyOp, mapping);
3337 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3338 results.push_back(
result);
3346 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3347 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3353 newOp.getRegion().begin(), mapping);
3354 rewriter.
replaceOp(op, newOp.getResults());
3359struct MergeNestedParallelLoops :
public OpRewritePattern<ParallelOp> {
3360 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3362 LogicalResult matchAndRewrite(ParallelOp op,
3363 PatternRewriter &rewriter)
const override {
3364 Block &outerBody = *op.getBody();
3368 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3373 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3374 llvm::is_contained(innerOp.getUpperBound(), val) ||
3375 llvm::is_contained(innerOp.getStep(), val))
3379 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3382 auto bodyBuilder = [&](OpBuilder &builder, Location ,
3384 Block &innerBody = *innerOp.getBody();
3385 assert(iterVals.size() ==
3393 builder.
clone(op, mapping);
3396 auto concatValues = [](
const auto &first,
const auto &second) {
3397 SmallVector<Value> ret;
3398 ret.reserve(first.size() + second.size());
3399 ret.assign(first.begin(), first.end());
3400 ret.append(second.begin(), second.end());
3404 auto newLowerBounds =
3405 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3406 auto newUpperBounds =
3407 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3408 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3419void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3420 MLIRContext *context) {
3422 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3431void ParallelOp::getSuccessorRegions(
3432 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3436 regions.push_back(RegionSuccessor(&getRegion()));
3437 regions.push_back(RegionSuccessor(
3438 getOperation(), ResultRange{getResults().end(), getResults().end()}));
3445void ReduceOp::build(OpBuilder &builder, OperationState &
result) {}
3447void ReduceOp::build(OpBuilder &builder, OperationState &
result,
3449 result.addOperands(operands);
3450 for (Value v : operands) {
3451 OpBuilder::InsertionGuard guard(builder);
3452 Region *bodyRegion =
result.addRegion();
3459LogicalResult ReduceOp::verifyRegions() {
3460 if (getReductions().size() != getOperands().size())
3461 return emitOpError() <<
"expects number of reduction regions: "
3462 << getReductions().size()
3463 <<
" to be the same as number of reduction operands: "
3464 << getOperands().size();
3467 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3468 auto type = getOperands()[i].getType();
3471 return emitOpError() << i <<
"-th reduction has an empty body";
3473 llvm::any_of(block.
getArguments(), [&](
const BlockArgument &arg) {
3474 return arg.getType() != type;
3476 return emitOpError() <<
"expected two block arguments with type " << type
3477 <<
" in the " << i <<
"-th reduction region";
3481 return emitOpError(
"reduction bodies must be terminated with an "
3482 "'scf.reduce.return' op");
3489ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3491 return MutableOperandRange(getOperation(), 0, 0);
3498LogicalResult ReduceReturnOp::verify() {
3501 Block *reductionBody = getOperation()->getBlock();
3503 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3505 if (expectedResultType != getResult().
getType())
3506 return emitOpError() <<
"must have type " << expectedResultType
3507 <<
" (the type of the reduction inputs)";
3515void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3516 ::mlir::OperationState &odsState,
TypeRange resultTypes,
3517 ValueRange inits, BodyBuilderFn beforeBuilder,
3518 BodyBuilderFn afterBuilder) {
3522 OpBuilder::InsertionGuard guard(odsBuilder);
3525 SmallVector<Location, 4> beforeArgLocs;
3526 beforeArgLocs.reserve(inits.size());
3527 for (Value operand : inits) {
3528 beforeArgLocs.push_back(operand.getLoc());
3531 Region *beforeRegion = odsState.
addRegion();
3533 inits.getTypes(), beforeArgLocs);
3538 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.
location);
3540 Region *afterRegion = odsState.
addRegion();
3542 resultTypes, afterArgLocs);
3548ConditionOp WhileOp::getConditionOp() {
3549 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3552YieldOp WhileOp::getYieldOp() {
3553 return cast<YieldOp>(getAfterBody()->getTerminator());
3556std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3557 return getYieldOp().getResultsMutable();
3561 return getBeforeBody()->getArguments();
3565 return getAfterBody()->getArguments();
3569 return getBeforeArguments();
3572OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3574 "WhileOp is expected to branch only to the first region");
3578void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3579 SmallVectorImpl<RegionSuccessor> ®ions) {
3582 regions.emplace_back(&getBefore(), getBefore().getArguments());
3586 assert(llvm::is_contained(
3587 {&getAfter(), &getBefore()},
3589 "there are only two regions in a WhileOp");
3593 regions.emplace_back(&getBefore(), getBefore().getArguments());
3597 regions.emplace_back(getOperation(), getResults());
3598 regions.emplace_back(&getAfter(), getAfter().getArguments());
3601SmallVector<Region *> WhileOp::getLoopRegions() {
3602 return {&getBefore(), &getAfter()};
3612ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
3613 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3614 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3615 Region *before =
result.addRegion();
3616 Region *after =
result.addRegion();
3618 OptionalParseResult listResult =
3623 FunctionType functionType;
3628 result.addTypes(functionType.getResults());
3630 if (functionType.getNumInputs() != operands.size()) {
3632 <<
"expected as many input types as operands " <<
"(expected "
3633 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3643 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3644 regionArgs[i].type = functionType.getInput(i);
3646 return failure(parser.
parseRegion(*before, regionArgs) ||
3652void scf::WhileOp::print(OpAsmPrinter &p) {
3666template <
typename OpTy>
3669 if (left.size() != right.size())
3670 return op.emitOpError(
"expects the same number of ") << message;
3672 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3673 if (left[i] != right[i]) {
3676 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3677 <<
" and " << right[i];
3685LogicalResult scf::WhileOp::verify() {
3688 "expects the 'before' region to terminate with 'scf.condition'");
3689 if (!beforeTerminator)
3694 "expects the 'after' region to terminate with 'scf.yield'");
3695 return success(afterTerminator !=
nullptr);
3732struct WhileMoveIfDown :
public OpRewritePattern<scf::WhileOp> {
3733 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3735 LogicalResult matchAndRewrite(scf::WhileOp op,
3736 PatternRewriter &rewriter)
const override {
3737 auto conditionOp = op.getConditionOp();
3745 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3751 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3752 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3755 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3756 *ifOp->user_begin() == conditionOp)) &&
3757 "ifOp has unexpected uses");
3759 Location loc = op.getLoc();
3763 for (
auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3764 auto it = llvm::find(ifOp->getResults(), arg);
3765 if (it != ifOp->getResults().end()) {
3766 size_t ifOpIdx = it.getIndex();
3767 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3768 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3778 if (&op.getBefore() == operand->get().getParentRegion())
3779 additionalUsedValuesSet.insert(operand->get());
3783 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3784 auto additionalValueTypes = llvm::map_to_vector(
3785 additionalUsedValues, [](Value val) {
return val.
getType(); });
3786 size_t additionalValueSize = additionalUsedValues.size();
3787 SmallVector<Type> newResultTypes(op.getResultTypes());
3788 newResultTypes.append(additionalValueTypes);
3791 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3794 newWhileOp.getBefore().takeBody(op.getBefore());
3795 newWhileOp.getAfter().takeBody(op.getAfter());
3796 newWhileOp.getAfter().addArguments(
3797 additionalValueTypes,
3798 SmallVector<Location>(additionalValueSize, loc));
3802 conditionOp.getArgsMutable().append(additionalUsedValues);
3808 additionalUsedValues,
3809 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3810 [&](OpOperand &use) {
3811 return ifOp.getThenRegion().isAncestor(
3812 use.getOwner()->getParentRegion());
3816 rewriter.
eraseOp(ifOp.thenYield());
3818 newWhileOp.getAfterBody()->begin());
3821 newWhileOp->getResults().drop_back(additionalValueSize));
3845struct WhileConditionTruth :
public OpRewritePattern<WhileOp> {
3846 using OpRewritePattern<WhileOp>::OpRewritePattern;
3848 LogicalResult matchAndRewrite(WhileOp op,
3849 PatternRewriter &rewriter)
const override {
3850 auto term = op.getConditionOp();
3854 Value constantTrue =
nullptr;
3856 bool replaced =
false;
3857 for (
auto yieldedAndBlockArgs :
3858 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3859 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3860 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3862 constantTrue = arith::ConstantOp::create(
3863 rewriter, op.getLoc(), term.getCondition().getType(),
3924struct RemoveLoopInvariantArgsFromBeforeBlock
3925 :
public OpRewritePattern<WhileOp> {
3926 using OpRewritePattern<WhileOp>::OpRewritePattern;
3928 LogicalResult matchAndRewrite(WhileOp op,
3929 PatternRewriter &rewriter)
const override {
3930 Block &afterBlock = *op.getAfterBody();
3932 ConditionOp condOp = op.getConditionOp();
3933 OperandRange condOpArgs = condOp.getArgs();
3937 bool canSimplify =
false;
3938 for (
const auto &it :
3939 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3940 auto index =
static_cast<unsigned>(it.index());
3941 auto [initVal, yieldOpArg] = it.value();
3944 if (yieldOpArg == initVal) {
3953 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3954 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3955 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3956 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3966 SmallVector<Value> newInitArgs, newYieldOpArgs;
3968 SmallVector<Location> newBeforeBlockArgLocs;
3969 for (
const auto &it :
3970 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3971 auto index =
static_cast<unsigned>(it.index());
3972 auto [initVal, yieldOpArg] = it.value();
3976 if (yieldOpArg == initVal) {
3977 beforeBlockInitValMap.insert({index, initVal});
3985 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3986 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3987 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3988 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3989 beforeBlockInitValMap.insert({index, initVal});
3994 newInitArgs.emplace_back(initVal);
3995 newYieldOpArgs.emplace_back(yieldOpArg);
3996 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
4000 OpBuilder::InsertionGuard g(rewriter);
4005 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
4009 &newWhile.getBefore(), {},
4010 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
4012 Block &beforeBlock = *op.getBeforeBody();
4019 for (
unsigned i = 0, j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
4022 if (beforeBlockInitValMap.count(i) != 0)
4023 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
4025 newBeforeBlockArgs[i] = newBeforeBlock.
getArgument(j++);
4028 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
4030 newWhile.getAfter().begin());
4032 rewriter.
replaceOp(op, newWhile.getResults());
4077struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
4078 using OpRewritePattern<WhileOp>::OpRewritePattern;
4080 LogicalResult matchAndRewrite(WhileOp op,
4081 PatternRewriter &rewriter)
const override {
4082 Block &beforeBlock = *op.getBeforeBody();
4083 ConditionOp condOp = op.getConditionOp();
4084 OperandRange condOpArgs = condOp.getArgs();
4086 bool canSimplify =
false;
4087 for (Value condOpArg : condOpArgs) {
4102 SmallVector<Value> newCondOpArgs;
4103 SmallVector<Type> newAfterBlockType;
4105 SmallVector<Location> newAfterBlockArgLocs;
4106 for (
const auto &it : llvm::enumerate(condOpArgs)) {
4107 auto index =
static_cast<unsigned>(it.index());
4108 Value condOpArg = it.value();
4113 condOpInitValMap.insert({index, condOpArg});
4115 newCondOpArgs.emplace_back(condOpArg);
4116 newAfterBlockType.emplace_back(condOpArg.
getType());
4117 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
4122 OpBuilder::InsertionGuard g(rewriter);
4128 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
4131 Block &newAfterBlock =
4133 newAfterBlockType, newAfterBlockArgLocs);
4135 Block &afterBlock = *op.getAfterBody();
4142 for (
unsigned i = 0, j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
4143 Value afterBlockArg,
result;
4146 if (condOpInitValMap.count(i) != 0) {
4147 afterBlockArg = condOpInitValMap[i];
4151 result = newWhile.getResult(j);
4154 newAfterBlockArgs[i] = afterBlockArg;
4155 newWhileResults[i] =
result;
4158 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4160 newWhile.getBefore().begin());
4162 rewriter.
replaceOp(op, newWhileResults);
4193struct WhileUnusedResult :
public OpRewritePattern<WhileOp> {
4194 using OpRewritePattern<WhileOp>::OpRewritePattern;
4196 LogicalResult matchAndRewrite(WhileOp op,
4197 PatternRewriter &rewriter)
const override {
4198 auto term = op.getConditionOp();
4199 auto afterArgs = op.getAfterArguments();
4200 auto termArgs = term.getArgs();
4203 SmallVector<unsigned> newResultsIndices;
4204 SmallVector<Type> newResultTypes;
4205 SmallVector<Value> newTermArgs;
4206 SmallVector<Location> newArgLocs;
4207 bool needUpdate =
false;
4208 for (
const auto &it :
4209 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
4210 auto i =
static_cast<unsigned>(it.index());
4211 Value
result = std::get<0>(it.value());
4212 Value afterArg = std::get<1>(it.value());
4213 Value termArg = std::get<2>(it.value());
4217 newResultsIndices.emplace_back(i);
4218 newTermArgs.emplace_back(termArg);
4219 newResultTypes.emplace_back(
result.getType());
4220 newArgLocs.emplace_back(
result.getLoc());
4228 OpBuilder::InsertionGuard g(rewriter);
4235 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
4238 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
4242 SmallVector<Value> newResults(op.getNumResults());
4243 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
4244 for (
const auto &it : llvm::enumerate(newResultsIndices)) {
4245 newResults[it.value()] = newWhile.getResult(it.index());
4246 newAfterBlockArgs[it.value()] = newAfterBlock.
getArgument(it.index());
4250 newWhile.getBefore().begin());
4252 Block &afterBlock = *op.getAfterBody();
4253 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4282struct WhileCmpCond :
public OpRewritePattern<scf::WhileOp> {
4283 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
4285 LogicalResult matchAndRewrite(scf::WhileOp op,
4286 PatternRewriter &rewriter)
const override {
4287 using namespace scf;
4288 auto cond = op.getConditionOp();
4289 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4293 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4294 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
4295 if (std::get<0>(tup) != cmp.getOperand(opIdx))
4298 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4299 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4303 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4306 if (cmp2.getPredicate() == cmp.getPredicate())
4307 samePredicate =
true;
4308 else if (cmp2.getPredicate() ==
4309 arith::invertPredicate(cmp.getPredicate()))
4310 samePredicate =
false;
4325struct WhileRemoveUnusedArgs :
public OpRewritePattern<WhileOp> {
4326 using OpRewritePattern<WhileOp>::OpRewritePattern;
4328 LogicalResult matchAndRewrite(WhileOp op,
4329 PatternRewriter &rewriter)
const override {
4331 if (!llvm::any_of(op.getBeforeArguments(),
4332 [](Value arg) { return arg.use_empty(); }))
4335 YieldOp yield = op.getYieldOp();
4338 SmallVector<Value> newYields;
4339 SmallVector<Value> newInits;
4340 llvm::BitVector argsToErase;
4342 size_t argsCount = op.getBeforeArguments().size();
4343 newYields.reserve(argsCount);
4344 newInits.reserve(argsCount);
4345 argsToErase.reserve(argsCount);
4346 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4347 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4348 if (beforeArg.use_empty()) {
4349 argsToErase.push_back(
true);
4351 argsToErase.push_back(
false);
4352 newYields.emplace_back(yieldValue);
4353 newInits.emplace_back(initValue);
4357 Block &beforeBlock = *op.getBeforeBody();
4358 Block &afterBlock = *op.getAfterBody();
4362 Location loc = op.getLoc();
4364 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4366 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4367 Block &newAfterBlock = *newWhileOp.getAfterBody();
4369 OpBuilder::InsertionGuard g(rewriter);
4373 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4378 rewriter.
replaceOp(op, newWhileOp.getResults());
4384struct WhileRemoveDuplicatedResults :
public OpRewritePattern<WhileOp> {
4387 LogicalResult matchAndRewrite(WhileOp op,
4388 PatternRewriter &rewriter)
const override {
4389 ConditionOp condOp = op.getConditionOp();
4392 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4394 if (argsSet.size() == condOpArgs.size())
4397 llvm::SmallDenseMap<Value, unsigned> argsMap;
4398 SmallVector<Value> newArgs;
4399 argsMap.reserve(condOpArgs.size());
4400 newArgs.reserve(condOpArgs.size());
4401 for (Value arg : condOpArgs) {
4402 if (!argsMap.count(arg)) {
4403 auto pos =
static_cast<unsigned>(argsMap.size());
4404 argsMap.insert({arg, pos});
4405 newArgs.emplace_back(arg);
4411 Location loc = op.getLoc();
4413 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4416 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4417 Block &newAfterBlock = *newWhileOp.getAfterBody();
4419 SmallVector<Value> afterArgsMapping;
4420 SmallVector<Value> resultsMapping;
4421 for (
auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4422 auto it = argsMap.find(arg);
4423 assert(it != argsMap.end());
4424 auto pos = it->second;
4425 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4426 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4429 OpBuilder::InsertionGuard g(rewriter);
4434 Block &beforeBlock = *op.getBeforeBody();
4435 Block &afterBlock = *op.getAfterBody();
4437 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4439 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4447static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4449 if (args1.size() != args2.size())
4450 return std::nullopt;
4452 SmallVector<unsigned> ret(args1.size());
4453 for (
auto &&[i, arg1] : llvm::enumerate(args1)) {
4454 auto it = llvm::find(args2, arg1);
4455 if (it == args2.end())
4456 return std::nullopt;
4458 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4465 llvm::SmallDenseSet<Value> set;
4466 for (Value arg : args) {
4467 if (!set.insert(arg).second)
4477struct WhileOpAlignBeforeArgs :
public OpRewritePattern<WhileOp> {
4480 LogicalResult matchAndRewrite(WhileOp loop,
4481 PatternRewriter &rewriter)
const override {
4482 auto *oldBefore = loop.getBeforeBody();
4483 ConditionOp oldTerm = loop.getConditionOp();
4484 ValueRange beforeArgs = oldBefore->getArguments();
4486 if (beforeArgs == termArgs)
4489 if (hasDuplicates(termArgs))
4492 auto mapping = getArgsMapping(beforeArgs, termArgs);
4497 OpBuilder::InsertionGuard g(rewriter);
4503 auto *oldAfter = loop.getAfterBody();
4505 SmallVector<Type> newResultTypes(beforeArgs.size());
4506 for (
auto &&[i, j] : llvm::enumerate(*mapping))
4507 newResultTypes[j] = loop.getResult(i).getType();
4509 auto newLoop = WhileOp::create(
4510 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4512 auto *newBefore = newLoop.getBeforeBody();
4513 auto *newAfter = newLoop.getAfterBody();
4515 SmallVector<Value> newResults(beforeArgs.size());
4516 SmallVector<Value> newAfterArgs(beforeArgs.size());
4517 for (
auto &&[i, j] : llvm::enumerate(*mapping)) {
4518 newResults[i] = newLoop.getResult(j);
4519 newAfterArgs[i] = newAfter->getArgument(j);
4523 newBefore->getArguments());
4533void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4534 MLIRContext *context) {
4535 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4536 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4537 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4538 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4553 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4556 caseValues.push_back(value);
4565 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4567 p <<
"case " << value <<
' ';
4572LogicalResult scf::IndexSwitchOp::verify() {
4573 if (getCases().size() != getCaseRegions().size()) {
4575 << getCaseRegions().size() <<
" case regions but "
4576 << getCases().size() <<
" case values";
4580 for (int64_t value : getCases())
4581 if (!valueSet.insert(value).second)
4582 return emitOpError(
"has duplicate case value: ") << value;
4583 auto verifyRegion = [&](Region ®ion,
const Twine &name) -> LogicalResult {
4584 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4586 return emitOpError(
"expected region to end with scf.yield, but got ")
4589 if (yield.getNumOperands() != getNumResults()) {
4590 return (
emitOpError(
"expected each region to return ")
4591 << getNumResults() <<
" values, but " << name <<
" returns "
4592 << yield.getNumOperands())
4593 .attachNote(yield.getLoc())
4594 <<
"see yield operation here";
4596 for (
auto [idx,
result, operand] :
4597 llvm::enumerate(getResultTypes(), yield.getOperands())) {
4599 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
4600 if (
result == operand.getType())
4603 << idx <<
" of each region to be " <<
result)
4604 .attachNote(yield.getLoc())
4605 << name <<
" returns " << operand.getType() <<
" here";
4612 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4619unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4621Block &scf::IndexSwitchOp::getDefaultBlock() {
4622 return getDefaultRegion().front();
4625Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4626 assert(idx < getNumCases() &&
"case index out-of-bounds");
4627 return getCaseRegions()[idx].front();
4630void IndexSwitchOp::getSuccessorRegions(
4631 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4634 successors.emplace_back(getOperation(), getResults());
4638 llvm::append_range(successors, getRegions());
4641void IndexSwitchOp::getEntrySuccessorRegions(
4642 ArrayRef<Attribute> operands,
4643 SmallVectorImpl<RegionSuccessor> &successors) {
4644 FoldAdaptor adaptor(operands, *
this);
4647 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4649 llvm::append_range(successors, getRegions());
4655 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4656 if (caseValue == arg.getInt()) {
4657 successors.emplace_back(&caseRegion);
4661 successors.emplace_back(&getDefaultRegion());
4664void IndexSwitchOp::getRegionInvocationBounds(
4665 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4666 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4667 if (!operandValue) {
4669 bounds.append(getNumRegions(), InvocationBounds(0, 1));
4673 unsigned liveIndex = getNumRegions() - 1;
4674 const auto *it = llvm::find(getCases(), operandValue.getInt());
4675 if (it != getCases().end())
4676 liveIndex = std::distance(getCases().begin(), it);
4677 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4678 bounds.emplace_back(0, i == liveIndex);
4689 if (!maybeCst.has_value())
4692 int64_t caseIdx, e = op.getNumCases();
4693 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4694 if (cst == op.getCases()[caseIdx])
4698 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4699 : op.getDefaultRegion();
4714void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4715 MLIRContext *context) {
4716 results.
add<FoldConstantCase>(context);
4723#define GET_OP_CLASSES
4724#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.