30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
47struct SCFInlinerInterface :
public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
52 IRMapping &valueMapping)
const final {
57 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
62 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
67 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
78void SCFDialect::initialize() {
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
95 scf::YieldOp::create(builder, loc);
100template <
typename TerminatorTy>
102 StringRef errorMessage) {
103 Operation *terminatorOperation =
nullptr;
105 terminatorOperation = ®ion.
front().
back();
106 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
118 auto addOp =
ub.getDefiningOp<arith::AddIOp>();
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
125 if (addOp.getLhs() != lb ||
146ParseResult ExecuteRegionOp::parse(
OpAsmParser &parser,
174LogicalResult ExecuteRegionOp::verify() {
175 if (getRegion().empty())
176 return emitOpError(
"region needs to have at least one block");
177 if (getRegion().front().getNumArguments() > 0)
178 return emitOpError(
"region cannot have any arguments");
224 if (op.getNoInline())
226 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
229 Block *prevBlock = op->getBlock();
233 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
235 for (
Block &blk : op.getRegion()) {
236 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
238 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
239 yieldOp.getResults());
247 for (
auto res : op.getResults())
248 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
259 results, ExecuteRegionOp::getOperationName());
262 results, ExecuteRegionOp::getOperationName(),
264 return failure(cast<ExecuteRegionOp>(op).getNoInline());
268void ExecuteRegionOp::getSuccessorRegions(
293 "condition op can only exit the loop or branch to the after"
296 return getArgsMutable();
299void ConditionOp::getSuccessorRegions(
301 FoldAdaptor adaptor(operands, *
this);
303 WhileOp whileOp = getParentOp();
307 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
308 if (!boolAttr || boolAttr.getValue())
309 regions.emplace_back(&whileOp.getAfter());
310 if (!boolAttr || !boolAttr.getValue())
320 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
324 result.addAttribute(getUnsignedCmpAttrName(
result.name),
327 result.addOperands(initArgs);
328 for (
Value v : initArgs)
329 result.addTypes(v.getType());
334 for (
Value v : initArgs)
340 if (initArgs.empty() && !bodyBuilder) {
341 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
342 }
else if (bodyBuilder) {
350LogicalResult ForOp::verify() {
352 if (getInitArgs().size() != getNumResults())
354 "mismatch in number of loop-carried values and defined values");
359LogicalResult ForOp::verifyRegions() {
364 "expected induction variable to be same type as bounds and step");
366 if (getNumRegionIterArgs() != getNumResults())
368 "mismatch in number of basic block args and defined values");
370 auto initArgs = getInitArgs();
371 auto iterArgs = getRegionIterArgs();
372 auto opResults = getResults();
374 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
376 return emitOpError() <<
"types mismatch between " << i
377 <<
"th iter operand and defined value";
379 return emitOpError() <<
"types mismatch between " << i
380 <<
"th iter region arg and defined value";
387std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
391std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
395std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
399std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
403bool ForOp::isValidInductionVarType(
Type type) {
408 if (bounds.size() != 1)
410 if (
auto val = dyn_cast<Value>(bounds[0])) {
418 if (bounds.size() != 1)
420 if (
auto val = dyn_cast<Value>(bounds[0])) {
428 if (steps.size() != 1)
430 if (
auto val = dyn_cast<Value>(steps[0])) {
437std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
441LogicalResult ForOp::promoteIfSingleIteration(
RewriterBase &rewriter) {
442 std::optional<APInt> tripCount = getStaticTripCount();
443 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
446 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
449 if (*tripCount == 0) {
456 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
463 llvm::append_range(bbArgReplacements, getInitArgs());
467 getOperation()->getIterator(), bbArgReplacements);
483 StringRef prefix =
"") {
484 assert(blocksArgs.size() == initializers.size() &&
485 "expected same length of arguments and initializers");
486 if (initializers.empty())
490 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
491 p << std::get<0>(it) <<
" = " << std::get<1>(it);
497 if (getUnsignedCmp())
500 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
504 if (!getInitArgs().empty())
505 p <<
" -> (" << getInitArgs().getTypes() <<
')';
508 p <<
" : " << t <<
' ';
511 !getInitArgs().empty());
513 getUnsignedCmpAttrName().strref());
524 result.addAttribute(getUnsignedCmpAttrName(
result.name),
538 regionArgs.push_back(inductionVariable);
548 if (regionArgs.size() !=
result.types.size() + 1)
551 "mismatch in number of loop-carried values and defined values");
560 regionArgs.front().type = type;
561 for (
auto [iterArg, type] :
562 llvm::zip_equal(llvm::drop_begin(regionArgs),
result.types))
569 ForOp::ensureTerminator(*body, builder,
result.location);
578 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
579 operands,
result.types)) {
580 Type type = std::get<2>(argOperandType);
581 std::get<0>(argOperandType).type = type;
598 return getBody()->getArguments().drop_front(getNumInductionVars());
602 return getInitArgsMutable();
605FailureOr<LoopLikeOpInterface>
606ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
608 bool replaceInitOperandUsesInLoop,
613 auto inits = llvm::to_vector(getInitArgs());
614 inits.append(newInitOperands.begin(), newInitOperands.end());
615 scf::ForOp newLoop = scf::ForOp::create(
621 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
623 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
628 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
629 assert(newInitOperands.size() == newYieldedValues.size() &&
630 "expected as many new yield values as new iter operands");
632 yieldOp.getResultsMutable().append(newYieldedValues);
638 newLoop.getBody()->getArguments().take_front(
639 getBody()->getNumArguments()));
641 if (replaceInitOperandUsesInLoop) {
644 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
655 newLoop->getResults().take_front(getNumResults()));
656 return cast<LoopLikeOpInterface>(newLoop.getOperation());
660 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
663 assert(ivArg.getOwner() &&
"unlinked block argument");
664 auto *containingOp = ivArg.getOwner()->getParentOp();
665 return dyn_cast_or_null<ForOp>(containingOp);
669 return getInitArgs();
674 if (std::optional<APInt> tripCount = getStaticTripCount()) {
677 if (*tripCount == 0) {
686 }
else if (*tripCount == 1) {
710LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
711 for (
auto [lb, ub, step] :
712 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
715 if (!tripCount.has_value() || *tripCount != 1)
724 return getBody()->getArguments().drop_front(getRank());
727MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
728 return getOutputsMutable();
734 scf::InParallelOp terminator = forallOp.getTerminator();
739 bbArgReplacements.append(forallOp.getOutputs().begin(),
740 forallOp.getOutputs().end());
744 forallOp->getIterator(), bbArgReplacements);
749 results.reserve(forallOp.getResults().size());
750 for (
auto &yieldingOp : terminator.getYieldingOps()) {
751 auto parallelInsertSliceOp =
752 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
753 if (!parallelInsertSliceOp)
756 Value dst = parallelInsertSliceOp.getDest();
757 Value src = parallelInsertSliceOp.getSource();
758 if (llvm::isa<TensorType>(src.
getType())) {
759 results.push_back(tensor::InsertSliceOp::create(
760 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
761 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
762 parallelInsertSliceOp.getStrides(),
763 parallelInsertSliceOp.getStaticOffsets(),
764 parallelInsertSliceOp.getStaticSizes(),
765 parallelInsertSliceOp.getStaticStrides()));
767 llvm_unreachable(
"unsupported terminator");
782 assert(lbs.size() == ubs.size() &&
783 "expected the same number of lower and upper bounds");
784 assert(lbs.size() == steps.size() &&
785 "expected the same number of lower bounds and steps");
790 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
792 assert(results.size() == iterArgs.size() &&
793 "loop nest body must return as many values as loop has iteration "
795 return LoopNest{{}, std::move(results)};
803 loops.reserve(lbs.size());
804 ivs.reserve(lbs.size());
807 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
808 auto loop = scf::ForOp::create(
809 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
815 currentIterArgs = args;
816 currentLoc = nestedLoc;
822 loops.push_back(loop);
826 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
828 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
835 ? bodyBuilder(builder, currentLoc, ivs,
836 loops.back().getRegionIterArgs())
838 assert(results.size() == iterArgs.size() &&
839 "loop nest body must return as many values as loop has iteration "
842 scf::YieldOp::create(builder, loc, results);
846 llvm::append_range(nestResults, loops.front().getResults());
847 return LoopNest{std::move(loops), std::move(nestResults)};
860 bodyBuilder(nestedBuilder, nestedLoc, ivs);
869 assert(operand.
getOwner() == forOp);
874 "expected an iter OpOperand");
876 "Expected a different type");
878 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
883 newIterOperands.push_back(opOperand.get());
887 scf::ForOp newForOp = scf::ForOp::create(
888 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
889 forOp.getStep(), newIterOperands,
nullptr,
890 forOp.getUnsignedCmp());
891 newForOp->setAttrs(forOp->getAttrs());
892 Block &newBlock = newForOp.getRegion().
front();
900 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
902 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
903 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
907 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
910 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
913 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
914 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
915 clonedYieldOp.getOperand(yieldIdx));
917 newYieldOperands[yieldIdx] = castOut;
918 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
919 rewriter.
eraseOp(clonedYieldOp);
924 newResults[yieldIdx] =
925 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
960 LogicalResult matchAndRewrite(ForOp op,
962 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
963 OpOperand &iterOpOperand = std::get<0>(it);
966 incomingCast.getSource().getType() == incomingCast.getType())
971 incomingCast.getDest().getType(),
972 incomingCast.getSource().getType()))
974 if (!std::get<1>(it).hasOneUse())
980 rewriter, op, iterOpOperand, incomingCast.getSource(),
982 return tensor::CastOp::create(b, loc, type, source);
991void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
992 MLIRContext *context) {
993 results.
add<ForOpTensorCastFolder>(context);
995 results, ForOp::getOperationName());
997 results, ForOp::getOperationName(),
998 [](OpBuilder &builder, Location loc, Value value) {
1002 auto blockArg = cast<BlockArgument>(value);
1003 assert(blockArg.getArgNumber() == 0 &&
"expected induction variable");
1004 auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
1005 return forOp.getLowerBound();
1009std::optional<APInt> ForOp::getConstantStep() {
1012 return step.getValue();
1016std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1017 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1023 if (
auto constantStep = getConstantStep())
1024 if (*constantStep == 1)
1032std::optional<APInt> ForOp::getStaticTripCount() {
1041LogicalResult ForallOp::verify() {
1042 unsigned numLoops = getRank();
1044 if (getNumResults() != getOutputs().size())
1046 << getNumResults() <<
" results, but has only "
1047 << getOutputs().size() <<
" outputs";
1050 auto *body = getBody();
1052 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1053 for (int64_t i = 0; i < numLoops; ++i)
1056 << i <<
"-th block argument to be an index";
1057 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1060 << i <<
"-th output and corresponding block argument";
1061 if (getMapping().has_value() && !getMapping()->empty()) {
1062 if (getDeviceMappingAttrs().size() != numLoops)
1063 return emitOpError() <<
"mapping attribute size must match op rank";
1064 if (
failed(getDeviceMaskingAttr()))
1066 <<
" supports at most one device masking attribute";
1070 Operation *op = getOperation();
1072 getStaticLowerBound(),
1073 getDynamicLowerBound())))
1076 getStaticUpperBound(),
1077 getDynamicUpperBound())))
1080 getStaticStep(), getDynamicStep())))
1086void ForallOp::print(OpAsmPrinter &p) {
1087 Operation *op = getOperation();
1088 p <<
" (" << getInductionVars();
1089 if (isNormalized()) {
1110 if (!getRegionOutArgs().empty())
1111 p <<
"-> (" << getResultTypes() <<
") ";
1112 p.printRegion(getRegion(),
1114 getNumResults() > 0);
1115 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1116 getStaticLowerBoundAttrName(),
1117 getStaticUpperBoundAttrName(),
1118 getStaticStepAttrName()});
1121ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &
result) {
1123 auto indexType =
b.getIndexType();
1128 SmallVector<OpAsmParser::Argument, 4> ivs;
1133 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1143 unsigned numLoops = ivs.size();
1144 staticLbs =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1145 staticSteps =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1174 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1175 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1178 if (outOperands.size() !=
result.types.size())
1180 "mismatch between out operands and types");
1189 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1190 std::unique_ptr<Region> region = std::make_unique<Region>();
1191 for (
auto &iv : ivs) {
1192 iv.type =
b.getIndexType();
1193 regionArgs.push_back(iv);
1195 for (
const auto &it : llvm::enumerate(regionOutArgs)) {
1196 auto &out = it.value();
1197 out.type =
result.types[it.index()];
1198 regionArgs.push_back(out);
1204 ForallOp::ensureTerminator(*region,
b,
result.location);
1205 result.addRegion(std::move(region));
1211 result.addAttribute(
"staticLowerBound", staticLbs);
1212 result.addAttribute(
"staticUpperBound", staticUbs);
1213 result.addAttribute(
"staticStep", staticSteps);
1214 result.addAttribute(
"operandSegmentSizes",
1216 {static_cast<int32_t>(dynamicLbs.size()),
1217 static_cast<int32_t>(dynamicUbs.size()),
1218 static_cast<int32_t>(dynamicSteps.size()),
1219 static_cast<int32_t>(outOperands.size())}));
1224void ForallOp::build(
1225 mlir::OpBuilder &
b, mlir::OperationState &
result,
1226 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1227 ArrayRef<OpFoldResult> steps,
ValueRange outputs,
1228 std::optional<ArrayAttr> mapping,
1230 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1231 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1236 result.addOperands(dynamicLbs);
1237 result.addOperands(dynamicUbs);
1238 result.addOperands(dynamicSteps);
1239 result.addOperands(outputs);
1242 result.addAttribute(getStaticLowerBoundAttrName(
result.name),
1243 b.getDenseI64ArrayAttr(staticLbs));
1244 result.addAttribute(getStaticUpperBoundAttrName(
result.name),
1245 b.getDenseI64ArrayAttr(staticUbs));
1246 result.addAttribute(getStaticStepAttrName(
result.name),
1247 b.getDenseI64ArrayAttr(staticSteps));
1249 "operandSegmentSizes",
1250 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1251 static_cast<int32_t>(dynamicUbs.size()),
1252 static_cast<int32_t>(dynamicSteps.size()),
1253 static_cast<int32_t>(outputs.size())}));
1254 if (mapping.has_value()) {
1255 result.addAttribute(ForallOp::getMappingAttrName(
result.name),
1259 Region *bodyRegion =
result.addRegion();
1260 OpBuilder::InsertionGuard g(
b);
1261 b.createBlock(bodyRegion);
1266 SmallVector<Type>(lbs.size(),
b.getIndexType()),
1267 SmallVector<Location>(staticLbs.size(),
result.location));
1270 SmallVector<Location>(outputs.size(),
result.location));
1272 b.setInsertionPointToStart(&bodyBlock);
1273 if (!bodyBuilderFn) {
1274 ForallOp::ensureTerminator(*bodyRegion,
b,
result.location);
1281void ForallOp::build(
1282 mlir::OpBuilder &
b, mlir::OperationState &
result,
1283 ArrayRef<OpFoldResult> ubs,
ValueRange outputs,
1284 std::optional<ArrayAttr> mapping,
1286 unsigned numLoops = ubs.size();
1287 SmallVector<OpFoldResult> lbs(numLoops,
b.getIndexAttr(0));
1288 SmallVector<OpFoldResult> steps(numLoops,
b.getIndexAttr(1));
1289 build(
b,
result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1293bool ForallOp::isNormalized() {
1294 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1295 return llvm::all_of(results, [&](OpFoldResult ofr) {
1297 return intValue.has_value() && intValue == val;
1300 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1303InParallelOp ForallOp::getTerminator() {
1304 return cast<InParallelOp>(getBody()->getTerminator());
1307SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1308 SmallVector<Operation *> storeOps;
1309 for (Operation *user : bbArg.
getUsers()) {
1310 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1311 storeOps.push_back(parallelOp);
1317SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1318 SmallVector<DeviceMappingAttrInterface> res;
1321 for (
auto attr : getMapping()->getValue()) {
1322 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1329FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1330 DeviceMaskingAttrInterface res;
1333 for (
auto attr : getMapping()->getValue()) {
1334 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1343bool ForallOp::usesLinearMapping() {
1344 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1347 return ifaces.front().isLinearMapping();
1350std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1351 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1355std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1357 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(),
b);
1361std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1363 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(),
b);
1367std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1373 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1376 assert(tidxArg.getOwner() &&
"unlinked block argument");
1377 auto *containingOp = tidxArg.getOwner()->getParentOp();
1378 return dyn_cast<ForallOp>(containingOp);
1386 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1388 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1392 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1395 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1400class ForallOpControlOperandsFolder :
public OpRewritePattern<ForallOp> {
1402 using OpRewritePattern<ForallOp>::OpRewritePattern;
1404 LogicalResult matchAndRewrite(ForallOp op,
1405 PatternRewriter &rewriter)
const override {
1406 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1407 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1408 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1415 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1416 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1419 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1420 op.setStaticLowerBound(staticLowerBound);
1424 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1425 op.setStaticUpperBound(staticUpperBound);
1428 op.getDynamicStepMutable().assign(dynamicStep);
1429 op.setStaticStep(staticStep);
1431 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1433 {static_cast<int32_t>(dynamicLowerBound.size()),
1434 static_cast<int32_t>(dynamicUpperBound.size()),
1435 static_cast<int32_t>(dynamicStep.size()),
1436 static_cast<int32_t>(op.getNumResults())}));
1515struct ForallOpIterArgsFolder :
public OpRewritePattern<ForallOp> {
1516 using OpRewritePattern<ForallOp>::OpRewritePattern;
1518 LogicalResult matchAndRewrite(ForallOp forallOp,
1519 PatternRewriter &rewriter)
const final {
1530 SmallVector<Value> resultsToDelete;
1531 SmallVector<Value> outsToDelete;
1532 SmallVector<BlockArgument> blockArgsToDelete;
1533 SmallVector<Value> newOuts;
1534 BitVector resultIndicesToDelete(forallOp.getNumResults(),
false);
1535 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1537 for (OpResult
result : forallOp.getResults()) {
1538 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1539 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1540 if (
result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1541 resultsToDelete.push_back(
result);
1542 outsToDelete.push_back(opOperand->
get());
1543 blockArgsToDelete.push_back(blockArg);
1544 resultIndicesToDelete[
result.getResultNumber()] =
true;
1547 newOuts.push_back(opOperand->
get());
1553 if (resultsToDelete.empty())
1558 for (
auto blockArg : blockArgsToDelete) {
1559 SmallVector<Operation *> combiningOps =
1560 forallOp.getCombiningOps(blockArg);
1561 for (Operation *combiningOp : combiningOps)
1562 rewriter.
eraseOp(combiningOp);
1564 for (
auto [blockArg,
result, out] :
1565 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1571 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1576 auto newForallOp = cast<scf::ForallOp>(
1578 newForallOp.getOutputsMutable().assign(newOuts);
1584struct ForallOpSingleOrZeroIterationDimsFolder
1585 :
public OpRewritePattern<ForallOp> {
1586 using OpRewritePattern<ForallOp>::OpRewritePattern;
1588 LogicalResult matchAndRewrite(ForallOp op,
1589 PatternRewriter &rewriter)
const override {
1591 if (op.getMapping().has_value() && !op.getMapping()->empty())
1593 Location loc = op.getLoc();
1596 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1599 for (
auto [lb, ub, step, iv] :
1600 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1601 op.getMixedStep(), op.getInductionVars())) {
1602 auto numIterations =
1604 if (numIterations.has_value()) {
1606 if (*numIterations == 0) {
1607 rewriter.
replaceOp(op, op.getOutputs());
1612 if (*numIterations == 1) {
1617 newMixedLowerBounds.push_back(lb);
1618 newMixedUpperBounds.push_back(ub);
1619 newMixedSteps.push_back(step);
1623 if (newMixedLowerBounds.empty()) {
1629 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1631 op,
"no dimensions have 0 or 1 iterations");
1636 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1637 newMixedUpperBounds, newMixedSteps,
1638 op.getOutputs(), std::nullopt,
nullptr);
1639 newOp.getBodyRegion().getBlocks().clear();
1643 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1644 newOp.getStaticLowerBoundAttrName(),
1645 newOp.getStaticUpperBoundAttrName(),
1646 newOp.getStaticStepAttrName()};
1647 for (
const auto &namedAttr : op->getAttrs()) {
1648 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1651 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1655 newOp.getRegion().begin(), mapping);
1656 rewriter.
replaceOp(op, newOp.getResults());
1662struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1663 using OpRewritePattern<ForallOp>::OpRewritePattern;
1665 LogicalResult matchAndRewrite(ForallOp op,
1666 PatternRewriter &rewriter)
const override {
1667 Location loc = op.getLoc();
1669 for (
auto [lb, ub, step, iv] :
1670 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1671 op.getMixedStep(), op.getInductionVars())) {
1674 auto numIterations =
1676 if (!numIterations.has_value() || numIterations.value() != 1) {
1687struct FoldTensorCastOfOutputIntoForallOp
1688 :
public OpRewritePattern<scf::ForallOp> {
1689 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1696 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1697 PatternRewriter &rewriter)
const final {
1698 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1699 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1700 for (
auto en : llvm::enumerate(newOutputTensors)) {
1701 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1708 castOp.getSource().getType())) {
1712 tensorCastProducers[en.index()] =
1713 TypeCast{castOp.getSource().getType(), castOp.getType()};
1714 newOutputTensors[en.index()] = castOp.getSource();
1717 if (tensorCastProducers.empty())
1721 Location loc = forallOp.getLoc();
1722 auto newForallOp = ForallOp::create(
1723 rewriter, loc, forallOp.getMixedLowerBound(),
1724 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1725 newOutputTensors, forallOp.getMapping(),
1726 [&](OpBuilder nestedBuilder, Location nestedLoc,
ValueRange bbArgs) {
1727 auto castBlockArgs =
1728 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1729 for (auto [index, cast] : tensorCastProducers) {
1730 Value &oldTypeBBArg = castBlockArgs[index];
1731 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1732 cast.dstType, oldTypeBBArg);
1736 SmallVector<Value> ivsBlockArgs =
1737 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1738 ivsBlockArgs.append(castBlockArgs);
1740 bbArgs.front().getParentBlock(), ivsBlockArgs);
1746 auto terminator = newForallOp.getTerminator();
1747 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1748 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1749 if (
auto parallelCombingingOp =
1750 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1751 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1757 SmallVector<Value> castResults = newForallOp.getResults();
1758 for (
auto &item : tensorCastProducers) {
1759 Value &oldTypeResult = castResults[item.first];
1760 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1763 rewriter.
replaceOp(forallOp, castResults);
1770void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1771 MLIRContext *context) {
1772 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1773 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1774 ForallOpSingleOrZeroIterationDimsFolder,
1775 ForallOpReplaceConstantInductionVar>(context);
1778void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1779 SmallVectorImpl<RegionSuccessor> ®ions) {
1786 regions.push_back(RegionSuccessor(&getRegion()));
1803void InParallelOp::build(OpBuilder &
b, OperationState &
result) {
1804 OpBuilder::InsertionGuard g(
b);
1805 Region *bodyRegion =
result.addRegion();
1806 b.createBlock(bodyRegion);
1809LogicalResult InParallelOp::verify() {
1810 scf::ForallOp forallOp =
1811 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1813 return this->
emitOpError(
"expected forall op parent");
1815 for (Operation &op : getRegion().front().getOperations()) {
1816 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1817 if (!parallelCombiningOp) {
1818 return this->
emitOpError(
"expected only ParallelCombiningOpInterface")
1823 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1824 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1825 for (OpOperand &dest : dests) {
1826 if (!llvm::is_contained(regionOutArgs, dest.get()))
1827 return op.emitOpError(
"may only insert into an output block argument");
1834void InParallelOp::print(OpAsmPrinter &p) {
1842ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
1845 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1846 std::unique_ptr<Region> region = std::make_unique<Region>();
1850 if (region->empty())
1851 OpBuilder(builder.
getContext()).createBlock(region.get());
1852 result.addRegion(std::move(region));
1860OpResult InParallelOp::getParentResult(int64_t idx) {
1861 return getOperation()->getParentOp()->getResult(idx);
1864SmallVector<BlockArgument> InParallelOp::getDests() {
1865 SmallVector<BlockArgument> updatedDests;
1866 for (Operation &yieldingOp : getYieldingOps()) {
1867 auto parallelCombiningOp =
1868 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1869 if (!parallelCombiningOp)
1871 for (OpOperand &updatedOperand :
1872 parallelCombiningOp.getUpdatedDestinations())
1873 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1875 return updatedDests;
1878llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1879 return getRegion().front().getOperations();
1887 assert(a &&
"expected non-empty operation");
1888 assert(
b &&
"expected non-empty operation");
1893 if (ifOp->isProperAncestor(
b))
1896 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1897 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*
b));
1899 ifOp = ifOp->getParentOfType<IfOp>();
1907IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1908 IfOp::Adaptor adaptor,
1910 if (adaptor.getRegions().empty())
1912 Region *r = &adaptor.getThenRegion();
1918 auto yieldOp = llvm::dyn_cast<YieldOp>(
b.back());
1921 TypeRange types = yieldOp.getOperandTypes();
1922 llvm::append_range(inferredReturnTypes, types);
1928 return build(builder,
result, resultTypes, cond,
false,
1932void IfOp::build(OpBuilder &builder, OperationState &
result,
1933 TypeRange resultTypes, Value cond,
bool addThenBlock,
1934 bool addElseBlock) {
1935 assert((!addElseBlock || addThenBlock) &&
1936 "must not create else block w/o then block");
1937 result.addTypes(resultTypes);
1938 result.addOperands(cond);
1941 OpBuilder::InsertionGuard guard(builder);
1942 Region *thenRegion =
result.addRegion();
1945 Region *elseRegion =
result.addRegion();
1950void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
1951 bool withElseRegion) {
1955void IfOp::build(OpBuilder &builder, OperationState &
result,
1956 TypeRange resultTypes, Value cond,
bool withElseRegion) {
1957 result.addTypes(resultTypes);
1958 result.addOperands(cond);
1961 OpBuilder::InsertionGuard guard(builder);
1962 Region *thenRegion =
result.addRegion();
1964 if (resultTypes.empty())
1965 IfOp::ensureTerminator(*thenRegion, builder,
result.location);
1968 Region *elseRegion =
result.addRegion();
1969 if (withElseRegion) {
1971 if (resultTypes.empty())
1972 IfOp::ensureTerminator(*elseRegion, builder,
result.location);
1976void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
1978 function_ref<
void(OpBuilder &, Location)> elseBuilder) {
1979 assert(thenBuilder &&
"the builder callback for 'then' must be present");
1980 result.addOperands(cond);
1983 OpBuilder::InsertionGuard guard(builder);
1984 Region *thenRegion =
result.addRegion();
1986 thenBuilder(builder,
result.location);
1989 Region *elseRegion =
result.addRegion();
1992 elseBuilder(builder,
result.location);
1996 SmallVector<Type> inferredReturnTypes;
1998 auto attrDict = DictionaryAttr::get(ctx,
result.attributes);
1999 if (succeeded(inferReturnTypes(ctx, std::nullopt,
result.operands, attrDict,
2001 inferredReturnTypes))) {
2002 result.addTypes(inferredReturnTypes);
2006LogicalResult IfOp::verify() {
2007 if (getNumResults() != 0 && getElseRegion().empty())
2008 return emitOpError(
"must have an else block if defining values");
2012ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
2014 result.regions.reserve(2);
2015 Region *thenRegion =
result.addRegion();
2016 Region *elseRegion =
result.addRegion();
2019 OpAsmParser::UnresolvedOperand cond;
2045void IfOp::print(OpAsmPrinter &p) {
2046 bool printBlockTerminators =
false;
2048 p <<
" " << getCondition();
2049 if (!getResults().empty()) {
2050 p <<
" -> (" << getResultTypes() <<
")";
2052 printBlockTerminators =
true;
2057 printBlockTerminators);
2060 auto &elseRegion = getElseRegion();
2061 if (!elseRegion.
empty()) {
2065 printBlockTerminators);
2071void IfOp::getSuccessorRegions(RegionBranchPoint point,
2072 SmallVectorImpl<RegionSuccessor> ®ions) {
2080 regions.push_back(RegionSuccessor(&getThenRegion()));
2083 Region *elseRegion = &this->getElseRegion();
2084 if (elseRegion->
empty())
2087 regions.push_back(RegionSuccessor(elseRegion));
2090ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
2095void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2096 SmallVectorImpl<RegionSuccessor> ®ions) {
2097 FoldAdaptor adaptor(operands, *
this);
2098 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2099 if (!boolAttr || boolAttr.getValue())
2100 regions.emplace_back(&getThenRegion());
2103 if (!boolAttr || !boolAttr.getValue()) {
2104 if (!getElseRegion().empty())
2105 regions.emplace_back(&getElseRegion());
2111LogicalResult IfOp::fold(FoldAdaptor adaptor,
2112 SmallVectorImpl<OpFoldResult> &results) {
2114 if (getElseRegion().empty())
2117 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2124 getConditionMutable().assign(xorStmt.getLhs());
2125 Block *thenBlock = &getThenRegion().front();
2128 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2129 getElseRegion().getBlocks());
2130 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2131 getThenRegion().getBlocks(), thenBlock);
2135void IfOp::getRegionInvocationBounds(
2136 ArrayRef<Attribute> operands,
2137 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2138 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2141 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2142 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2145 invocationBounds.assign(2, {0, 1});
2152struct ConvertTrivialIfToSelect :
public OpRewritePattern<IfOp> {
2153 using OpRewritePattern<IfOp>::OpRewritePattern;
2155 LogicalResult matchAndRewrite(IfOp op,
2156 PatternRewriter &rewriter)
const override {
2157 if (op->getNumResults() == 0)
2160 auto cond = op.getCondition();
2161 auto thenYieldArgs = op.thenYield().getOperands();
2162 auto elseYieldArgs = op.elseYield().getOperands();
2164 SmallVector<Type> nonHoistable;
2165 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2166 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2167 &op.getElseRegion() == falseVal.getParentRegion())
2168 nonHoistable.push_back(trueVal.getType());
2172 if (nonHoistable.size() == op->getNumResults())
2175 IfOp
replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2179 replacement.getThenRegion().takeBody(op.getThenRegion());
2180 replacement.getElseRegion().takeBody(op.getElseRegion());
2182 SmallVector<Value> results(op->getNumResults());
2183 assert(thenYieldArgs.size() == results.size());
2184 assert(elseYieldArgs.size() == results.size());
2186 SmallVector<Value> trueYields;
2187 SmallVector<Value> falseYields;
2189 for (
const auto &it :
2190 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2191 Value trueVal = std::get<0>(it.value());
2192 Value falseVal = std::get<1>(it.value());
2195 results[it.index()] =
replacement.getResult(trueYields.size());
2196 trueYields.push_back(trueVal);
2197 falseYields.push_back(falseVal);
2198 }
else if (trueVal == falseVal)
2199 results[it.index()] = trueVal;
2201 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2202 cond, trueVal, falseVal);
2229struct ConditionPropagation :
public OpRewritePattern<IfOp> {
2230 using OpRewritePattern<IfOp>::OpRewritePattern;
2233 enum class Parent { Then, Else,
None };
2238 static Parent getParentType(Region *toCheck, IfOp op,
2240 Region *endRegion) {
2241 SmallVector<Region *> seen;
2242 while (toCheck != endRegion) {
2243 auto found = cache.find(toCheck);
2244 if (found != cache.end())
2245 return found->second;
2246 seen.push_back(toCheck);
2247 if (&op.getThenRegion() == toCheck) {
2248 for (Region *region : seen)
2249 cache[region] = Parent::Then;
2250 return Parent::Then;
2252 if (&op.getElseRegion() == toCheck) {
2253 for (Region *region : seen)
2254 cache[region] = Parent::Else;
2255 return Parent::Else;
2260 for (Region *region : seen)
2261 cache[region] = Parent::None;
2262 return Parent::None;
2265 LogicalResult matchAndRewrite(IfOp op,
2266 PatternRewriter &rewriter)
const override {
2277 Value constantTrue =
nullptr;
2278 Value constantFalse =
nullptr;
2281 for (OpOperand &use :
2282 llvm::make_early_inc_range(op.getCondition().getUses())) {
2285 case Parent::Then: {
2289 constantTrue = arith::ConstantOp::create(
2293 [&]() { use.set(constantTrue); });
2296 case Parent::Else: {
2300 constantFalse = arith::ConstantOp::create(
2304 [&]() { use.set(constantFalse); });
2352struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2353 using OpRewritePattern<IfOp>::OpRewritePattern;
2355 LogicalResult matchAndRewrite(IfOp op,
2356 PatternRewriter &rewriter)
const override {
2358 if (op.getNumResults() == 0)
2362 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2364 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2367 op.getOperation()->getIterator());
2370 for (
auto [trueResult, falseResult, opResult] :
2371 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2373 if (trueResult == falseResult) {
2374 if (!opResult.use_empty()) {
2375 opResult.replaceAllUsesWith(trueResult);
2381 BoolAttr trueYield, falseYield;
2386 bool trueVal = trueYield.
getValue();
2387 bool falseVal = falseYield.
getValue();
2388 if (!trueVal && falseVal) {
2389 if (!opResult.use_empty()) {
2390 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2391 Value notCond = arith::XOrIOp::create(
2392 rewriter, op.getLoc(), op.getCondition(),
2398 opResult.replaceAllUsesWith(notCond);
2402 if (trueVal && !falseVal) {
2403 if (!opResult.use_empty()) {
2404 opResult.replaceAllUsesWith(op.getCondition());
2434struct CombineIfs :
public OpRewritePattern<IfOp> {
2435 using OpRewritePattern<IfOp>::OpRewritePattern;
2437 LogicalResult matchAndRewrite(IfOp nextIf,
2438 PatternRewriter &rewriter)
const override {
2439 Block *parent = nextIf->getBlock();
2440 if (nextIf == &parent->
front())
2443 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2451 Block *nextThen =
nullptr;
2452 Block *nextElse =
nullptr;
2453 if (nextIf.getCondition() == prevIf.getCondition()) {
2454 nextThen = nextIf.thenBlock();
2455 if (!nextIf.getElseRegion().empty())
2456 nextElse = nextIf.elseBlock();
2458 if (arith::XOrIOp notv =
2459 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2460 if (notv.getLhs() == prevIf.getCondition() &&
2462 nextElse = nextIf.thenBlock();
2463 if (!nextIf.getElseRegion().empty())
2464 nextThen = nextIf.elseBlock();
2467 if (arith::XOrIOp notv =
2468 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2469 if (notv.getLhs() == nextIf.getCondition() &&
2471 nextElse = nextIf.thenBlock();
2472 if (!nextIf.getElseRegion().empty())
2473 nextThen = nextIf.elseBlock();
2477 if (!nextThen && !nextElse)
2480 SmallVector<Value> prevElseYielded;
2481 if (!prevIf.getElseRegion().empty())
2482 prevElseYielded = prevIf.elseYield().getOperands();
2485 for (
auto it : llvm::zip(prevIf.getResults(),
2486 prevIf.thenYield().getOperands(), prevElseYielded))
2487 for (OpOperand &use :
2488 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2492 use.
set(std::get<1>(it));
2497 use.
set(std::get<2>(it));
2502 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2503 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2505 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2506 prevIf.getCondition(),
false);
2507 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2510 combinedIf.getThenRegion(),
2511 combinedIf.getThenRegion().begin());
2514 YieldOp thenYield = combinedIf.thenYield();
2515 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2516 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2519 SmallVector<Value> mergedYields(thenYield.getOperands());
2520 llvm::append_range(mergedYields, thenYield2.getOperands());
2521 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2527 combinedIf.getElseRegion(),
2528 combinedIf.getElseRegion().begin());
2531 if (combinedIf.getElseRegion().empty()) {
2533 combinedIf.getElseRegion(),
2534 combinedIf.getElseRegion().
begin());
2536 YieldOp elseYield = combinedIf.elseYield();
2537 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2538 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2542 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2543 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2545 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2551 SmallVector<Value> prevValues;
2552 SmallVector<Value> nextValues;
2553 for (
const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2554 if (pair.index() < prevIf.getNumResults())
2555 prevValues.push_back(pair.value());
2557 nextValues.push_back(pair.value());
2566struct RemoveEmptyElseBranch :
public OpRewritePattern<IfOp> {
2567 using OpRewritePattern<IfOp>::OpRewritePattern;
2569 LogicalResult matchAndRewrite(IfOp ifOp,
2570 PatternRewriter &rewriter)
const override {
2572 if (ifOp.getNumResults())
2574 Block *elseBlock = ifOp.elseBlock();
2575 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2579 newIfOp.getThenRegion().begin());
2601struct CombineNestedIfs :
public OpRewritePattern<IfOp> {
2602 using OpRewritePattern<IfOp>::OpRewritePattern;
2604 LogicalResult matchAndRewrite(IfOp op,
2605 PatternRewriter &rewriter)
const override {
2606 auto nestedOps = op.thenBlock()->without_terminator();
2608 if (!llvm::hasSingleElement(nestedOps))
2612 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2615 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2619 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2622 SmallVector<Value> thenYield(op.thenYield().getOperands());
2623 SmallVector<Value> elseYield;
2625 llvm::append_range(elseYield, op.elseYield().getOperands());
2629 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2638 for (
const auto &tup : llvm::enumerate(thenYield)) {
2639 if (tup.value().getDefiningOp() == nestedIf) {
2640 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2641 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2642 elseYield[tup.index()]) {
2647 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2660 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2663 elseYieldsToUpgradeToSelect.push_back(tup.index());
2666 Location loc = op.getLoc();
2667 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2668 nestedIf.getCondition());
2669 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2672 SmallVector<Value> results;
2673 llvm::append_range(results, newIf.getResults());
2676 for (
auto idx : elseYieldsToUpgradeToSelect)
2678 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2679 thenYield[idx], elseYield[idx]);
2681 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2684 if (!elseYield.empty()) {
2687 YieldOp::create(rewriter, loc, elseYield);
2696void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2697 MLIRContext *context) {
2698 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2699 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2700 ReplaceIfYieldWithConditionOrValue>(context);
2702 results, IfOp::getOperationName());
2704 IfOp::getOperationName());
2707Block *IfOp::thenBlock() {
return &getThenRegion().back(); }
2708YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2709Block *IfOp::elseBlock() {
2710 Region &r = getElseRegion();
2715YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2721void ParallelOp::build(
2726 result.addOperands(lowerBounds);
2727 result.addOperands(upperBounds);
2728 result.addOperands(steps);
2729 result.addOperands(initVals);
2731 ParallelOp::getOperandSegmentSizeAttr(),
2733 static_cast<int32_t>(upperBounds.size()),
2734 static_cast<int32_t>(steps.size()),
2735 static_cast<int32_t>(initVals.size())}));
2738 OpBuilder::InsertionGuard guard(builder);
2739 unsigned numIVs = steps.size();
2740 SmallVector<Type, 8> argTypes(numIVs, builder.
getIndexType());
2741 SmallVector<Location, 8> argLocs(numIVs,
result.location);
2742 Region *bodyRegion =
result.addRegion();
2745 if (bodyBuilderFn) {
2747 bodyBuilderFn(builder,
result.location,
2752 if (initVals.empty())
2753 ParallelOp::ensureTerminator(*bodyRegion, builder,
result.location);
2756void ParallelOp::build(
2763 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2766 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2770 wrapper = wrappedBuilderFn;
2776LogicalResult ParallelOp::verify() {
2781 if (stepValues.empty())
2783 "needs at least one tuple element for lowerBound, upperBound and step");
2786 for (Value stepValue : stepValues)
2789 return emitOpError(
"constant step operand must be positive");
2793 Block *body = getBody();
2795 return emitOpError() <<
"expects the same number of induction variables: "
2797 <<
" as bound and step values: " << stepValues.size();
2799 if (!arg.getType().isIndex())
2801 "expects arguments for the induction variable to be of index type");
2805 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2810 auto resultsSize = getResults().size();
2811 auto reductionsSize = reduceOp.getReductions().size();
2812 auto initValsSize = getInitVals().size();
2813 if (resultsSize != reductionsSize)
2814 return emitOpError() <<
"expects number of results: " << resultsSize
2815 <<
" to be the same as number of reductions: "
2817 if (resultsSize != initValsSize)
2818 return emitOpError() <<
"expects number of results: " << resultsSize
2819 <<
" to be the same as number of initial values: "
2821 if (reduceOp.getNumOperands() != initValsSize)
2826 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2827 auto resultType = getOperation()->getResult(i).getType();
2828 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2829 if (resultType != reductionOperandType)
2830 return reduceOp.emitOpError()
2831 <<
"expects type of " << i
2832 <<
"-th reduction operand: " << reductionOperandType
2833 <<
" to be the same as the " << i
2834 <<
"-th result type: " << resultType;
2839ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
2842 SmallVector<OpAsmParser::Argument, 4> ivs;
2847 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2854 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2862 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2870 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2881 Region *body =
result.addRegion();
2882 for (
auto &iv : ivs)
2889 ParallelOp::getOperandSegmentSizeAttr(),
2891 static_cast<int32_t>(upper.size()),
2892 static_cast<int32_t>(steps.size()),
2893 static_cast<int32_t>(initVals.size())}));
2902 ParallelOp::ensureTerminator(*body, builder,
result.location);
2906void ParallelOp::print(OpAsmPrinter &p) {
2907 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2908 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2909 if (!getInitVals().empty())
2910 p <<
" init (" << getInitVals() <<
")";
2915 (*this)->getAttrs(),
2916 ParallelOp::getOperandSegmentSizeAttr());
2919SmallVector<Region *> ParallelOp::getLoopRegions() {
return {&getRegion()}; }
2921std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2922 return SmallVector<Value>{getBody()->getArguments()};
2925std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2929std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2933std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2938 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2940 return ParallelOp();
2941 assert(ivArg.getOwner() &&
"unlinked block argument");
2942 auto *containingOp = ivArg.getOwner()->getParentOp();
2943 return dyn_cast<ParallelOp>(containingOp);
2948struct ParallelOpSingleOrZeroIterationDimsFolder
2952 LogicalResult matchAndRewrite(ParallelOp op,
2959 for (
auto [lb,
ub, step, iv] :
2960 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2961 op.getInductionVars())) {
2962 auto numIterations =
2964 if (numIterations.has_value()) {
2966 if (*numIterations == 0) {
2967 rewriter.
replaceOp(op, op.getInitVals());
2972 if (*numIterations == 1) {
2977 newLowerBounds.push_back(lb);
2978 newUpperBounds.push_back(ub);
2979 newSteps.push_back(step);
2982 if (newLowerBounds.size() == op.getLowerBound().size())
2985 if (newLowerBounds.empty()) {
2988 SmallVector<Value> results;
2989 results.reserve(op.getInitVals().size());
2990 for (
auto &bodyOp : op.getBody()->without_terminator())
2991 rewriter.
clone(bodyOp, mapping);
2992 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
2993 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
2994 Block &reduceBlock = reduceOp.getReductions()[i].front();
2995 auto initValIndex = results.size();
2996 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3000 rewriter.
clone(reduceBodyOp, mapping);
3003 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3004 results.push_back(
result);
3012 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3013 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3019 newOp.getRegion().begin(), mapping);
3020 rewriter.
replaceOp(op, newOp.getResults());
3025struct MergeNestedParallelLoops :
public OpRewritePattern<ParallelOp> {
3026 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3028 LogicalResult matchAndRewrite(ParallelOp op,
3029 PatternRewriter &rewriter)
const override {
3030 Block &outerBody = *op.getBody();
3034 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3039 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3040 llvm::is_contained(innerOp.getUpperBound(), val) ||
3041 llvm::is_contained(innerOp.getStep(), val))
3045 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3048 auto bodyBuilder = [&](OpBuilder &builder, Location ,
3050 Block &innerBody = *innerOp.getBody();
3051 assert(iterVals.size() ==
3059 builder.
clone(op, mapping);
3062 auto concatValues = [](
const auto &first,
const auto &second) {
3063 SmallVector<Value> ret;
3064 ret.reserve(first.size() + second.size());
3065 ret.assign(first.begin(), first.end());
3066 ret.append(second.begin(), second.end());
3070 auto newLowerBounds =
3071 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3072 auto newUpperBounds =
3073 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3074 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3085void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3086 MLIRContext *context) {
3088 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3097void ParallelOp::getSuccessorRegions(
3098 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3102 regions.push_back(RegionSuccessor(&getRegion()));
3110void ReduceOp::build(OpBuilder &builder, OperationState &
result) {}
3112void ReduceOp::build(OpBuilder &builder, OperationState &
result,
3114 result.addOperands(operands);
3115 for (Value v : operands) {
3116 OpBuilder::InsertionGuard guard(builder);
3117 Region *bodyRegion =
result.addRegion();
3119 ArrayRef<Type>{v.getType(), v.getType()},
3124LogicalResult ReduceOp::verifyRegions() {
3125 if (getReductions().size() != getOperands().size())
3126 return emitOpError() <<
"expects number of reduction regions: "
3127 << getReductions().size()
3128 <<
" to be the same as number of reduction operands: "
3129 << getOperands().size();
3132 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3133 auto type = getOperands()[i].getType();
3134 Block &block = getReductions()[i].front();
3136 return emitOpError() << i <<
"-th reduction has an empty body";
3138 llvm::any_of(block.
getArguments(), [&](
const BlockArgument &arg) {
3139 return arg.getType() != type;
3141 return emitOpError() <<
"expected two block arguments with type " << type
3142 <<
" in the " << i <<
"-th reduction region";
3146 return emitOpError(
"reduction bodies must be terminated with an "
3147 "'scf.reduce.return' op");
3154ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3156 return MutableOperandRange(getOperation(), 0, 0);
3163LogicalResult ReduceReturnOp::verify() {
3166 Block *reductionBody = getOperation()->getBlock();
3168 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3170 if (expectedResultType != getResult().
getType())
3171 return emitOpError() <<
"must have type " << expectedResultType
3172 <<
" (the type of the reduction inputs)";
3180void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3181 ::mlir::OperationState &odsState,
TypeRange resultTypes,
3182 ValueRange inits, BodyBuilderFn beforeBuilder,
3183 BodyBuilderFn afterBuilder) {
3187 OpBuilder::InsertionGuard guard(odsBuilder);
3190 SmallVector<Location, 4> beforeArgLocs;
3191 beforeArgLocs.reserve(inits.size());
3192 for (Value operand : inits) {
3193 beforeArgLocs.push_back(operand.getLoc());
3196 Region *beforeRegion = odsState.
addRegion();
3198 inits.getTypes(), beforeArgLocs);
3203 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.
location);
3205 Region *afterRegion = odsState.
addRegion();
3207 resultTypes, afterArgLocs);
3213ConditionOp WhileOp::getConditionOp() {
3214 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3217YieldOp WhileOp::getYieldOp() {
3218 return cast<YieldOp>(getAfterBody()->getTerminator());
3221std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3222 return getYieldOp().getResultsMutable();
3226 return getBeforeBody()->getArguments();
3230 return getAfterBody()->getArguments();
3234 return getBeforeArguments();
3237OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3239 "WhileOp is expected to branch only to the first region");
3243void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3244 SmallVectorImpl<RegionSuccessor> ®ions) {
3247 regions.emplace_back(&getBefore());
3251 assert(llvm::is_contained(
3252 {&getAfter(), &getBefore()},
3254 "there are only two regions in a WhileOp");
3258 regions.emplace_back(&getBefore());
3263 regions.emplace_back(&getAfter());
3266ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
3268 return getOperation()->getResults();
3269 if (successor == &getBefore())
3270 return getBefore().getArguments();
3271 if (successor == &getAfter())
3272 return getAfter().getArguments();
3273 llvm_unreachable(
"invalid region successor");
3276SmallVector<Region *> WhileOp::getLoopRegions() {
3277 return {&getBefore(), &getAfter()};
3287ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
3288 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3289 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3290 Region *before =
result.addRegion();
3291 Region *after =
result.addRegion();
3293 OptionalParseResult listResult =
3298 FunctionType functionType;
3303 result.addTypes(functionType.getResults());
3305 if (functionType.getNumInputs() != operands.size()) {
3307 <<
"expected as many input types as operands " <<
"(expected "
3308 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3318 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3319 regionArgs[i].type = functionType.getInput(i);
3321 return failure(parser.
parseRegion(*before, regionArgs) ||
3327void scf::WhileOp::print(OpAsmPrinter &p) {
3341template <
typename OpTy>
3344 if (left.size() != right.size())
3345 return op.emitOpError(
"expects the same number of ") << message;
3347 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3348 if (left[i] != right[i]) {
3351 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3352 <<
" and " << right[i];
3360LogicalResult scf::WhileOp::verify() {
3363 "expects the 'before' region to terminate with 'scf.condition'");
3364 if (!beforeTerminator)
3369 "expects the 'after' region to terminate with 'scf.yield'");
3370 return success(afterTerminator !=
nullptr);
3407struct WhileMoveIfDown :
public OpRewritePattern<scf::WhileOp> {
3408 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3410 LogicalResult matchAndRewrite(scf::WhileOp op,
3411 PatternRewriter &rewriter)
const override {
3412 auto conditionOp = op.getConditionOp();
3420 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3426 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3427 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3430 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3431 *ifOp->user_begin() == conditionOp)) &&
3432 "ifOp has unexpected uses");
3434 Location loc = op.getLoc();
3438 for (
auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3439 auto it = llvm::find(ifOp->getResults(), arg);
3440 if (it != ifOp->getResults().end()) {
3441 size_t ifOpIdx = it.getIndex();
3442 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3443 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3453 if (&op.getBefore() == operand->get().getParentRegion())
3454 additionalUsedValuesSet.insert(operand->get());
3458 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3459 auto additionalValueTypes = llvm::map_to_vector(
3460 additionalUsedValues, [](Value val) {
return val.
getType(); });
3461 size_t additionalValueSize = additionalUsedValues.size();
3462 SmallVector<Type> newResultTypes(op.getResultTypes());
3463 newResultTypes.append(additionalValueTypes);
3466 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3469 newWhileOp.getBefore().takeBody(op.getBefore());
3470 newWhileOp.getAfter().takeBody(op.getAfter());
3471 newWhileOp.getAfter().addArguments(
3472 additionalValueTypes,
3473 SmallVector<Location>(additionalValueSize, loc));
3477 conditionOp.getArgsMutable().append(additionalUsedValues);
3483 additionalUsedValues,
3484 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3485 [&](OpOperand &use) {
3486 return ifOp.getThenRegion().isAncestor(
3487 use.getOwner()->getParentRegion());
3491 rewriter.
eraseOp(ifOp.thenYield());
3493 newWhileOp.getAfterBody()->begin());
3496 newWhileOp->getResults().drop_back(additionalValueSize));
3520struct WhileConditionTruth :
public OpRewritePattern<WhileOp> {
3521 using OpRewritePattern<WhileOp>::OpRewritePattern;
3523 LogicalResult matchAndRewrite(WhileOp op,
3524 PatternRewriter &rewriter)
const override {
3525 auto term = op.getConditionOp();
3529 Value constantTrue =
nullptr;
3531 bool replaced =
false;
3532 for (
auto yieldedAndBlockArgs :
3533 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3534 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3535 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3537 constantTrue = arith::ConstantOp::create(
3538 rewriter, op.getLoc(), term.getCondition().getType(),
3573struct WhileCmpCond :
public OpRewritePattern<scf::WhileOp> {
3574 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3576 LogicalResult matchAndRewrite(scf::WhileOp op,
3577 PatternRewriter &rewriter)
const override {
3578 using namespace scf;
3579 auto cond = op.getConditionOp();
3580 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3584 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3585 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3586 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3589 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3590 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3594 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3597 if (cmp2.getPredicate() == cmp.getPredicate())
3598 samePredicate =
true;
3599 else if (cmp2.getPredicate() ==
3600 arith::invertPredicate(cmp.getPredicate()))
3601 samePredicate =
false;
3617static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
3619 if (args1.size() != args2.size())
3620 return std::nullopt;
3622 SmallVector<unsigned> ret(args1.size());
3623 for (
auto &&[i, arg1] : llvm::enumerate(args1)) {
3624 auto it = llvm::find(args2, arg1);
3625 if (it == args2.end())
3626 return std::nullopt;
3628 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
3635 llvm::SmallDenseSet<Value> set;
3636 for (Value arg : args) {
3637 if (!set.insert(arg).second)
3647struct WhileOpAlignBeforeArgs :
public OpRewritePattern<WhileOp> {
3650 LogicalResult matchAndRewrite(WhileOp loop,
3651 PatternRewriter &rewriter)
const override {
3652 auto *oldBefore = loop.getBeforeBody();
3653 ConditionOp oldTerm = loop.getConditionOp();
3654 ValueRange beforeArgs = oldBefore->getArguments();
3656 if (beforeArgs == termArgs)
3659 if (hasDuplicates(termArgs))
3662 auto mapping = getArgsMapping(beforeArgs, termArgs);
3667 OpBuilder::InsertionGuard g(rewriter);
3673 auto *oldAfter = loop.getAfterBody();
3675 SmallVector<Type> newResultTypes(beforeArgs.size());
3676 for (
auto &&[i, j] : llvm::enumerate(*mapping))
3677 newResultTypes[j] = loop.getResult(i).getType();
3679 auto newLoop = WhileOp::create(
3680 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3682 auto *newBefore = newLoop.getBeforeBody();
3683 auto *newAfter = newLoop.getAfterBody();
3685 SmallVector<Value> newResults(beforeArgs.size());
3686 SmallVector<Value> newAfterArgs(beforeArgs.size());
3687 for (
auto &&[i, j] : llvm::enumerate(*mapping)) {
3688 newResults[i] = newLoop.getResult(j);
3689 newAfterArgs[i] = newAfter->getArgument(j);
3693 newBefore->getArguments());
3703void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3704 MLIRContext *context) {
3705 results.
add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3706 WhileMoveIfDown>(context);
3708 results, WhileOp::getOperationName());
3710 WhileOp::getOperationName());
3724 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
3727 caseValues.push_back(value);
3736 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
3738 p <<
"case " << value <<
' ';
3743LogicalResult scf::IndexSwitchOp::verify() {
3744 if (getCases().size() != getCaseRegions().size()) {
3746 << getCaseRegions().size() <<
" case regions but "
3747 << getCases().size() <<
" case values";
3751 for (int64_t value : getCases())
3752 if (!valueSet.insert(value).second)
3753 return emitOpError(
"has duplicate case value: ") << value;
3754 auto verifyRegion = [&](Region ®ion,
const Twine &name) -> LogicalResult {
3755 auto yield = dyn_cast<YieldOp>(region.
front().
back());
3757 return emitOpError(
"expected region to end with scf.yield, but got ")
3760 if (yield.getNumOperands() != getNumResults()) {
3761 return (
emitOpError(
"expected each region to return ")
3762 << getNumResults() <<
" values, but " << name <<
" returns "
3763 << yield.getNumOperands())
3764 .attachNote(yield.getLoc())
3765 <<
"see yield operation here";
3767 for (
auto [idx,
result, operand] :
3768 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3770 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
3771 if (
result == operand.getType())
3774 << idx <<
" of each region to be " <<
result)
3775 .attachNote(yield.getLoc())
3776 << name <<
" returns " << operand.getType() <<
" here";
3783 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3790unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
3792Block &scf::IndexSwitchOp::getDefaultBlock() {
3793 return getDefaultRegion().front();
3796Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
3797 assert(idx < getNumCases() &&
"case index out-of-bounds");
3798 return getCaseRegions()[idx].front();
3801void IndexSwitchOp::getSuccessorRegions(
3802 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3809 llvm::append_range(successors, getRegions());
3812ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
3817void IndexSwitchOp::getEntrySuccessorRegions(
3818 ArrayRef<Attribute> operands,
3819 SmallVectorImpl<RegionSuccessor> &successors) {
3820 FoldAdaptor adaptor(operands, *
this);
3823 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3825 llvm::append_range(successors, getRegions());
3831 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3832 if (caseValue == arg.getInt()) {
3833 successors.emplace_back(&caseRegion);
3837 successors.emplace_back(&getDefaultRegion());
3840void IndexSwitchOp::getRegionInvocationBounds(
3841 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3842 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3843 if (!operandValue) {
3845 bounds.append(getNumRegions(), InvocationBounds(0, 1));
3849 unsigned liveIndex = getNumRegions() - 1;
3850 const auto *it = llvm::find(getCases(), operandValue.getInt());
3851 if (it != getCases().end())
3852 liveIndex = std::distance(getCases().begin(), it);
3853 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
3854 bounds.emplace_back(0, i == liveIndex);
3857void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3858 MLIRContext *context) {
3860 results, IndexSwitchOp::getOperationName());
3862 results, IndexSwitchOp::getOperationName());
3869#define GET_OP_CLASSES
3870#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
static Value defaultReplBuilderFn(OpBuilder &builder, Location loc, Value value)
Default implementation of the non-successor-input replacement builder function.
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.