29#include "llvm/ADT/MapVector.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallPtrSet.h"
32#include "llvm/Support/Casting.h"
33#include "llvm/Support/DebugLog.h"
39#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
51 IRMapping &valueMapping)
const final {
56 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
61 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
62 auto retValOp = dyn_cast<scf::YieldOp>(op);
66 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
67 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
77void SCFDialect::initialize() {
80#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
82 addInterfaces<SCFInlinerInterface>();
83 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
84 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
85 InParallelOp, ReduceReturnOp>();
86 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
87 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
88 ForallOp, InParallelOp, WhileOp, YieldOp>();
89 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
94 scf::YieldOp::create(builder, loc);
99template <
typename TerminatorTy>
101 StringRef errorMessage) {
102 Operation *terminatorOperation =
nullptr;
104 terminatorOperation = ®ion.
front().
back();
105 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
109 if (terminatorOperation)
110 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
117 auto addOp =
ub.getDefiningOp<arith::AddIOp>();
120 if ((isSigned && !addOp.hasNoSignedWrap()) ||
121 (!isSigned && !addOp.hasNoUnsignedWrap()))
124 if (addOp.getLhs() != lb ||
138 assert(region.
hasOneBlock() &&
"expected single-block region");
158ParseResult ExecuteRegionOp::parse(
OpAsmParser &parser,
186LogicalResult ExecuteRegionOp::verify() {
187 if (getRegion().empty())
188 return emitOpError(
"region needs to have at least one block");
189 if (getRegion().front().getNumArguments() > 0)
190 return emitOpError(
"region cannot have any arguments");
213 if (!op.getRegion().hasOneBlock() || op.getNoInline())
262 if (op.getNoInline())
264 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
267 Block *prevBlock = op->getBlock();
271 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
273 for (
Block &blk : op.getRegion()) {
274 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
276 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
277 yieldOp.getResults());
285 for (
auto res : op.getResults())
286 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
303 if (op.getNumResults() == 0)
307 for (
Block &block : op.getRegion()) {
308 if (
auto yield = dyn_cast<scf::YieldOp>(block.
getTerminator()))
309 yieldOps.push_back(yield.getOperation());
312 if (yieldOps.empty())
316 auto yieldOpsOperands = yieldOps[0]->getOperands();
317 for (
auto *yieldOp : yieldOps) {
318 if (yieldOp->getOperands() != yieldOpsOperands)
326 for (
auto [
index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
327 if (isValueFromInsideRegion(yieldedValue, op)) {
328 internalValues.push_back(yieldedValue);
329 opResultsToKeep.push_back(op.getResult(
index));
331 externalValues.push_back(yieldedValue);
332 opResultsToReplaceWithExternalValues.push_back(op.getResult(
index));
336 if (externalValues.empty())
342 for (
Value value : internalValues)
343 resultTypes.push_back(value.getType());
345 ExecuteRegionOp::create(rewriter, op.getLoc(),
TypeRange(resultTypes));
346 newOp->setAttrs(op->getAttrs());
350 newOp.getRegion().end());
354 for (
auto *yieldOp : yieldOps) {
371 bool isValueFromInsideRegion(
Value value,
372 ExecuteRegionOp executeRegionOp)
const {
375 return &executeRegionOp.getRegion() == defOp->getParentRegion();
379 return &executeRegionOp.getRegion() == blockArg.getParentRegion();
391void ExecuteRegionOp::getSuccessorRegions(
411 "condition op can only exit the loop or branch to the after"
414 return getArgsMutable();
417void ConditionOp::getSuccessorRegions(
419 FoldAdaptor adaptor(operands, *
this);
421 WhileOp whileOp = getParentOp();
425 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
426 if (!boolAttr || boolAttr.getValue())
427 regions.emplace_back(&whileOp.getAfter(),
428 whileOp.getAfter().getArguments());
429 if (!boolAttr || !boolAttr.getValue())
430 regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
439 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
443 result.addAttribute(getUnsignedCmpAttrName(
result.name),
446 result.addOperands(initArgs);
447 for (
Value v : initArgs)
448 result.addTypes(v.getType());
453 for (
Value v : initArgs)
459 if (initArgs.empty() && !bodyBuilder) {
460 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
461 }
else if (bodyBuilder) {
469LogicalResult ForOp::verify() {
471 if (getInitArgs().size() != getNumResults())
473 "mismatch in number of loop-carried values and defined values");
478LogicalResult ForOp::verifyRegions() {
483 "expected induction variable to be same type as bounds and step");
485 if (getNumRegionIterArgs() != getNumResults())
487 "mismatch in number of basic block args and defined values");
489 auto initArgs = getInitArgs();
490 auto iterArgs = getRegionIterArgs();
491 auto opResults = getResults();
493 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
495 return emitOpError() <<
"types mismatch between " << i
496 <<
"th iter operand and defined value";
498 return emitOpError() <<
"types mismatch between " << i
499 <<
"th iter region arg and defined value";
506std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
510std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
514std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
518std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
522std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
526LogicalResult ForOp::promoteIfSingleIteration(
RewriterBase &rewriter) {
527 std::optional<APInt> tripCount = getStaticTripCount();
528 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
531 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
534 if (*tripCount == 0) {
541 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
548 llvm::append_range(bbArgReplacements, getInitArgs());
552 getOperation()->getIterator(), bbArgReplacements);
568 StringRef prefix =
"") {
569 assert(blocksArgs.size() == initializers.size() &&
570 "expected same length of arguments and initializers");
571 if (initializers.empty())
575 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
576 p << std::get<0>(it) <<
" = " << std::get<1>(it);
582 if (getUnsignedCmp())
585 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
589 if (!getInitArgs().empty())
590 p <<
" -> (" << getInitArgs().getTypes() <<
')';
593 p <<
" : " << t <<
' ';
596 !getInitArgs().empty());
598 getUnsignedCmpAttrName().strref());
609 result.addAttribute(getUnsignedCmpAttrName(
result.name),
623 regionArgs.push_back(inductionVariable);
633 if (regionArgs.size() !=
result.types.size() + 1)
636 "mismatch in number of loop-carried values and defined values");
645 regionArgs.front().type = type;
646 for (
auto [iterArg, type] :
647 llvm::zip_equal(llvm::drop_begin(regionArgs),
result.types))
654 ForOp::ensureTerminator(*body, builder,
result.location);
663 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
664 operands,
result.types)) {
665 Type type = std::get<2>(argOperandType);
666 std::get<0>(argOperandType).type = type;
683 return getBody()->getArguments().drop_front(getNumInductionVars());
687 return getInitArgsMutable();
690FailureOr<LoopLikeOpInterface>
691ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
693 bool replaceInitOperandUsesInLoop,
698 auto inits = llvm::to_vector(getInitArgs());
699 inits.append(newInitOperands.begin(), newInitOperands.end());
700 scf::ForOp newLoop = scf::ForOp::create(
706 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
708 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
713 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
714 assert(newInitOperands.size() == newYieldedValues.size() &&
715 "expected as many new yield values as new iter operands");
717 yieldOp.getResultsMutable().append(newYieldedValues);
723 newLoop.getBody()->getArguments().take_front(
724 getBody()->getNumArguments()));
726 if (replaceInitOperandUsesInLoop) {
729 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
740 newLoop->getResults().take_front(getNumResults()));
741 return cast<LoopLikeOpInterface>(newLoop.getOperation());
745 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
748 assert(ivArg.getOwner() &&
"unlinked block argument");
749 auto *containingOp = ivArg.getOwner()->getParentOp();
750 return dyn_cast_or_null<ForOp>(containingOp);
754 return getInitArgs();
770LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
771 for (
auto [lb, ub, step] :
772 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
775 if (!tripCount.has_value() || *tripCount != 1)
784 return getBody()->getArguments().drop_front(getRank());
787MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
788 return getOutputsMutable();
794 scf::InParallelOp terminator = forallOp.getTerminator();
799 bbArgReplacements.append(forallOp.getOutputs().begin(),
800 forallOp.getOutputs().end());
804 forallOp->getIterator(), bbArgReplacements);
809 results.reserve(forallOp.getResults().size());
810 for (
auto &yieldingOp : terminator.getYieldingOps()) {
811 auto parallelInsertSliceOp =
812 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
813 if (!parallelInsertSliceOp)
816 Value dst = parallelInsertSliceOp.getDest();
817 Value src = parallelInsertSliceOp.getSource();
818 if (llvm::isa<TensorType>(src.
getType())) {
819 results.push_back(tensor::InsertSliceOp::create(
820 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
821 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
822 parallelInsertSliceOp.getStrides(),
823 parallelInsertSliceOp.getStaticOffsets(),
824 parallelInsertSliceOp.getStaticSizes(),
825 parallelInsertSliceOp.getStaticStrides()));
827 llvm_unreachable(
"unsupported terminator");
842 assert(lbs.size() == ubs.size() &&
843 "expected the same number of lower and upper bounds");
844 assert(lbs.size() == steps.size() &&
845 "expected the same number of lower bounds and steps");
850 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
852 assert(results.size() == iterArgs.size() &&
853 "loop nest body must return as many values as loop has iteration "
855 return LoopNest{{}, std::move(results)};
863 loops.reserve(lbs.size());
864 ivs.reserve(lbs.size());
867 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
868 auto loop = scf::ForOp::create(
869 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
875 currentIterArgs = args;
876 currentLoc = nestedLoc;
882 loops.push_back(loop);
886 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
888 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
895 ? bodyBuilder(builder, currentLoc, ivs,
896 loops.back().getRegionIterArgs())
898 assert(results.size() == iterArgs.size() &&
899 "loop nest body must return as many values as loop has iteration "
902 scf::YieldOp::create(builder, loc, results);
906 llvm::append_range(nestResults, loops.front().getResults());
907 return LoopNest{std::move(loops), std::move(nestResults)};
920 bodyBuilder(nestedBuilder, nestedLoc, ivs);
929 assert(operand.
getOwner() == forOp);
934 "expected an iter OpOperand");
936 "Expected a different type");
938 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
943 newIterOperands.push_back(opOperand.get());
947 scf::ForOp newForOp = scf::ForOp::create(
948 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
949 forOp.getStep(), newIterOperands,
nullptr,
950 forOp.getUnsignedCmp());
951 newForOp->setAttrs(forOp->getAttrs());
952 Block &newBlock = newForOp.getRegion().
front();
960 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
962 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
963 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
967 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
970 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
973 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
974 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
975 clonedYieldOp.getOperand(yieldIdx));
977 newYieldOperands[yieldIdx] = castOut;
978 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
979 rewriter.
eraseOp(clonedYieldOp);
984 newResults[yieldIdx] =
985 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
1005 LogicalResult matchAndRewrite(scf::ForOp forOp,
1007 bool canonicalize =
false;
1014 int64_t numResults = forOp.getNumResults();
1016 keepMask.reserve(numResults);
1019 newBlockTransferArgs.reserve(1 + numResults);
1020 newBlockTransferArgs.push_back(
Value());
1021 newIterArgs.reserve(forOp.getInitArgs().size());
1022 newYieldValues.reserve(numResults);
1023 newResultValues.reserve(numResults);
1025 for (
auto [init, arg,
result, yielded] :
1026 llvm::zip(forOp.getInitArgs(),
1027 forOp.getRegionIterArgs(),
1029 forOp.getYieldedValues()
1036 bool forwarded = (arg == yielded) || (init == yielded) ||
1037 (arg.use_empty() &&
result.use_empty());
1039 canonicalize =
true;
1040 keepMask.push_back(
false);
1041 newBlockTransferArgs.push_back(init);
1042 newResultValues.push_back(init);
1048 if (
auto it = initYieldToArg.find({init, yielded});
1049 it != initYieldToArg.end()) {
1050 canonicalize =
true;
1051 keepMask.push_back(
false);
1052 auto [sameArg, sameResult] = it->second;
1056 newBlockTransferArgs.push_back(init);
1057 newResultValues.push_back(init);
1062 initYieldToArg.insert({{init, yielded}, {arg,
result}});
1063 keepMask.push_back(
true);
1064 newIterArgs.push_back(init);
1065 newYieldValues.push_back(yielded);
1066 newBlockTransferArgs.push_back(Value());
1067 newResultValues.push_back(Value());
1073 scf::ForOp newForOp =
1074 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
1075 forOp.getUpperBound(), forOp.getStep(), newIterArgs,
1076 nullptr, forOp.getUnsignedCmp());
1077 newForOp->setAttrs(forOp->getAttrs());
1078 Block &newBlock = newForOp.getRegion().front();
1081 newBlockTransferArgs[0] = newBlock.
getArgument(0);
1082 for (
unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
1084 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
1085 Value &newResultVal = newResultValues[idx];
1086 assert((blockTransferArg && newResultVal) ||
1087 (!blockTransferArg && !newResultVal));
1088 if (!blockTransferArg) {
1089 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
1090 newResultVal = newForOp.getResult(collapsedIdx++);
1094 Block &oldBlock = forOp.getRegion().front();
1096 "unexpected argument size mismatch");
1101 if (newIterArgs.empty()) {
1102 auto newYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1105 rewriter.
replaceOp(forOp, newResultValues);
1110 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1111 OpBuilder::InsertionGuard g(rewriter);
1113 SmallVector<Value, 4> filteredOperands;
1114 filteredOperands.reserve(newResultValues.size());
1115 for (
unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1117 filteredOperands.push_back(mergedTerminator.getOperand(idx));
1118 scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1122 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1123 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
1124 cloneFilteredTerminator(mergedYieldOp);
1125 rewriter.
eraseOp(mergedYieldOp);
1126 rewriter.
replaceOp(forOp, newResultValues);
1134struct SimplifyTrivialLoops :
public OpRewritePattern<ForOp> {
1135 using OpRewritePattern<ForOp>::OpRewritePattern;
1137 LogicalResult matchAndRewrite(ForOp op,
1138 PatternRewriter &rewriter)
const override {
1139 std::optional<APInt> tripCount = op.getStaticTripCount();
1140 if (!tripCount.has_value())
1142 "can't compute constant trip count");
1144 if (tripCount->isZero()) {
1145 LDBG() <<
"SimplifyTrivialLoops tripCount is 0 for loop "
1146 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1147 rewriter.
replaceOp(op, op.getInitArgs());
1151 if (tripCount->getSExtValue() == 1) {
1152 LDBG() <<
"SimplifyTrivialLoops tripCount is 1 for loop "
1153 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1154 SmallVector<Value, 4> blockArgs;
1155 blockArgs.reserve(op.getInitArgs().size() + 1);
1156 blockArgs.push_back(op.getLowerBound());
1157 llvm::append_range(blockArgs, op.getInitArgs());
1164 if (!llvm::hasSingleElement(block))
1168 if (llvm::any_of(op.getYieldedValues(),
1169 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1171 LDBG() <<
"SimplifyTrivialLoops empty body loop allows replacement with "
1172 "yield operands for loop "
1173 << OpWithFlags(op, OpPrintingFlags().skipRegions());
1174 rewriter.
replaceOp(op, op.getYieldedValues());
1205struct ForOpTensorCastFolder :
public OpRewritePattern<ForOp> {
1206 using OpRewritePattern<ForOp>::OpRewritePattern;
1208 LogicalResult matchAndRewrite(ForOp op,
1209 PatternRewriter &rewriter)
const override {
1210 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1211 OpOperand &iterOpOperand = std::get<0>(it);
1213 if (!incomingCast ||
1214 incomingCast.getSource().getType() == incomingCast.getType())
1219 incomingCast.getDest().getType(),
1220 incomingCast.getSource().getType()))
1222 if (!std::get<1>(it).hasOneUse())
1228 rewriter, op, iterOpOperand, incomingCast.getSource(),
1229 [](OpBuilder &
b, Location loc, Type type, Value source) {
1230 return tensor::CastOp::create(b, loc, type, source);
1240void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1241 MLIRContext *context) {
1242 results.
add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1246std::optional<APInt> ForOp::getConstantStep() {
1249 return step.getValue();
1253std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1254 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1260 if (
auto constantStep = getConstantStep())
1261 if (*constantStep == 1)
1269std::optional<APInt> ForOp::getStaticTripCount() {
1278LogicalResult ForallOp::verify() {
1279 unsigned numLoops = getRank();
1281 if (getNumResults() != getOutputs().size())
1283 << getNumResults() <<
" results, but has only "
1284 << getOutputs().size() <<
" outputs";
1287 auto *body = getBody();
1289 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1290 for (int64_t i = 0; i < numLoops; ++i)
1293 << i <<
"-th block argument to be an index";
1294 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1297 << i <<
"-th output and corresponding block argument";
1298 if (getMapping().has_value() && !getMapping()->empty()) {
1299 if (getDeviceMappingAttrs().size() != numLoops)
1300 return emitOpError() <<
"mapping attribute size must match op rank";
1301 if (
failed(getDeviceMaskingAttr()))
1303 <<
" supports at most one device masking attribute";
1307 Operation *op = getOperation();
1309 getStaticLowerBound(),
1310 getDynamicLowerBound())))
1313 getStaticUpperBound(),
1314 getDynamicUpperBound())))
1317 getStaticStep(), getDynamicStep())))
1323void ForallOp::print(OpAsmPrinter &p) {
1324 Operation *op = getOperation();
1325 p <<
" (" << getInductionVars();
1326 if (isNormalized()) {
1347 if (!getRegionOutArgs().empty())
1348 p <<
"-> (" << getResultTypes() <<
") ";
1349 p.printRegion(getRegion(),
1351 getNumResults() > 0);
1352 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1353 getStaticLowerBoundAttrName(),
1354 getStaticUpperBoundAttrName(),
1355 getStaticStepAttrName()});
1358ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &
result) {
1360 auto indexType =
b.getIndexType();
1365 SmallVector<OpAsmParser::Argument, 4> ivs;
1370 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1380 unsigned numLoops = ivs.size();
1381 staticLbs =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1382 staticSteps =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1411 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1412 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1415 if (outOperands.size() !=
result.types.size())
1417 "mismatch between out operands and types");
1426 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1427 std::unique_ptr<Region> region = std::make_unique<Region>();
1428 for (
auto &iv : ivs) {
1429 iv.type =
b.getIndexType();
1430 regionArgs.push_back(iv);
1432 for (
const auto &it : llvm::enumerate(regionOutArgs)) {
1433 auto &out = it.value();
1434 out.type =
result.types[it.index()];
1435 regionArgs.push_back(out);
1441 ForallOp::ensureTerminator(*region,
b,
result.location);
1442 result.addRegion(std::move(region));
1448 result.addAttribute(
"staticLowerBound", staticLbs);
1449 result.addAttribute(
"staticUpperBound", staticUbs);
1450 result.addAttribute(
"staticStep", staticSteps);
1451 result.addAttribute(
"operandSegmentSizes",
1453 {static_cast<int32_t>(dynamicLbs.size()),
1454 static_cast<int32_t>(dynamicUbs.size()),
1455 static_cast<int32_t>(dynamicSteps.size()),
1456 static_cast<int32_t>(outOperands.size())}));
1461void ForallOp::build(
1462 mlir::OpBuilder &
b, mlir::OperationState &
result,
1463 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1464 ArrayRef<OpFoldResult> steps,
ValueRange outputs,
1465 std::optional<ArrayAttr> mapping,
1467 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1468 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1473 result.addOperands(dynamicLbs);
1474 result.addOperands(dynamicUbs);
1475 result.addOperands(dynamicSteps);
1476 result.addOperands(outputs);
1479 result.addAttribute(getStaticLowerBoundAttrName(
result.name),
1480 b.getDenseI64ArrayAttr(staticLbs));
1481 result.addAttribute(getStaticUpperBoundAttrName(
result.name),
1482 b.getDenseI64ArrayAttr(staticUbs));
1483 result.addAttribute(getStaticStepAttrName(
result.name),
1484 b.getDenseI64ArrayAttr(staticSteps));
1486 "operandSegmentSizes",
1487 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1488 static_cast<int32_t>(dynamicUbs.size()),
1489 static_cast<int32_t>(dynamicSteps.size()),
1490 static_cast<int32_t>(outputs.size())}));
1491 if (mapping.has_value()) {
1492 result.addAttribute(ForallOp::getMappingAttrName(
result.name),
1496 Region *bodyRegion =
result.addRegion();
1497 OpBuilder::InsertionGuard g(
b);
1498 b.createBlock(bodyRegion);
1503 SmallVector<Type>(lbs.size(),
b.getIndexType()),
1504 SmallVector<Location>(staticLbs.size(),
result.location));
1507 SmallVector<Location>(outputs.size(),
result.location));
1509 b.setInsertionPointToStart(&bodyBlock);
1510 if (!bodyBuilderFn) {
1511 ForallOp::ensureTerminator(*bodyRegion,
b,
result.location);
1518void ForallOp::build(
1519 mlir::OpBuilder &
b, mlir::OperationState &
result,
1520 ArrayRef<OpFoldResult> ubs,
ValueRange outputs,
1521 std::optional<ArrayAttr> mapping,
1523 unsigned numLoops = ubs.size();
1524 SmallVector<OpFoldResult> lbs(numLoops,
b.getIndexAttr(0));
1525 SmallVector<OpFoldResult> steps(numLoops,
b.getIndexAttr(1));
1526 build(
b,
result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1530bool ForallOp::isNormalized() {
1531 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1532 return llvm::all_of(results, [&](OpFoldResult ofr) {
1534 return intValue.has_value() && intValue == val;
1537 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1540InParallelOp ForallOp::getTerminator() {
1541 return cast<InParallelOp>(getBody()->getTerminator());
1544SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1545 SmallVector<Operation *> storeOps;
1546 for (Operation *user : bbArg.
getUsers()) {
1547 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1548 storeOps.push_back(parallelOp);
1554SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1555 SmallVector<DeviceMappingAttrInterface> res;
1558 for (
auto attr : getMapping()->getValue()) {
1559 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1566FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1567 DeviceMaskingAttrInterface res;
1570 for (
auto attr : getMapping()->getValue()) {
1571 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1580bool ForallOp::usesLinearMapping() {
1581 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1584 return ifaces.front().isLinearMapping();
1587std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1588 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1592std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1594 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(),
b);
1598std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1600 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(),
b);
1604std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1610 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1613 assert(tidxArg.getOwner() &&
"unlinked block argument");
1614 auto *containingOp = tidxArg.getOwner()->getParentOp();
1615 return dyn_cast<ForallOp>(containingOp);
1623 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1625 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1629 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1632 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1637class ForallOpControlOperandsFolder :
public OpRewritePattern<ForallOp> {
1639 using OpRewritePattern<ForallOp>::OpRewritePattern;
1641 LogicalResult matchAndRewrite(ForallOp op,
1642 PatternRewriter &rewriter)
const override {
1643 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1644 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1645 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1652 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1653 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1656 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1657 op.setStaticLowerBound(staticLowerBound);
1661 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1662 op.setStaticUpperBound(staticUpperBound);
1665 op.getDynamicStepMutable().assign(dynamicStep);
1666 op.setStaticStep(staticStep);
1668 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1670 {static_cast<int32_t>(dynamicLowerBound.size()),
1671 static_cast<int32_t>(dynamicUpperBound.size()),
1672 static_cast<int32_t>(dynamicStep.size()),
1673 static_cast<int32_t>(op.getNumResults())}));
1752struct ForallOpIterArgsFolder :
public OpRewritePattern<ForallOp> {
1753 using OpRewritePattern<ForallOp>::OpRewritePattern;
1755 LogicalResult matchAndRewrite(ForallOp forallOp,
1756 PatternRewriter &rewriter)
const final {
1772 SmallVector<Value> resultToReplace;
1773 SmallVector<Value> newOuts;
1774 for (OpResult
result : forallOp.getResults()) {
1775 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1776 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1777 if (
result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1778 resultToDelete.insert(
result);
1780 resultToReplace.push_back(
result);
1781 newOuts.push_back(opOperand->
get());
1787 if (resultToDelete.empty())
1795 for (OpResult
result : resultToDelete) {
1796 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1797 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1798 SmallVector<Operation *> combiningOps =
1799 forallOp.getCombiningOps(blockArg);
1800 for (Operation *combiningOp : combiningOps)
1801 rewriter.
eraseOp(combiningOp);
1806 auto newForallOp = scf::ForallOp::create(
1807 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1808 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1809 forallOp.getMapping(),
1814 Block *loopBody = forallOp.getBody();
1815 Block *newLoopBody = newForallOp.getBody();
1816 ArrayRef<BlockArgument> newBbArgs = newLoopBody->
getArguments();
1819 SmallVector<Value> newBlockArgs =
1820 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1821 [](BlockArgument
b) -> Value { return b; });
1827 for (OpResult
result : forallOp.getResults()) {
1828 if (resultToDelete.count(
result)) {
1829 newBlockArgs.push_back(forallOp.getTiedOpOperand(
result)->get());
1831 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1834 rewriter.
mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1838 for (
auto &&[oldResult, newResult] :
1839 llvm::zip(resultToReplace, newForallOp->getResults()))
1845 for (OpResult oldResult : resultToDelete)
1847 forallOp.getTiedOpOperand(oldResult)->get());
1852struct ForallOpSingleOrZeroIterationDimsFolder
1853 :
public OpRewritePattern<ForallOp> {
1854 using OpRewritePattern<ForallOp>::OpRewritePattern;
1856 LogicalResult matchAndRewrite(ForallOp op,
1857 PatternRewriter &rewriter)
const override {
1859 if (op.getMapping().has_value() && !op.getMapping()->empty())
1861 Location loc = op.getLoc();
1864 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1867 for (
auto [lb, ub, step, iv] :
1868 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1869 op.getMixedStep(), op.getInductionVars())) {
1870 auto numIterations =
1872 if (numIterations.has_value()) {
1874 if (*numIterations == 0) {
1875 rewriter.
replaceOp(op, op.getOutputs());
1880 if (*numIterations == 1) {
1885 newMixedLowerBounds.push_back(lb);
1886 newMixedUpperBounds.push_back(ub);
1887 newMixedSteps.push_back(step);
1891 if (newMixedLowerBounds.empty()) {
1897 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1899 op,
"no dimensions have 0 or 1 iterations");
1904 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1905 newMixedUpperBounds, newMixedSteps,
1906 op.getOutputs(), std::nullopt,
nullptr);
1907 newOp.getBodyRegion().getBlocks().clear();
1911 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1912 newOp.getStaticLowerBoundAttrName(),
1913 newOp.getStaticUpperBoundAttrName(),
1914 newOp.getStaticStepAttrName()};
1915 for (
const auto &namedAttr : op->getAttrs()) {
1916 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1919 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1923 newOp.getRegion().begin(), mapping);
1924 rewriter.
replaceOp(op, newOp.getResults());
1930struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1931 using OpRewritePattern<ForallOp>::OpRewritePattern;
1933 LogicalResult matchAndRewrite(ForallOp op,
1934 PatternRewriter &rewriter)
const override {
1935 Location loc = op.getLoc();
1937 for (
auto [lb, ub, step, iv] :
1938 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1939 op.getMixedStep(), op.getInductionVars())) {
1942 auto numIterations =
1944 if (!numIterations.has_value() || numIterations.value() != 1) {
1955struct FoldTensorCastOfOutputIntoForallOp
1956 :
public OpRewritePattern<scf::ForallOp> {
1957 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1964 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1965 PatternRewriter &rewriter)
const final {
1966 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1967 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1968 for (
auto en : llvm::enumerate(newOutputTensors)) {
1969 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1976 castOp.getSource().getType())) {
1980 tensorCastProducers[en.index()] =
1981 TypeCast{castOp.getSource().getType(), castOp.getType()};
1982 newOutputTensors[en.index()] = castOp.getSource();
1985 if (tensorCastProducers.empty())
1989 Location loc = forallOp.getLoc();
1990 auto newForallOp = ForallOp::create(
1991 rewriter, loc, forallOp.getMixedLowerBound(),
1992 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1993 newOutputTensors, forallOp.getMapping(),
1994 [&](OpBuilder nestedBuilder, Location nestedLoc,
ValueRange bbArgs) {
1995 auto castBlockArgs =
1996 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1997 for (auto [index, cast] : tensorCastProducers) {
1998 Value &oldTypeBBArg = castBlockArgs[index];
1999 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
2000 cast.dstType, oldTypeBBArg);
2004 SmallVector<Value> ivsBlockArgs =
2005 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
2006 ivsBlockArgs.append(castBlockArgs);
2008 bbArgs.front().getParentBlock(), ivsBlockArgs);
2014 auto terminator = newForallOp.getTerminator();
2015 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
2016 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
2017 if (
auto parallelCombingingOp =
2018 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
2019 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
2025 SmallVector<Value> castResults = newForallOp.getResults();
2026 for (
auto &item : tensorCastProducers) {
2027 Value &oldTypeResult = castResults[item.first];
2028 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
2031 rewriter.
replaceOp(forallOp, castResults);
2038void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
2039 MLIRContext *context) {
2040 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
2041 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
2042 ForallOpSingleOrZeroIterationDimsFolder,
2043 ForallOpReplaceConstantInductionVar>(context);
2051void ForallOp::getSuccessorRegions(RegionBranchPoint point,
2052 SmallVectorImpl<RegionSuccessor> ®ions) {
2057 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2060 RegionSuccessor(getOperation(), getOperation()->getResults()));
2068void InParallelOp::build(OpBuilder &
b, OperationState &
result) {
2069 OpBuilder::InsertionGuard g(
b);
2070 Region *bodyRegion =
result.addRegion();
2071 b.createBlock(bodyRegion);
2074LogicalResult InParallelOp::verify() {
2075 scf::ForallOp forallOp =
2076 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
2078 return this->
emitOpError(
"expected forall op parent");
2080 for (Operation &op : getRegion().front().getOperations()) {
2081 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
2082 if (!parallelCombiningOp) {
2083 return this->
emitOpError(
"expected only ParallelCombiningOpInterface")
2088 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
2089 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
2090 for (OpOperand &dest : dests) {
2091 if (!llvm::is_contained(regionOutArgs, dest.get()))
2092 return op.emitOpError(
"may only insert into an output block argument");
2099void InParallelOp::print(OpAsmPrinter &p) {
2107ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
2110 SmallVector<OpAsmParser::Argument, 8> regionOperands;
2111 std::unique_ptr<Region> region = std::make_unique<Region>();
2115 if (region->empty())
2116 OpBuilder(builder.
getContext()).createBlock(region.get());
2117 result.addRegion(std::move(region));
2125OpResult InParallelOp::getParentResult(int64_t idx) {
2126 return getOperation()->getParentOp()->getResult(idx);
2129SmallVector<BlockArgument> InParallelOp::getDests() {
2130 SmallVector<BlockArgument> updatedDests;
2131 for (Operation &yieldingOp : getYieldingOps()) {
2132 auto parallelCombiningOp =
2133 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2134 if (!parallelCombiningOp)
2136 for (OpOperand &updatedOperand :
2137 parallelCombiningOp.getUpdatedDestinations())
2138 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2140 return updatedDests;
2143llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2144 return getRegion().front().getOperations();
2152 assert(a &&
"expected non-empty operation");
2153 assert(
b &&
"expected non-empty operation");
2158 if (ifOp->isProperAncestor(
b))
2161 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2162 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*
b));
2164 ifOp = ifOp->getParentOfType<IfOp>();
2172IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2173 IfOp::Adaptor adaptor,
2175 if (adaptor.getRegions().empty())
2177 Region *r = &adaptor.getThenRegion();
2183 auto yieldOp = llvm::dyn_cast<YieldOp>(
b.back());
2186 TypeRange types = yieldOp.getOperandTypes();
2187 llvm::append_range(inferredReturnTypes, types);
2193 return build(builder,
result, resultTypes, cond,
false,
2197void IfOp::build(OpBuilder &builder, OperationState &
result,
2198 TypeRange resultTypes, Value cond,
bool addThenBlock,
2199 bool addElseBlock) {
2200 assert((!addElseBlock || addThenBlock) &&
2201 "must not create else block w/o then block");
2202 result.addTypes(resultTypes);
2203 result.addOperands(cond);
2206 OpBuilder::InsertionGuard guard(builder);
2207 Region *thenRegion =
result.addRegion();
2210 Region *elseRegion =
result.addRegion();
2215void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2216 bool withElseRegion) {
2220void IfOp::build(OpBuilder &builder, OperationState &
result,
2221 TypeRange resultTypes, Value cond,
bool withElseRegion) {
2222 result.addTypes(resultTypes);
2223 result.addOperands(cond);
2226 OpBuilder::InsertionGuard guard(builder);
2227 Region *thenRegion =
result.addRegion();
2229 if (resultTypes.empty())
2230 IfOp::ensureTerminator(*thenRegion, builder,
result.location);
2233 Region *elseRegion =
result.addRegion();
2234 if (withElseRegion) {
2236 if (resultTypes.empty())
2237 IfOp::ensureTerminator(*elseRegion, builder,
result.location);
2241void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2243 function_ref<
void(OpBuilder &, Location)> elseBuilder) {
2244 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2245 result.addOperands(cond);
2248 OpBuilder::InsertionGuard guard(builder);
2249 Region *thenRegion =
result.addRegion();
2251 thenBuilder(builder,
result.location);
2254 Region *elseRegion =
result.addRegion();
2257 elseBuilder(builder,
result.location);
2261 SmallVector<Type> inferredReturnTypes;
2263 auto attrDict = DictionaryAttr::get(ctx,
result.attributes);
2264 if (succeeded(inferReturnTypes(ctx, std::nullopt,
result.operands, attrDict,
2266 inferredReturnTypes))) {
2267 result.addTypes(inferredReturnTypes);
2271LogicalResult IfOp::verify() {
2272 if (getNumResults() != 0 && getElseRegion().empty())
2273 return emitOpError(
"must have an else block if defining values");
2277ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
2279 result.regions.reserve(2);
2280 Region *thenRegion =
result.addRegion();
2281 Region *elseRegion =
result.addRegion();
2284 OpAsmParser::UnresolvedOperand cond;
2310void IfOp::print(OpAsmPrinter &p) {
2311 bool printBlockTerminators =
false;
2313 p <<
" " << getCondition();
2314 if (!getResults().empty()) {
2315 p <<
" -> (" << getResultTypes() <<
")";
2317 printBlockTerminators =
true;
2322 printBlockTerminators);
2325 auto &elseRegion = getElseRegion();
2326 if (!elseRegion.
empty()) {
2330 printBlockTerminators);
2336void IfOp::getSuccessorRegions(RegionBranchPoint point,
2337 SmallVectorImpl<RegionSuccessor> ®ions) {
2341 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2345 regions.push_back(RegionSuccessor(&getThenRegion()));
2348 Region *elseRegion = &this->getElseRegion();
2349 if (elseRegion->
empty())
2351 RegionSuccessor(getOperation(), getOperation()->getResults()));
2353 regions.push_back(RegionSuccessor(elseRegion));
2356void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2357 SmallVectorImpl<RegionSuccessor> ®ions) {
2358 FoldAdaptor adaptor(operands, *
this);
2359 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2360 if (!boolAttr || boolAttr.getValue())
2361 regions.emplace_back(&getThenRegion());
2364 if (!boolAttr || !boolAttr.getValue()) {
2365 if (!getElseRegion().empty())
2366 regions.emplace_back(&getElseRegion());
2368 regions.emplace_back(getOperation(), getResults());
2372LogicalResult IfOp::fold(FoldAdaptor adaptor,
2373 SmallVectorImpl<OpFoldResult> &results) {
2375 if (getElseRegion().empty())
2378 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2385 getConditionMutable().assign(xorStmt.getLhs());
2386 Block *thenBlock = &getThenRegion().front();
2389 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2390 getElseRegion().getBlocks());
2391 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2392 getThenRegion().getBlocks(), thenBlock);
2396void IfOp::getRegionInvocationBounds(
2397 ArrayRef<Attribute> operands,
2398 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2399 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2402 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2403 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2406 invocationBounds.assign(2, {0, 1});
2412struct RemoveUnusedResults :
public OpRewritePattern<IfOp> {
2413 using OpRewritePattern<IfOp>::OpRewritePattern;
2415 void transferBody(
Block *source,
Block *dest, ArrayRef<OpResult> usedResults,
2416 PatternRewriter &rewriter)
const {
2421 SmallVector<Value, 4> usedOperands;
2422 llvm::transform(usedResults, std::back_inserter(usedOperands),
2424 return yieldOp.getOperand(
result.getResultNumber());
2427 [&]() { yieldOp->setOperands(usedOperands); });
2430 LogicalResult matchAndRewrite(IfOp op,
2431 PatternRewriter &rewriter)
const override {
2433 SmallVector<OpResult, 4> usedResults;
2434 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2435 [](OpResult
result) { return !result.use_empty(); });
2438 if (usedResults.size() == op.getNumResults())
2442 SmallVector<Type, 4> newTypes;
2443 llvm::transform(usedResults, std::back_inserter(newTypes),
2448 IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2454 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2455 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2458 SmallVector<Value, 4> repResults(op.getNumResults());
2459 for (
const auto &en : llvm::enumerate(usedResults))
2460 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2466struct RemoveStaticCondition :
public OpRewritePattern<IfOp> {
2467 using OpRewritePattern<IfOp>::OpRewritePattern;
2469 LogicalResult matchAndRewrite(IfOp op,
2470 PatternRewriter &rewriter)
const override {
2477 else if (!op.getElseRegion().empty())
2488struct ConvertTrivialIfToSelect :
public OpRewritePattern<IfOp> {
2489 using OpRewritePattern<IfOp>::OpRewritePattern;
2491 LogicalResult matchAndRewrite(IfOp op,
2492 PatternRewriter &rewriter)
const override {
2493 if (op->getNumResults() == 0)
2496 auto cond = op.getCondition();
2497 auto thenYieldArgs = op.thenYield().getOperands();
2498 auto elseYieldArgs = op.elseYield().getOperands();
2500 SmallVector<Type> nonHoistable;
2501 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2502 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2503 &op.getElseRegion() == falseVal.getParentRegion())
2504 nonHoistable.push_back(trueVal.getType());
2508 if (nonHoistable.size() == op->getNumResults())
2511 IfOp
replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2515 replacement.getThenRegion().takeBody(op.getThenRegion());
2516 replacement.getElseRegion().takeBody(op.getElseRegion());
2518 SmallVector<Value> results(op->getNumResults());
2519 assert(thenYieldArgs.size() == results.size());
2520 assert(elseYieldArgs.size() == results.size());
2522 SmallVector<Value> trueYields;
2523 SmallVector<Value> falseYields;
2525 for (
const auto &it :
2526 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2527 Value trueVal = std::get<0>(it.value());
2528 Value falseVal = std::get<1>(it.value());
2531 results[it.index()] =
replacement.getResult(trueYields.size());
2532 trueYields.push_back(trueVal);
2533 falseYields.push_back(falseVal);
2534 }
else if (trueVal == falseVal)
2535 results[it.index()] = trueVal;
2537 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2538 cond, trueVal, falseVal);
2565struct ConditionPropagation :
public OpRewritePattern<IfOp> {
2566 using OpRewritePattern<IfOp>::OpRewritePattern;
2569 enum class Parent { Then, Else,
None };
2574 static Parent getParentType(Region *toCheck, IfOp op,
2576 Region *endRegion) {
2577 SmallVector<Region *> seen;
2578 while (toCheck != endRegion) {
2579 auto found = cache.find(toCheck);
2580 if (found != cache.end())
2581 return found->second;
2582 seen.push_back(toCheck);
2583 if (&op.getThenRegion() == toCheck) {
2584 for (Region *region : seen)
2585 cache[region] = Parent::Then;
2586 return Parent::Then;
2588 if (&op.getElseRegion() == toCheck) {
2589 for (Region *region : seen)
2590 cache[region] = Parent::Else;
2591 return Parent::Else;
2596 for (Region *region : seen)
2597 cache[region] = Parent::None;
2598 return Parent::None;
2601 LogicalResult matchAndRewrite(IfOp op,
2602 PatternRewriter &rewriter)
const override {
2613 Value constantTrue =
nullptr;
2614 Value constantFalse =
nullptr;
2617 for (OpOperand &use :
2618 llvm::make_early_inc_range(op.getCondition().getUses())) {
2621 case Parent::Then: {
2625 constantTrue = arith::ConstantOp::create(
2629 [&]() { use.set(constantTrue); });
2632 case Parent::Else: {
2636 constantFalse = arith::ConstantOp::create(
2640 [&]() { use.set(constantFalse); });
2688struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2689 using OpRewritePattern<IfOp>::OpRewritePattern;
2691 LogicalResult matchAndRewrite(IfOp op,
2692 PatternRewriter &rewriter)
const override {
2694 if (op.getNumResults() == 0)
2698 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2700 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2703 op.getOperation()->getIterator());
2706 for (
auto [trueResult, falseResult, opResult] :
2707 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2709 if (trueResult == falseResult) {
2710 if (!opResult.use_empty()) {
2711 opResult.replaceAllUsesWith(trueResult);
2717 BoolAttr trueYield, falseYield;
2722 bool trueVal = trueYield.
getValue();
2723 bool falseVal = falseYield.
getValue();
2724 if (!trueVal && falseVal) {
2725 if (!opResult.use_empty()) {
2726 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2727 Value notCond = arith::XOrIOp::create(
2728 rewriter, op.getLoc(), op.getCondition(),
2734 opResult.replaceAllUsesWith(notCond);
2738 if (trueVal && !falseVal) {
2739 if (!opResult.use_empty()) {
2740 opResult.replaceAllUsesWith(op.getCondition());
2770struct CombineIfs :
public OpRewritePattern<IfOp> {
2771 using OpRewritePattern<IfOp>::OpRewritePattern;
2773 LogicalResult matchAndRewrite(IfOp nextIf,
2774 PatternRewriter &rewriter)
const override {
2775 Block *parent = nextIf->getBlock();
2776 if (nextIf == &parent->
front())
2779 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2787 Block *nextThen =
nullptr;
2788 Block *nextElse =
nullptr;
2789 if (nextIf.getCondition() == prevIf.getCondition()) {
2790 nextThen = nextIf.thenBlock();
2791 if (!nextIf.getElseRegion().empty())
2792 nextElse = nextIf.elseBlock();
2794 if (arith::XOrIOp notv =
2795 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2796 if (notv.getLhs() == prevIf.getCondition() &&
2798 nextElse = nextIf.thenBlock();
2799 if (!nextIf.getElseRegion().empty())
2800 nextThen = nextIf.elseBlock();
2803 if (arith::XOrIOp notv =
2804 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2805 if (notv.getLhs() == nextIf.getCondition() &&
2807 nextElse = nextIf.thenBlock();
2808 if (!nextIf.getElseRegion().empty())
2809 nextThen = nextIf.elseBlock();
2813 if (!nextThen && !nextElse)
2816 SmallVector<Value> prevElseYielded;
2817 if (!prevIf.getElseRegion().empty())
2818 prevElseYielded = prevIf.elseYield().getOperands();
2821 for (
auto it : llvm::zip(prevIf.getResults(),
2822 prevIf.thenYield().getOperands(), prevElseYielded))
2823 for (OpOperand &use :
2824 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2828 use.
set(std::get<1>(it));
2833 use.
set(std::get<2>(it));
2838 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2839 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2841 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2842 prevIf.getCondition(),
false);
2843 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2846 combinedIf.getThenRegion(),
2847 combinedIf.getThenRegion().begin());
2850 YieldOp thenYield = combinedIf.thenYield();
2851 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2852 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2855 SmallVector<Value> mergedYields(thenYield.getOperands());
2856 llvm::append_range(mergedYields, thenYield2.getOperands());
2857 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2863 combinedIf.getElseRegion(),
2864 combinedIf.getElseRegion().begin());
2867 if (combinedIf.getElseRegion().empty()) {
2869 combinedIf.getElseRegion(),
2870 combinedIf.getElseRegion().
begin());
2872 YieldOp elseYield = combinedIf.elseYield();
2873 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2874 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2878 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2879 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2881 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2887 SmallVector<Value> prevValues;
2888 SmallVector<Value> nextValues;
2889 for (
const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2890 if (pair.index() < prevIf.getNumResults())
2891 prevValues.push_back(pair.value());
2893 nextValues.push_back(pair.value());
2902struct RemoveEmptyElseBranch :
public OpRewritePattern<IfOp> {
2903 using OpRewritePattern<IfOp>::OpRewritePattern;
2905 LogicalResult matchAndRewrite(IfOp ifOp,
2906 PatternRewriter &rewriter)
const override {
2908 if (ifOp.getNumResults())
2910 Block *elseBlock = ifOp.elseBlock();
2911 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2915 newIfOp.getThenRegion().begin());
2937struct CombineNestedIfs :
public OpRewritePattern<IfOp> {
2938 using OpRewritePattern<IfOp>::OpRewritePattern;
2940 LogicalResult matchAndRewrite(IfOp op,
2941 PatternRewriter &rewriter)
const override {
2942 auto nestedOps = op.thenBlock()->without_terminator();
2944 if (!llvm::hasSingleElement(nestedOps))
2948 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2951 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2955 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2958 SmallVector<Value> thenYield(op.thenYield().getOperands());
2959 SmallVector<Value> elseYield;
2961 llvm::append_range(elseYield, op.elseYield().getOperands());
2965 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2974 for (
const auto &tup : llvm::enumerate(thenYield)) {
2975 if (tup.value().getDefiningOp() == nestedIf) {
2976 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2977 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2978 elseYield[tup.index()]) {
2983 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2996 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2999 elseYieldsToUpgradeToSelect.push_back(tup.index());
3002 Location loc = op.getLoc();
3003 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
3004 nestedIf.getCondition());
3005 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
3008 SmallVector<Value> results;
3009 llvm::append_range(results, newIf.getResults());
3012 for (
auto idx : elseYieldsToUpgradeToSelect)
3014 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
3015 thenYield[idx], elseYield[idx]);
3017 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
3020 if (!elseYield.empty()) {
3023 YieldOp::create(rewriter, loc, elseYield);
3032void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3033 MLIRContext *context) {
3034 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
3035 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
3036 RemoveStaticCondition, RemoveUnusedResults,
3037 ReplaceIfYieldWithConditionOrValue>(context);
3040Block *IfOp::thenBlock() {
return &getThenRegion().back(); }
3041YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
3042Block *IfOp::elseBlock() {
3043 Region &r = getElseRegion();
3048YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
3054void ParallelOp::build(
3059 result.addOperands(lowerBounds);
3060 result.addOperands(upperBounds);
3061 result.addOperands(steps);
3062 result.addOperands(initVals);
3064 ParallelOp::getOperandSegmentSizeAttr(),
3066 static_cast<int32_t>(upperBounds.size()),
3067 static_cast<int32_t>(steps.size()),
3068 static_cast<int32_t>(initVals.size())}));
3071 OpBuilder::InsertionGuard guard(builder);
3072 unsigned numIVs = steps.size();
3073 SmallVector<Type, 8> argTypes(numIVs, builder.
getIndexType());
3074 SmallVector<Location, 8> argLocs(numIVs,
result.location);
3075 Region *bodyRegion =
result.addRegion();
3078 if (bodyBuilderFn) {
3080 bodyBuilderFn(builder,
result.location,
3085 if (initVals.empty())
3086 ParallelOp::ensureTerminator(*bodyRegion, builder,
result.location);
3089void ParallelOp::build(
3096 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
3099 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3103 wrapper = wrappedBuilderFn;
3109LogicalResult ParallelOp::verify() {
3114 if (stepValues.empty())
3116 "needs at least one tuple element for lowerBound, upperBound and step");
3119 for (Value stepValue : stepValues)
3122 return emitOpError(
"constant step operand must be positive");
3126 Block *body = getBody();
3128 return emitOpError() <<
"expects the same number of induction variables: "
3130 <<
" as bound and step values: " << stepValues.size();
3132 if (!arg.getType().isIndex())
3134 "expects arguments for the induction variable to be of index type");
3138 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
3143 auto resultsSize = getResults().size();
3144 auto reductionsSize = reduceOp.getReductions().size();
3145 auto initValsSize = getInitVals().size();
3146 if (resultsSize != reductionsSize)
3147 return emitOpError() <<
"expects number of results: " << resultsSize
3148 <<
" to be the same as number of reductions: "
3150 if (resultsSize != initValsSize)
3151 return emitOpError() <<
"expects number of results: " << resultsSize
3152 <<
" to be the same as number of initial values: "
3156 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3157 auto resultType = getOperation()->getResult(i).getType();
3158 auto reductionOperandType = reduceOp.getOperands()[i].getType();
3159 if (resultType != reductionOperandType)
3160 return reduceOp.emitOpError()
3161 <<
"expects type of " << i
3162 <<
"-th reduction operand: " << reductionOperandType
3163 <<
" to be the same as the " << i
3164 <<
"-th result type: " << resultType;
3169ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
3172 SmallVector<OpAsmParser::Argument, 4> ivs;
3177 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
3184 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
3192 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
3200 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
3211 Region *body =
result.addRegion();
3212 for (
auto &iv : ivs)
3219 ParallelOp::getOperandSegmentSizeAttr(),
3221 static_cast<int32_t>(upper.size()),
3222 static_cast<int32_t>(steps.size()),
3223 static_cast<int32_t>(initVals.size())}));
3232 ParallelOp::ensureTerminator(*body, builder,
result.location);
3236void ParallelOp::print(OpAsmPrinter &p) {
3237 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
3238 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
3239 if (!getInitVals().empty())
3240 p <<
" init (" << getInitVals() <<
")";
3245 (*this)->getAttrs(),
3246 ParallelOp::getOperandSegmentSizeAttr());
3249SmallVector<Region *> ParallelOp::getLoopRegions() {
return {&getRegion()}; }
3251std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3252 return SmallVector<Value>{getBody()->getArguments()};
3255std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3259std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3263std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3268 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3270 return ParallelOp();
3271 assert(ivArg.getOwner() &&
"unlinked block argument");
3272 auto *containingOp = ivArg.getOwner()->getParentOp();
3273 return dyn_cast<ParallelOp>(containingOp);
3278struct ParallelOpSingleOrZeroIterationDimsFolder
3282 LogicalResult matchAndRewrite(ParallelOp op,
3289 for (
auto [lb,
ub, step, iv] :
3290 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3291 op.getInductionVars())) {
3292 auto numIterations =
3294 if (numIterations.has_value()) {
3296 if (*numIterations == 0) {
3297 rewriter.
replaceOp(op, op.getInitVals());
3302 if (*numIterations == 1) {
3307 newLowerBounds.push_back(lb);
3308 newUpperBounds.push_back(ub);
3309 newSteps.push_back(step);
3312 if (newLowerBounds.size() == op.getLowerBound().size())
3315 if (newLowerBounds.empty()) {
3318 SmallVector<Value> results;
3319 results.reserve(op.getInitVals().size());
3320 for (
auto &bodyOp : op.getBody()->without_terminator())
3321 rewriter.
clone(bodyOp, mapping);
3322 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3323 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3324 Block &reduceBlock = reduceOp.getReductions()[i].front();
3325 auto initValIndex = results.size();
3326 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3330 rewriter.
clone(reduceBodyOp, mapping);
3333 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3334 results.push_back(
result);
3342 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3343 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3349 newOp.getRegion().begin(), mapping);
3350 rewriter.
replaceOp(op, newOp.getResults());
3355struct MergeNestedParallelLoops :
public OpRewritePattern<ParallelOp> {
3356 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3358 LogicalResult matchAndRewrite(ParallelOp op,
3359 PatternRewriter &rewriter)
const override {
3360 Block &outerBody = *op.getBody();
3364 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3369 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3370 llvm::is_contained(innerOp.getUpperBound(), val) ||
3371 llvm::is_contained(innerOp.getStep(), val))
3375 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3378 auto bodyBuilder = [&](OpBuilder &builder, Location ,
3380 Block &innerBody = *innerOp.getBody();
3381 assert(iterVals.size() ==
3389 builder.
clone(op, mapping);
3392 auto concatValues = [](
const auto &first,
const auto &second) {
3393 SmallVector<Value> ret;
3394 ret.reserve(first.size() + second.size());
3395 ret.assign(first.begin(), first.end());
3396 ret.append(second.begin(), second.end());
3400 auto newLowerBounds =
3401 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3402 auto newUpperBounds =
3403 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3404 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3415void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3416 MLIRContext *context) {
3418 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3427void ParallelOp::getSuccessorRegions(
3428 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3432 regions.push_back(RegionSuccessor(&getRegion()));
3433 regions.push_back(RegionSuccessor(
3434 getOperation(), ResultRange{getResults().end(), getResults().end()}));
3441void ReduceOp::build(OpBuilder &builder, OperationState &
result) {}
3443void ReduceOp::build(OpBuilder &builder, OperationState &
result,
3445 result.addOperands(operands);
3446 for (Value v : operands) {
3447 OpBuilder::InsertionGuard guard(builder);
3448 Region *bodyRegion =
result.addRegion();
3455LogicalResult ReduceOp::verifyRegions() {
3458 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3459 auto type = getOperands()[i].getType();
3462 return emitOpError() << i <<
"-th reduction has an empty body";
3464 llvm::any_of(block.
getArguments(), [&](
const BlockArgument &arg) {
3465 return arg.getType() != type;
3467 return emitOpError() <<
"expected two block arguments with type " << type
3468 <<
" in the " << i <<
"-th reduction region";
3472 return emitOpError(
"reduction bodies must be terminated with an "
3473 "'scf.reduce.return' op");
3480ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3482 return MutableOperandRange(getOperation(), 0, 0);
3489LogicalResult ReduceReturnOp::verify() {
3492 Block *reductionBody = getOperation()->getBlock();
3494 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3496 if (expectedResultType != getResult().
getType())
3497 return emitOpError() <<
"must have type " << expectedResultType
3498 <<
" (the type of the reduction inputs)";
3506void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3507 ::mlir::OperationState &odsState,
TypeRange resultTypes,
3508 ValueRange inits, BodyBuilderFn beforeBuilder,
3509 BodyBuilderFn afterBuilder) {
3513 OpBuilder::InsertionGuard guard(odsBuilder);
3516 SmallVector<Location, 4> beforeArgLocs;
3517 beforeArgLocs.reserve(inits.size());
3518 for (Value operand : inits) {
3519 beforeArgLocs.push_back(operand.getLoc());
3522 Region *beforeRegion = odsState.
addRegion();
3524 inits.getTypes(), beforeArgLocs);
3529 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.
location);
3531 Region *afterRegion = odsState.
addRegion();
3533 resultTypes, afterArgLocs);
3539ConditionOp WhileOp::getConditionOp() {
3540 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3543YieldOp WhileOp::getYieldOp() {
3544 return cast<YieldOp>(getAfterBody()->getTerminator());
3547std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3548 return getYieldOp().getResultsMutable();
3552 return getBeforeBody()->getArguments();
3556 return getAfterBody()->getArguments();
3560 return getBeforeArguments();
3563OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3565 "WhileOp is expected to branch only to the first region");
3569void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3570 SmallVectorImpl<RegionSuccessor> ®ions) {
3573 regions.emplace_back(&getBefore(), getBefore().getArguments());
3577 assert(llvm::is_contained(
3578 {&getAfter(), &getBefore()},
3580 "there are only two regions in a WhileOp");
3584 regions.emplace_back(&getBefore(), getBefore().getArguments());
3588 regions.emplace_back(getOperation(), getResults());
3589 regions.emplace_back(&getAfter(), getAfter().getArguments());
3592SmallVector<Region *> WhileOp::getLoopRegions() {
3593 return {&getBefore(), &getAfter()};
3603ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
3604 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3605 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3606 Region *before =
result.addRegion();
3607 Region *after =
result.addRegion();
3609 OptionalParseResult listResult =
3614 FunctionType functionType;
3619 result.addTypes(functionType.getResults());
3621 if (functionType.getNumInputs() != operands.size()) {
3623 <<
"expected as many input types as operands " <<
"(expected "
3624 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3634 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3635 regionArgs[i].type = functionType.getInput(i);
3637 return failure(parser.
parseRegion(*before, regionArgs) ||
3643void scf::WhileOp::print(OpAsmPrinter &p) {
3657template <
typename OpTy>
3660 if (left.size() != right.size())
3661 return op.emitOpError(
"expects the same number of ") << message;
3663 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3664 if (left[i] != right[i]) {
3667 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3668 <<
" and " << right[i];
3676LogicalResult scf::WhileOp::verify() {
3679 "expects the 'before' region to terminate with 'scf.condition'");
3680 if (!beforeTerminator)
3685 "expects the 'after' region to terminate with 'scf.yield'");
3686 return success(afterTerminator !=
nullptr);
3709struct WhileConditionTruth :
public OpRewritePattern<WhileOp> {
3710 using OpRewritePattern<WhileOp>::OpRewritePattern;
3712 LogicalResult matchAndRewrite(WhileOp op,
3713 PatternRewriter &rewriter)
const override {
3714 auto term = op.getConditionOp();
3718 Value constantTrue =
nullptr;
3720 bool replaced =
false;
3721 for (
auto yieldedAndBlockArgs :
3722 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3723 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3724 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3726 constantTrue = arith::ConstantOp::create(
3727 rewriter, op.getLoc(), term.getCondition().getType(),
3788struct RemoveLoopInvariantArgsFromBeforeBlock
3789 :
public OpRewritePattern<WhileOp> {
3790 using OpRewritePattern<WhileOp>::OpRewritePattern;
3792 LogicalResult matchAndRewrite(WhileOp op,
3793 PatternRewriter &rewriter)
const override {
3794 Block &afterBlock = *op.getAfterBody();
3796 ConditionOp condOp = op.getConditionOp();
3797 OperandRange condOpArgs = condOp.getArgs();
3801 bool canSimplify =
false;
3802 for (
const auto &it :
3803 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3804 auto index =
static_cast<unsigned>(it.index());
3805 auto [initVal, yieldOpArg] = it.value();
3808 if (yieldOpArg == initVal) {
3817 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3818 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3819 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3820 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3830 SmallVector<Value> newInitArgs, newYieldOpArgs;
3832 SmallVector<Location> newBeforeBlockArgLocs;
3833 for (
const auto &it :
3834 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3835 auto index =
static_cast<unsigned>(it.index());
3836 auto [initVal, yieldOpArg] = it.value();
3840 if (yieldOpArg == initVal) {
3841 beforeBlockInitValMap.insert({index, initVal});
3849 auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3850 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3851 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3852 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3853 beforeBlockInitValMap.insert({index, initVal});
3858 newInitArgs.emplace_back(initVal);
3859 newYieldOpArgs.emplace_back(yieldOpArg);
3860 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3864 OpBuilder::InsertionGuard g(rewriter);
3869 auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3873 &newWhile.getBefore(), {},
3874 ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3876 Block &beforeBlock = *op.getBeforeBody();
3883 for (
unsigned i = 0, j = 0, n = beforeBlock.
getNumArguments(); i < n; i++) {
3886 if (beforeBlockInitValMap.count(i) != 0)
3887 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3889 newBeforeBlockArgs[i] = newBeforeBlock.
getArgument(j++);
3892 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3894 newWhile.getAfter().begin());
3896 rewriter.
replaceOp(op, newWhile.getResults());
3941struct RemoveLoopInvariantValueYielded :
public OpRewritePattern<WhileOp> {
3942 using OpRewritePattern<WhileOp>::OpRewritePattern;
3944 LogicalResult matchAndRewrite(WhileOp op,
3945 PatternRewriter &rewriter)
const override {
3946 Block &beforeBlock = *op.getBeforeBody();
3947 ConditionOp condOp = op.getConditionOp();
3948 OperandRange condOpArgs = condOp.getArgs();
3950 bool canSimplify =
false;
3951 for (Value condOpArg : condOpArgs) {
3966 SmallVector<Value> newCondOpArgs;
3967 SmallVector<Type> newAfterBlockType;
3969 SmallVector<Location> newAfterBlockArgLocs;
3970 for (
const auto &it : llvm::enumerate(condOpArgs)) {
3971 auto index =
static_cast<unsigned>(it.index());
3972 Value condOpArg = it.value();
3977 condOpInitValMap.insert({index, condOpArg});
3979 newCondOpArgs.emplace_back(condOpArg);
3980 newAfterBlockType.emplace_back(condOpArg.
getType());
3981 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3986 OpBuilder::InsertionGuard g(rewriter);
3992 auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3995 Block &newAfterBlock =
3997 newAfterBlockType, newAfterBlockArgLocs);
3999 Block &afterBlock = *op.getAfterBody();
4006 for (
unsigned i = 0, j = 0, n = afterBlock.
getNumArguments(); i < n; i++) {
4007 Value afterBlockArg,
result;
4010 if (condOpInitValMap.count(i) != 0) {
4011 afterBlockArg = condOpInitValMap[i];
4015 result = newWhile.getResult(j);
4018 newAfterBlockArgs[i] = afterBlockArg;
4019 newWhileResults[i] =
result;
4022 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4024 newWhile.getBefore().begin());
4026 rewriter.
replaceOp(op, newWhileResults);
4057struct WhileUnusedResult :
public OpRewritePattern<WhileOp> {
4058 using OpRewritePattern<WhileOp>::OpRewritePattern;
4060 LogicalResult matchAndRewrite(WhileOp op,
4061 PatternRewriter &rewriter)
const override {
4062 auto term = op.getConditionOp();
4063 auto afterArgs = op.getAfterArguments();
4064 auto termArgs = term.getArgs();
4067 SmallVector<unsigned> newResultsIndices;
4068 SmallVector<Type> newResultTypes;
4069 SmallVector<Value> newTermArgs;
4070 SmallVector<Location> newArgLocs;
4071 bool needUpdate =
false;
4072 for (
const auto &it :
4073 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
4074 auto i =
static_cast<unsigned>(it.index());
4075 Value
result = std::get<0>(it.value());
4076 Value afterArg = std::get<1>(it.value());
4077 Value termArg = std::get<2>(it.value());
4081 newResultsIndices.emplace_back(i);
4082 newTermArgs.emplace_back(termArg);
4083 newResultTypes.emplace_back(
result.getType());
4084 newArgLocs.emplace_back(
result.getLoc());
4092 OpBuilder::InsertionGuard g(rewriter);
4099 WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
4102 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
4106 SmallVector<Value> newResults(op.getNumResults());
4107 SmallVector<Value> newAfterBlockArgs(op.getNumResults());
4108 for (
const auto &it : llvm::enumerate(newResultsIndices)) {
4109 newResults[it.value()] = newWhile.getResult(it.index());
4110 newAfterBlockArgs[it.value()] = newAfterBlock.
getArgument(it.index());
4114 newWhile.getBefore().begin());
4116 Block &afterBlock = *op.getAfterBody();
4117 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
4146struct WhileCmpCond :
public OpRewritePattern<scf::WhileOp> {
4147 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
4149 LogicalResult matchAndRewrite(scf::WhileOp op,
4150 PatternRewriter &rewriter)
const override {
4151 using namespace scf;
4152 auto cond = op.getConditionOp();
4153 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4157 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4158 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
4159 if (std::get<0>(tup) != cmp.getOperand(opIdx))
4162 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4163 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4167 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4170 if (cmp2.getPredicate() == cmp.getPredicate())
4171 samePredicate =
true;
4172 else if (cmp2.getPredicate() ==
4173 arith::invertPredicate(cmp.getPredicate()))
4174 samePredicate =
false;
4189struct WhileRemoveUnusedArgs :
public OpRewritePattern<WhileOp> {
4190 using OpRewritePattern<WhileOp>::OpRewritePattern;
4192 LogicalResult matchAndRewrite(WhileOp op,
4193 PatternRewriter &rewriter)
const override {
4195 if (!llvm::any_of(op.getBeforeArguments(),
4196 [](Value arg) { return arg.use_empty(); }))
4199 YieldOp yield = op.getYieldOp();
4202 SmallVector<Value> newYields;
4203 SmallVector<Value> newInits;
4204 llvm::BitVector argsToErase;
4206 size_t argsCount = op.getBeforeArguments().size();
4207 newYields.reserve(argsCount);
4208 newInits.reserve(argsCount);
4209 argsToErase.reserve(argsCount);
4210 for (
auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4211 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4212 if (beforeArg.use_empty()) {
4213 argsToErase.push_back(
true);
4215 argsToErase.push_back(
false);
4216 newYields.emplace_back(yieldValue);
4217 newInits.emplace_back(initValue);
4221 Block &beforeBlock = *op.getBeforeBody();
4222 Block &afterBlock = *op.getAfterBody();
4226 Location loc = op.getLoc();
4228 WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4230 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4231 Block &newAfterBlock = *newWhileOp.getAfterBody();
4233 OpBuilder::InsertionGuard g(rewriter);
4237 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4242 rewriter.
replaceOp(op, newWhileOp.getResults());
4248struct WhileRemoveDuplicatedResults :
public OpRewritePattern<WhileOp> {
4251 LogicalResult matchAndRewrite(WhileOp op,
4252 PatternRewriter &rewriter)
const override {
4253 ConditionOp condOp = op.getConditionOp();
4256 llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4258 if (argsSet.size() == condOpArgs.size())
4261 llvm::SmallDenseMap<Value, unsigned> argsMap;
4262 SmallVector<Value> newArgs;
4263 argsMap.reserve(condOpArgs.size());
4264 newArgs.reserve(condOpArgs.size());
4265 for (Value arg : condOpArgs) {
4266 if (!argsMap.count(arg)) {
4267 auto pos =
static_cast<unsigned>(argsMap.size());
4268 argsMap.insert({arg, pos});
4269 newArgs.emplace_back(arg);
4275 Location loc = op.getLoc();
4277 scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4280 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4281 Block &newAfterBlock = *newWhileOp.getAfterBody();
4283 SmallVector<Value> afterArgsMapping;
4284 SmallVector<Value> resultsMapping;
4285 for (
auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4286 auto it = argsMap.find(arg);
4287 assert(it != argsMap.end());
4288 auto pos = it->second;
4289 afterArgsMapping.emplace_back(newAfterBlock.
getArgument(pos));
4290 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4293 OpBuilder::InsertionGuard g(rewriter);
4298 Block &beforeBlock = *op.getBeforeBody();
4299 Block &afterBlock = *op.getAfterBody();
4301 rewriter.
mergeBlocks(&beforeBlock, &newBeforeBlock,
4303 rewriter.
mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4311static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
4313 if (args1.size() != args2.size())
4314 return std::nullopt;
4316 SmallVector<unsigned> ret(args1.size());
4317 for (
auto &&[i, arg1] : llvm::enumerate(args1)) {
4318 auto it = llvm::find(args2, arg1);
4319 if (it == args2.end())
4320 return std::nullopt;
4322 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
4329 llvm::SmallDenseSet<Value> set;
4330 for (Value arg : args) {
4331 if (!set.insert(arg).second)
4341struct WhileOpAlignBeforeArgs :
public OpRewritePattern<WhileOp> {
4344 LogicalResult matchAndRewrite(WhileOp loop,
4345 PatternRewriter &rewriter)
const override {
4346 auto oldBefore = loop.getBeforeBody();
4347 ConditionOp oldTerm = loop.getConditionOp();
4348 ValueRange beforeArgs = oldBefore->getArguments();
4350 if (beforeArgs == termArgs)
4353 if (hasDuplicates(termArgs))
4356 auto mapping = getArgsMapping(beforeArgs, termArgs);
4361 OpBuilder::InsertionGuard g(rewriter);
4367 auto oldAfter = loop.getAfterBody();
4369 SmallVector<Type> newResultTypes(beforeArgs.size());
4370 for (
auto &&[i, j] : llvm::enumerate(*mapping))
4371 newResultTypes[j] = loop.getResult(i).getType();
4373 auto newLoop = WhileOp::create(
4374 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4376 auto newBefore = newLoop.getBeforeBody();
4377 auto newAfter = newLoop.getAfterBody();
4379 SmallVector<Value> newResults(beforeArgs.size());
4380 SmallVector<Value> newAfterArgs(beforeArgs.size());
4381 for (
auto &&[i, j] : llvm::enumerate(*mapping)) {
4382 newResults[i] = newLoop.getResult(j);
4383 newAfterArgs[i] = newAfter->getArgument(j);
4387 newBefore->getArguments());
4397void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4398 MLIRContext *context) {
4399 results.
add<RemoveLoopInvariantArgsFromBeforeBlock,
4400 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4401 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4402 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4416 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
4419 caseValues.push_back(value);
4428 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
4430 p <<
"case " << value <<
' ';
4435LogicalResult scf::IndexSwitchOp::verify() {
4436 if (getCases().size() != getCaseRegions().size()) {
4438 << getCaseRegions().size() <<
" case regions but "
4439 << getCases().size() <<
" case values";
4443 for (int64_t value : getCases())
4444 if (!valueSet.insert(value).second)
4445 return emitOpError(
"has duplicate case value: ") << value;
4446 auto verifyRegion = [&](Region ®ion,
const Twine &name) -> LogicalResult {
4447 auto yield = dyn_cast<YieldOp>(region.
front().
back());
4449 return emitOpError(
"expected region to end with scf.yield, but got ")
4452 if (yield.getNumOperands() != getNumResults()) {
4453 return (
emitOpError(
"expected each region to return ")
4454 << getNumResults() <<
" values, but " << name <<
" returns "
4455 << yield.getNumOperands())
4456 .attachNote(yield.getLoc())
4457 <<
"see yield operation here";
4459 for (
auto [idx,
result, operand] :
4460 llvm::enumerate(getResultTypes(), yield.getOperands())) {
4462 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
4463 if (
result == operand.getType())
4466 << idx <<
" of each region to be " <<
result)
4467 .attachNote(yield.getLoc())
4468 << name <<
" returns " << operand.getType() <<
" here";
4475 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4482unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
4484Block &scf::IndexSwitchOp::getDefaultBlock() {
4485 return getDefaultRegion().front();
4488Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
4489 assert(idx < getNumCases() &&
"case index out-of-bounds");
4490 return getCaseRegions()[idx].front();
4493void IndexSwitchOp::getSuccessorRegions(
4494 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
4497 successors.emplace_back(getOperation(), getResults());
4501 llvm::append_range(successors, getRegions());
4504void IndexSwitchOp::getEntrySuccessorRegions(
4505 ArrayRef<Attribute> operands,
4506 SmallVectorImpl<RegionSuccessor> &successors) {
4507 FoldAdaptor adaptor(operands, *
this);
4510 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4512 llvm::append_range(successors, getRegions());
4518 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4519 if (caseValue == arg.getInt()) {
4520 successors.emplace_back(&caseRegion);
4524 successors.emplace_back(&getDefaultRegion());
4527void IndexSwitchOp::getRegionInvocationBounds(
4528 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
4529 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4530 if (!operandValue) {
4532 bounds.append(getNumRegions(), InvocationBounds(0, 1));
4536 unsigned liveIndex = getNumRegions() - 1;
4537 const auto *it = llvm::find(getCases(), operandValue.getInt());
4538 if (it != getCases().end())
4539 liveIndex = std::distance(getCases().begin(), it);
4540 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
4541 bounds.emplace_back(0, i == liveIndex);
4552 if (!maybeCst.has_value())
4555 int64_t caseIdx, e = op.getNumCases();
4556 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4557 if (cst == op.getCases()[caseIdx])
4561 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4562 : op.getDefaultRegion();
4577void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4578 MLIRContext *context) {
4579 results.
add<FoldConstantCase>(context);
4586#define GET_OP_CLASSES
4587#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.