30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
47struct SCFInlinerInterface :
public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
52 IRMapping &valueMapping)
const final {
57 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
62 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
67 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
78void SCFDialect::initialize() {
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
95 scf::YieldOp::create(builder, loc);
100template <
typename TerminatorTy>
102 StringRef errorMessage) {
103 Operation *terminatorOperation =
nullptr;
105 terminatorOperation = ®ion.
front().
back();
106 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
118 auto addOp =
ub.getDefiningOp<arith::AddIOp>();
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
125 if (addOp.getLhs() != lb ||
139 assert(region.
hasOneBlock() &&
"expected single-block region");
159ParseResult ExecuteRegionOp::parse(
OpAsmParser &parser,
187LogicalResult ExecuteRegionOp::verify() {
188 if (getRegion().empty())
189 return emitOpError(
"region needs to have at least one block");
190 if (getRegion().front().getNumArguments() > 0)
191 return emitOpError(
"region cannot have any arguments");
214 if (!op.getRegion().hasOneBlock() || op.getNoInline())
263 if (op.getNoInline())
265 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
268 Block *prevBlock = op->getBlock();
272 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
274 for (
Block &blk : op.getRegion()) {
275 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
277 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278 yieldOp.getResults());
286 for (
auto res : op.getResults())
287 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
298 results, ExecuteRegionOp::getOperationName());
301void ExecuteRegionOp::getSuccessorRegions(
321 "condition op can only exit the loop or branch to the after"
324 return getArgsMutable();
327void ConditionOp::getSuccessorRegions(
329 FoldAdaptor adaptor(operands, *
this);
331 WhileOp whileOp = getParentOp();
335 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
336 if (!boolAttr || boolAttr.getValue())
337 regions.emplace_back(&whileOp.getAfter(),
338 whileOp.getAfter().getArguments());
339 if (!boolAttr || !boolAttr.getValue())
349 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
353 result.addAttribute(getUnsignedCmpAttrName(
result.name),
356 result.addOperands(initArgs);
357 for (
Value v : initArgs)
358 result.addTypes(v.getType());
363 for (
Value v : initArgs)
369 if (initArgs.empty() && !bodyBuilder) {
370 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
371 }
else if (bodyBuilder) {
379LogicalResult ForOp::verify() {
381 if (getInitArgs().size() != getNumResults())
383 "mismatch in number of loop-carried values and defined values");
388LogicalResult ForOp::verifyRegions() {
393 "expected induction variable to be same type as bounds and step");
395 if (getNumRegionIterArgs() != getNumResults())
397 "mismatch in number of basic block args and defined values");
399 auto initArgs = getInitArgs();
400 auto iterArgs = getRegionIterArgs();
401 auto opResults = getResults();
403 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
405 return emitOpError() <<
"types mismatch between " << i
406 <<
"th iter operand and defined value";
408 return emitOpError() <<
"types mismatch between " << i
409 <<
"th iter region arg and defined value";
416std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
420std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
424std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
428std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
432bool ForOp::isValidInductionVarType(
Type type) {
437 if (bounds.size() != 1)
439 if (
auto val = dyn_cast<Value>(bounds[0])) {
447 if (bounds.size() != 1)
449 if (
auto val = dyn_cast<Value>(bounds[0])) {
457 if (steps.size() != 1)
459 if (
auto val = dyn_cast<Value>(steps[0])) {
466std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
470LogicalResult ForOp::promoteIfSingleIteration(
RewriterBase &rewriter) {
471 std::optional<APInt> tripCount = getStaticTripCount();
472 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
475 if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
478 if (*tripCount == 0) {
485 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
492 llvm::append_range(bbArgReplacements, getInitArgs());
496 getOperation()->getIterator(), bbArgReplacements);
512 StringRef prefix =
"") {
513 assert(blocksArgs.size() == initializers.size() &&
514 "expected same length of arguments and initializers");
515 if (initializers.empty())
519 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
520 p << std::get<0>(it) <<
" = " << std::get<1>(it);
526 if (getUnsignedCmp())
529 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
533 if (!getInitArgs().empty())
534 p <<
" -> (" << getInitArgs().getTypes() <<
')';
537 p <<
" : " << t <<
' ';
540 !getInitArgs().empty());
542 getUnsignedCmpAttrName().strref());
553 result.addAttribute(getUnsignedCmpAttrName(
result.name),
567 regionArgs.push_back(inductionVariable);
577 if (regionArgs.size() !=
result.types.size() + 1)
580 "mismatch in number of loop-carried values and defined values");
589 regionArgs.front().type = type;
590 for (
auto [iterArg, type] :
591 llvm::zip_equal(llvm::drop_begin(regionArgs),
result.types))
598 ForOp::ensureTerminator(*body, builder,
result.location);
607 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
608 operands,
result.types)) {
609 Type type = std::get<2>(argOperandType);
610 std::get<0>(argOperandType).type = type;
627 return getBody()->getArguments().drop_front(getNumInductionVars());
631 return getInitArgsMutable();
634FailureOr<LoopLikeOpInterface>
635ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
637 bool replaceInitOperandUsesInLoop,
642 auto inits = llvm::to_vector(getInitArgs());
643 inits.append(newInitOperands.begin(), newInitOperands.end());
644 scf::ForOp newLoop = scf::ForOp::create(
650 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
652 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
657 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
658 assert(newInitOperands.size() == newYieldedValues.size() &&
659 "expected as many new yield values as new iter operands");
661 yieldOp.getResultsMutable().append(newYieldedValues);
667 newLoop.getBody()->getArguments().take_front(
668 getBody()->getNumArguments()));
670 if (replaceInitOperandUsesInLoop) {
673 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
684 newLoop->getResults().take_front(getNumResults()));
685 return cast<LoopLikeOpInterface>(newLoop.getOperation());
689 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
692 assert(ivArg.getOwner() &&
"unlinked block argument");
693 auto *containingOp = ivArg.getOwner()->getParentOp();
694 return dyn_cast_or_null<ForOp>(containingOp);
698 return getInitArgs();
714LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
715 for (
auto [lb, ub, step] :
716 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
719 if (!tripCount.has_value() || *tripCount != 1)
728 return getBody()->getArguments().drop_front(getRank());
731MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
732 return getOutputsMutable();
738 scf::InParallelOp terminator = forallOp.getTerminator();
743 bbArgReplacements.append(forallOp.getOutputs().begin(),
744 forallOp.getOutputs().end());
748 forallOp->getIterator(), bbArgReplacements);
753 results.reserve(forallOp.getResults().size());
754 for (
auto &yieldingOp : terminator.getYieldingOps()) {
755 auto parallelInsertSliceOp =
756 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
757 if (!parallelInsertSliceOp)
760 Value dst = parallelInsertSliceOp.getDest();
761 Value src = parallelInsertSliceOp.getSource();
762 if (llvm::isa<TensorType>(src.
getType())) {
763 results.push_back(tensor::InsertSliceOp::create(
764 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
765 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
766 parallelInsertSliceOp.getStrides(),
767 parallelInsertSliceOp.getStaticOffsets(),
768 parallelInsertSliceOp.getStaticSizes(),
769 parallelInsertSliceOp.getStaticStrides()));
771 llvm_unreachable(
"unsupported terminator");
786 assert(lbs.size() == ubs.size() &&
787 "expected the same number of lower and upper bounds");
788 assert(lbs.size() == steps.size() &&
789 "expected the same number of lower bounds and steps");
794 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
796 assert(results.size() == iterArgs.size() &&
797 "loop nest body must return as many values as loop has iteration "
799 return LoopNest{{}, std::move(results)};
807 loops.reserve(lbs.size());
808 ivs.reserve(lbs.size());
811 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
812 auto loop = scf::ForOp::create(
813 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
819 currentIterArgs = args;
820 currentLoc = nestedLoc;
826 loops.push_back(loop);
830 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
832 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
839 ? bodyBuilder(builder, currentLoc, ivs,
840 loops.back().getRegionIterArgs())
842 assert(results.size() == iterArgs.size() &&
843 "loop nest body must return as many values as loop has iteration "
846 scf::YieldOp::create(builder, loc, results);
850 llvm::append_range(nestResults, loops.front().getResults());
851 return LoopNest{std::move(loops), std::move(nestResults)};
864 bodyBuilder(nestedBuilder, nestedLoc, ivs);
873 assert(operand.
getOwner() == forOp);
878 "expected an iter OpOperand");
880 "Expected a different type");
882 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
887 newIterOperands.push_back(opOperand.get());
891 scf::ForOp newForOp = scf::ForOp::create(
892 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
893 forOp.getStep(), newIterOperands,
nullptr,
894 forOp.getUnsignedCmp());
895 newForOp->setAttrs(forOp->getAttrs());
896 Block &newBlock = newForOp.getRegion().
front();
904 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
906 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
907 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
911 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
914 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
917 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
918 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
919 clonedYieldOp.getOperand(yieldIdx));
921 newYieldOperands[yieldIdx] = castOut;
922 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
923 rewriter.
eraseOp(clonedYieldOp);
928 newResults[yieldIdx] =
929 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
941 LogicalResult matchAndRewrite(ForOp op,
943 std::optional<APInt> tripCount = op.getStaticTripCount();
944 if (!tripCount.has_value())
946 "can't compute constant trip count");
948 if (tripCount->isZero()) {
949 LDBG() <<
"SimplifyTrivialLoops tripCount is 0 for loop "
951 rewriter.
replaceOp(op, op.getInitArgs());
955 if (tripCount->getSExtValue() == 1) {
956 LDBG() <<
"SimplifyTrivialLoops tripCount is 1 for loop "
959 blockArgs.reserve(op.getInitArgs().size() + 1);
960 blockArgs.push_back(op.getLowerBound());
961 llvm::append_range(blockArgs, op.getInitArgs());
968 if (!llvm::hasSingleElement(block))
972 if (llvm::any_of(op.getYieldedValues(),
973 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
975 LDBG() <<
"SimplifyTrivialLoops empty body loop allows replacement with "
976 "yield operands for loop "
977 << OpWithFlags(op, OpPrintingFlags().skipRegions());
978 rewriter.
replaceOp(op, op.getYieldedValues());
1009struct ForOpTensorCastFolder :
public OpRewritePattern<ForOp> {
1010 using OpRewritePattern<ForOp>::OpRewritePattern;
1012 LogicalResult matchAndRewrite(ForOp op,
1013 PatternRewriter &rewriter)
const override {
1014 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1015 OpOperand &iterOpOperand = std::get<0>(it);
1017 if (!incomingCast ||
1018 incomingCast.getSource().getType() == incomingCast.getType())
1023 incomingCast.getDest().getType(),
1024 incomingCast.getSource().getType()))
1026 if (!std::get<1>(it).hasOneUse())
1032 rewriter, op, iterOpOperand, incomingCast.getSource(),
1033 [](OpBuilder &
b, Location loc, Type type, Value source) {
1034 return tensor::CastOp::create(b, loc, type, source);
1043void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1044 MLIRContext *context) {
1045 results.
add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
1047 results, ForOp::getOperationName());
1050std::optional<APInt> ForOp::getConstantStep() {
1053 return step.getValue();
1057std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1058 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1064 if (
auto constantStep = getConstantStep())
1065 if (*constantStep == 1)
1073std::optional<APInt> ForOp::getStaticTripCount() {
1082LogicalResult ForallOp::verify() {
1083 unsigned numLoops = getRank();
1085 if (getNumResults() != getOutputs().size())
1087 << getNumResults() <<
" results, but has only "
1088 << getOutputs().size() <<
" outputs";
1091 auto *body = getBody();
1093 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1094 for (int64_t i = 0; i < numLoops; ++i)
1097 << i <<
"-th block argument to be an index";
1098 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1101 << i <<
"-th output and corresponding block argument";
1102 if (getMapping().has_value() && !getMapping()->empty()) {
1103 if (getDeviceMappingAttrs().size() != numLoops)
1104 return emitOpError() <<
"mapping attribute size must match op rank";
1105 if (
failed(getDeviceMaskingAttr()))
1107 <<
" supports at most one device masking attribute";
1111 Operation *op = getOperation();
1113 getStaticLowerBound(),
1114 getDynamicLowerBound())))
1117 getStaticUpperBound(),
1118 getDynamicUpperBound())))
1121 getStaticStep(), getDynamicStep())))
1127void ForallOp::print(OpAsmPrinter &p) {
1128 Operation *op = getOperation();
1129 p <<
" (" << getInductionVars();
1130 if (isNormalized()) {
1151 if (!getRegionOutArgs().empty())
1152 p <<
"-> (" << getResultTypes() <<
") ";
1153 p.printRegion(getRegion(),
1155 getNumResults() > 0);
1156 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1157 getStaticLowerBoundAttrName(),
1158 getStaticUpperBoundAttrName(),
1159 getStaticStepAttrName()});
1162ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &
result) {
1164 auto indexType =
b.getIndexType();
1169 SmallVector<OpAsmParser::Argument, 4> ivs;
1174 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1184 unsigned numLoops = ivs.size();
1185 staticLbs =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1186 staticSteps =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1215 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1216 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1219 if (outOperands.size() !=
result.types.size())
1221 "mismatch between out operands and types");
1230 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1231 std::unique_ptr<Region> region = std::make_unique<Region>();
1232 for (
auto &iv : ivs) {
1233 iv.type =
b.getIndexType();
1234 regionArgs.push_back(iv);
1236 for (
const auto &it : llvm::enumerate(regionOutArgs)) {
1237 auto &out = it.value();
1238 out.type =
result.types[it.index()];
1239 regionArgs.push_back(out);
1245 ForallOp::ensureTerminator(*region,
b,
result.location);
1246 result.addRegion(std::move(region));
1252 result.addAttribute(
"staticLowerBound", staticLbs);
1253 result.addAttribute(
"staticUpperBound", staticUbs);
1254 result.addAttribute(
"staticStep", staticSteps);
1255 result.addAttribute(
"operandSegmentSizes",
1257 {static_cast<int32_t>(dynamicLbs.size()),
1258 static_cast<int32_t>(dynamicUbs.size()),
1259 static_cast<int32_t>(dynamicSteps.size()),
1260 static_cast<int32_t>(outOperands.size())}));
1265void ForallOp::build(
1266 mlir::OpBuilder &
b, mlir::OperationState &
result,
1267 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1268 ArrayRef<OpFoldResult> steps,
ValueRange outputs,
1269 std::optional<ArrayAttr> mapping,
1271 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1272 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1277 result.addOperands(dynamicLbs);
1278 result.addOperands(dynamicUbs);
1279 result.addOperands(dynamicSteps);
1280 result.addOperands(outputs);
1283 result.addAttribute(getStaticLowerBoundAttrName(
result.name),
1284 b.getDenseI64ArrayAttr(staticLbs));
1285 result.addAttribute(getStaticUpperBoundAttrName(
result.name),
1286 b.getDenseI64ArrayAttr(staticUbs));
1287 result.addAttribute(getStaticStepAttrName(
result.name),
1288 b.getDenseI64ArrayAttr(staticSteps));
1290 "operandSegmentSizes",
1291 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1292 static_cast<int32_t>(dynamicUbs.size()),
1293 static_cast<int32_t>(dynamicSteps.size()),
1294 static_cast<int32_t>(outputs.size())}));
1295 if (mapping.has_value()) {
1296 result.addAttribute(ForallOp::getMappingAttrName(
result.name),
1300 Region *bodyRegion =
result.addRegion();
1301 OpBuilder::InsertionGuard g(
b);
1302 b.createBlock(bodyRegion);
1307 SmallVector<Type>(lbs.size(),
b.getIndexType()),
1308 SmallVector<Location>(staticLbs.size(),
result.location));
1311 SmallVector<Location>(outputs.size(),
result.location));
1313 b.setInsertionPointToStart(&bodyBlock);
1314 if (!bodyBuilderFn) {
1315 ForallOp::ensureTerminator(*bodyRegion,
b,
result.location);
1322void ForallOp::build(
1323 mlir::OpBuilder &
b, mlir::OperationState &
result,
1324 ArrayRef<OpFoldResult> ubs,
ValueRange outputs,
1325 std::optional<ArrayAttr> mapping,
1327 unsigned numLoops = ubs.size();
1328 SmallVector<OpFoldResult> lbs(numLoops,
b.getIndexAttr(0));
1329 SmallVector<OpFoldResult> steps(numLoops,
b.getIndexAttr(1));
1330 build(
b,
result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1334bool ForallOp::isNormalized() {
1335 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1336 return llvm::all_of(results, [&](OpFoldResult ofr) {
1338 return intValue.has_value() && intValue == val;
1341 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1344InParallelOp ForallOp::getTerminator() {
1345 return cast<InParallelOp>(getBody()->getTerminator());
1348SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1349 SmallVector<Operation *> storeOps;
1350 for (Operation *user : bbArg.
getUsers()) {
1351 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1352 storeOps.push_back(parallelOp);
1358SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1359 SmallVector<DeviceMappingAttrInterface> res;
1362 for (
auto attr : getMapping()->getValue()) {
1363 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1370FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1371 DeviceMaskingAttrInterface res;
1374 for (
auto attr : getMapping()->getValue()) {
1375 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1384bool ForallOp::usesLinearMapping() {
1385 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1388 return ifaces.front().isLinearMapping();
1391std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1392 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1396std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1398 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(),
b);
1402std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1404 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(),
b);
1408std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1414 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1417 assert(tidxArg.getOwner() &&
"unlinked block argument");
1418 auto *containingOp = tidxArg.getOwner()->getParentOp();
1419 return dyn_cast<ForallOp>(containingOp);
1427 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1429 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1433 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1436 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1441class ForallOpControlOperandsFolder :
public OpRewritePattern<ForallOp> {
1443 using OpRewritePattern<ForallOp>::OpRewritePattern;
1445 LogicalResult matchAndRewrite(ForallOp op,
1446 PatternRewriter &rewriter)
const override {
1447 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1448 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1449 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1456 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1457 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1460 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1461 op.setStaticLowerBound(staticLowerBound);
1465 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1466 op.setStaticUpperBound(staticUpperBound);
1469 op.getDynamicStepMutable().assign(dynamicStep);
1470 op.setStaticStep(staticStep);
1472 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1474 {static_cast<int32_t>(dynamicLowerBound.size()),
1475 static_cast<int32_t>(dynamicUpperBound.size()),
1476 static_cast<int32_t>(dynamicStep.size()),
1477 static_cast<int32_t>(op.getNumResults())}));
1556struct ForallOpIterArgsFolder :
public OpRewritePattern<ForallOp> {
1557 using OpRewritePattern<ForallOp>::OpRewritePattern;
1559 LogicalResult matchAndRewrite(ForallOp forallOp,
1560 PatternRewriter &rewriter)
const final {
1571 SmallVector<Value> resultsToDelete;
1572 SmallVector<Value> outsToDelete;
1573 SmallVector<BlockArgument> blockArgsToDelete;
1574 SmallVector<Value> newOuts;
1575 BitVector resultIndicesToDelete(forallOp.getNumResults(),
false);
1576 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1578 for (OpResult
result : forallOp.getResults()) {
1579 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1580 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1581 if (
result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1582 resultsToDelete.push_back(
result);
1583 outsToDelete.push_back(opOperand->
get());
1584 blockArgsToDelete.push_back(blockArg);
1585 resultIndicesToDelete[
result.getResultNumber()] =
true;
1588 newOuts.push_back(opOperand->
get());
1594 if (resultsToDelete.empty())
1599 for (
auto blockArg : blockArgsToDelete) {
1600 SmallVector<Operation *> combiningOps =
1601 forallOp.getCombiningOps(blockArg);
1602 for (Operation *combiningOp : combiningOps)
1603 rewriter.
eraseOp(combiningOp);
1605 for (
auto [blockArg,
result, out] :
1606 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1612 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1617 auto newForallOp = cast<scf::ForallOp>(
1619 newForallOp.getOutputsMutable().assign(newOuts);
1625struct ForallOpSingleOrZeroIterationDimsFolder
1626 :
public OpRewritePattern<ForallOp> {
1627 using OpRewritePattern<ForallOp>::OpRewritePattern;
1629 LogicalResult matchAndRewrite(ForallOp op,
1630 PatternRewriter &rewriter)
const override {
1632 if (op.getMapping().has_value() && !op.getMapping()->empty())
1634 Location loc = op.getLoc();
1637 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1640 for (
auto [lb, ub, step, iv] :
1641 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1642 op.getMixedStep(), op.getInductionVars())) {
1643 auto numIterations =
1645 if (numIterations.has_value()) {
1647 if (*numIterations == 0) {
1648 rewriter.
replaceOp(op, op.getOutputs());
1653 if (*numIterations == 1) {
1658 newMixedLowerBounds.push_back(lb);
1659 newMixedUpperBounds.push_back(ub);
1660 newMixedSteps.push_back(step);
1664 if (newMixedLowerBounds.empty()) {
1670 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1672 op,
"no dimensions have 0 or 1 iterations");
1677 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1678 newMixedUpperBounds, newMixedSteps,
1679 op.getOutputs(), std::nullopt,
nullptr);
1680 newOp.getBodyRegion().getBlocks().clear();
1684 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1685 newOp.getStaticLowerBoundAttrName(),
1686 newOp.getStaticUpperBoundAttrName(),
1687 newOp.getStaticStepAttrName()};
1688 for (
const auto &namedAttr : op->getAttrs()) {
1689 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1692 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1696 newOp.getRegion().begin(), mapping);
1697 rewriter.
replaceOp(op, newOp.getResults());
1703struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1704 using OpRewritePattern<ForallOp>::OpRewritePattern;
1706 LogicalResult matchAndRewrite(ForallOp op,
1707 PatternRewriter &rewriter)
const override {
1708 Location loc = op.getLoc();
1710 for (
auto [lb, ub, step, iv] :
1711 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1712 op.getMixedStep(), op.getInductionVars())) {
1715 auto numIterations =
1717 if (!numIterations.has_value() || numIterations.value() != 1) {
1728struct FoldTensorCastOfOutputIntoForallOp
1729 :
public OpRewritePattern<scf::ForallOp> {
1730 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1737 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1738 PatternRewriter &rewriter)
const final {
1739 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1740 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1741 for (
auto en : llvm::enumerate(newOutputTensors)) {
1742 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1749 castOp.getSource().getType())) {
1753 tensorCastProducers[en.index()] =
1754 TypeCast{castOp.getSource().getType(), castOp.getType()};
1755 newOutputTensors[en.index()] = castOp.getSource();
1758 if (tensorCastProducers.empty())
1762 Location loc = forallOp.getLoc();
1763 auto newForallOp = ForallOp::create(
1764 rewriter, loc, forallOp.getMixedLowerBound(),
1765 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1766 newOutputTensors, forallOp.getMapping(),
1767 [&](OpBuilder nestedBuilder, Location nestedLoc,
ValueRange bbArgs) {
1768 auto castBlockArgs =
1769 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1770 for (auto [index, cast] : tensorCastProducers) {
1771 Value &oldTypeBBArg = castBlockArgs[index];
1772 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1773 cast.dstType, oldTypeBBArg);
1777 SmallVector<Value> ivsBlockArgs =
1778 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1779 ivsBlockArgs.append(castBlockArgs);
1781 bbArgs.front().getParentBlock(), ivsBlockArgs);
1787 auto terminator = newForallOp.getTerminator();
1788 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1789 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1790 if (
auto parallelCombingingOp =
1791 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1792 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1798 SmallVector<Value> castResults = newForallOp.getResults();
1799 for (
auto &item : tensorCastProducers) {
1800 Value &oldTypeResult = castResults[item.first];
1801 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1804 rewriter.
replaceOp(forallOp, castResults);
1811void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1812 MLIRContext *context) {
1813 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1814 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1815 ForallOpSingleOrZeroIterationDimsFolder,
1816 ForallOpReplaceConstantInductionVar>(context);
1819void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1820 SmallVectorImpl<RegionSuccessor> ®ions) {
1827 regions.push_back(RegionSuccessor(&getRegion()));
1831 ResultRange{getResults().end(), getResults().end()}));
1837 ResultRange{getResults().end(), getResults().end()}));
1846void InParallelOp::build(OpBuilder &
b, OperationState &
result) {
1847 OpBuilder::InsertionGuard g(
b);
1848 Region *bodyRegion =
result.addRegion();
1849 b.createBlock(bodyRegion);
1852LogicalResult InParallelOp::verify() {
1853 scf::ForallOp forallOp =
1854 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1856 return this->
emitOpError(
"expected forall op parent");
1858 for (Operation &op : getRegion().front().getOperations()) {
1859 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1860 if (!parallelCombiningOp) {
1861 return this->
emitOpError(
"expected only ParallelCombiningOpInterface")
1866 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1867 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1868 for (OpOperand &dest : dests) {
1869 if (!llvm::is_contained(regionOutArgs, dest.get()))
1870 return op.emitOpError(
"may only insert into an output block argument");
1877void InParallelOp::print(OpAsmPrinter &p) {
1885ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
1888 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1889 std::unique_ptr<Region> region = std::make_unique<Region>();
1893 if (region->empty())
1894 OpBuilder(builder.
getContext()).createBlock(region.get());
1895 result.addRegion(std::move(region));
1903OpResult InParallelOp::getParentResult(int64_t idx) {
1904 return getOperation()->getParentOp()->getResult(idx);
1907SmallVector<BlockArgument> InParallelOp::getDests() {
1908 SmallVector<BlockArgument> updatedDests;
1909 for (Operation &yieldingOp : getYieldingOps()) {
1910 auto parallelCombiningOp =
1911 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1912 if (!parallelCombiningOp)
1914 for (OpOperand &updatedOperand :
1915 parallelCombiningOp.getUpdatedDestinations())
1916 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1918 return updatedDests;
1921llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1922 return getRegion().front().getOperations();
1930 assert(a &&
"expected non-empty operation");
1931 assert(
b &&
"expected non-empty operation");
1936 if (ifOp->isProperAncestor(
b))
1939 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1940 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*
b));
1942 ifOp = ifOp->getParentOfType<IfOp>();
1950IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1951 IfOp::Adaptor adaptor,
1953 if (adaptor.getRegions().empty())
1955 Region *r = &adaptor.getThenRegion();
1961 auto yieldOp = llvm::dyn_cast<YieldOp>(
b.back());
1964 TypeRange types = yieldOp.getOperandTypes();
1965 llvm::append_range(inferredReturnTypes, types);
1971 return build(builder,
result, resultTypes, cond,
false,
1975void IfOp::build(OpBuilder &builder, OperationState &
result,
1976 TypeRange resultTypes, Value cond,
bool addThenBlock,
1977 bool addElseBlock) {
1978 assert((!addElseBlock || addThenBlock) &&
1979 "must not create else block w/o then block");
1980 result.addTypes(resultTypes);
1981 result.addOperands(cond);
1984 OpBuilder::InsertionGuard guard(builder);
1985 Region *thenRegion =
result.addRegion();
1988 Region *elseRegion =
result.addRegion();
1993void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
1994 bool withElseRegion) {
1998void IfOp::build(OpBuilder &builder, OperationState &
result,
1999 TypeRange resultTypes, Value cond,
bool withElseRegion) {
2000 result.addTypes(resultTypes);
2001 result.addOperands(cond);
2004 OpBuilder::InsertionGuard guard(builder);
2005 Region *thenRegion =
result.addRegion();
2007 if (resultTypes.empty())
2008 IfOp::ensureTerminator(*thenRegion, builder,
result.location);
2011 Region *elseRegion =
result.addRegion();
2012 if (withElseRegion) {
2014 if (resultTypes.empty())
2015 IfOp::ensureTerminator(*elseRegion, builder,
result.location);
2019void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2021 function_ref<
void(OpBuilder &, Location)> elseBuilder) {
2022 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2023 result.addOperands(cond);
2026 OpBuilder::InsertionGuard guard(builder);
2027 Region *thenRegion =
result.addRegion();
2029 thenBuilder(builder,
result.location);
2032 Region *elseRegion =
result.addRegion();
2035 elseBuilder(builder,
result.location);
2039 SmallVector<Type> inferredReturnTypes;
2041 auto attrDict = DictionaryAttr::get(ctx,
result.attributes);
2042 if (succeeded(inferReturnTypes(ctx, std::nullopt,
result.operands, attrDict,
2044 inferredReturnTypes))) {
2045 result.addTypes(inferredReturnTypes);
2049LogicalResult IfOp::verify() {
2050 if (getNumResults() != 0 && getElseRegion().empty())
2051 return emitOpError(
"must have an else block if defining values");
2055ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
2057 result.regions.reserve(2);
2058 Region *thenRegion =
result.addRegion();
2059 Region *elseRegion =
result.addRegion();
2062 OpAsmParser::UnresolvedOperand cond;
2088void IfOp::print(OpAsmPrinter &p) {
2089 bool printBlockTerminators =
false;
2091 p <<
" " << getCondition();
2092 if (!getResults().empty()) {
2093 p <<
" -> (" << getResultTypes() <<
")";
2095 printBlockTerminators =
true;
2100 printBlockTerminators);
2103 auto &elseRegion = getElseRegion();
2104 if (!elseRegion.
empty()) {
2108 printBlockTerminators);
2114void IfOp::getSuccessorRegions(RegionBranchPoint point,
2115 SmallVectorImpl<RegionSuccessor> ®ions) {
2123 regions.push_back(RegionSuccessor(&getThenRegion()));
2126 Region *elseRegion = &this->getElseRegion();
2127 if (elseRegion->
empty())
2130 regions.push_back(RegionSuccessor(elseRegion));
2133void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2134 SmallVectorImpl<RegionSuccessor> ®ions) {
2135 FoldAdaptor adaptor(operands, *
this);
2136 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2137 if (!boolAttr || boolAttr.getValue())
2138 regions.emplace_back(&getThenRegion());
2141 if (!boolAttr || !boolAttr.getValue()) {
2142 if (!getElseRegion().empty())
2143 regions.emplace_back(&getElseRegion());
2149LogicalResult IfOp::fold(FoldAdaptor adaptor,
2150 SmallVectorImpl<OpFoldResult> &results) {
2152 if (getElseRegion().empty())
2155 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2162 getConditionMutable().assign(xorStmt.getLhs());
2163 Block *thenBlock = &getThenRegion().front();
2166 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2167 getElseRegion().getBlocks());
2168 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2169 getThenRegion().getBlocks(), thenBlock);
2173void IfOp::getRegionInvocationBounds(
2174 ArrayRef<Attribute> operands,
2175 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2176 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2179 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2180 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2183 invocationBounds.assign(2, {0, 1});
2188struct RemoveStaticCondition :
public OpRewritePattern<IfOp> {
2189 using OpRewritePattern<IfOp>::OpRewritePattern;
2191 LogicalResult matchAndRewrite(IfOp op,
2192 PatternRewriter &rewriter)
const override {
2199 else if (!op.getElseRegion().empty())
2210struct ConvertTrivialIfToSelect :
public OpRewritePattern<IfOp> {
2211 using OpRewritePattern<IfOp>::OpRewritePattern;
2213 LogicalResult matchAndRewrite(IfOp op,
2214 PatternRewriter &rewriter)
const override {
2215 if (op->getNumResults() == 0)
2218 auto cond = op.getCondition();
2219 auto thenYieldArgs = op.thenYield().getOperands();
2220 auto elseYieldArgs = op.elseYield().getOperands();
2222 SmallVector<Type> nonHoistable;
2223 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2224 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2225 &op.getElseRegion() == falseVal.getParentRegion())
2226 nonHoistable.push_back(trueVal.getType());
2230 if (nonHoistable.size() == op->getNumResults())
2233 IfOp
replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2237 replacement.getThenRegion().takeBody(op.getThenRegion());
2238 replacement.getElseRegion().takeBody(op.getElseRegion());
2240 SmallVector<Value> results(op->getNumResults());
2241 assert(thenYieldArgs.size() == results.size());
2242 assert(elseYieldArgs.size() == results.size());
2244 SmallVector<Value> trueYields;
2245 SmallVector<Value> falseYields;
2247 for (
const auto &it :
2248 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2249 Value trueVal = std::get<0>(it.value());
2250 Value falseVal = std::get<1>(it.value());
2253 results[it.index()] =
replacement.getResult(trueYields.size());
2254 trueYields.push_back(trueVal);
2255 falseYields.push_back(falseVal);
2256 }
else if (trueVal == falseVal)
2257 results[it.index()] = trueVal;
2259 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2260 cond, trueVal, falseVal);
2287struct ConditionPropagation :
public OpRewritePattern<IfOp> {
2288 using OpRewritePattern<IfOp>::OpRewritePattern;
2291 enum class Parent { Then, Else,
None };
2296 static Parent getParentType(Region *toCheck, IfOp op,
2298 Region *endRegion) {
2299 SmallVector<Region *> seen;
2300 while (toCheck != endRegion) {
2301 auto found = cache.find(toCheck);
2302 if (found != cache.end())
2303 return found->second;
2304 seen.push_back(toCheck);
2305 if (&op.getThenRegion() == toCheck) {
2306 for (Region *region : seen)
2307 cache[region] = Parent::Then;
2308 return Parent::Then;
2310 if (&op.getElseRegion() == toCheck) {
2311 for (Region *region : seen)
2312 cache[region] = Parent::Else;
2313 return Parent::Else;
2318 for (Region *region : seen)
2319 cache[region] = Parent::None;
2320 return Parent::None;
2323 LogicalResult matchAndRewrite(IfOp op,
2324 PatternRewriter &rewriter)
const override {
2335 Value constantTrue =
nullptr;
2336 Value constantFalse =
nullptr;
2339 for (OpOperand &use :
2340 llvm::make_early_inc_range(op.getCondition().getUses())) {
2343 case Parent::Then: {
2347 constantTrue = arith::ConstantOp::create(
2351 [&]() { use.set(constantTrue); });
2354 case Parent::Else: {
2358 constantFalse = arith::ConstantOp::create(
2362 [&]() { use.set(constantFalse); });
2410struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2411 using OpRewritePattern<IfOp>::OpRewritePattern;
2413 LogicalResult matchAndRewrite(IfOp op,
2414 PatternRewriter &rewriter)
const override {
2416 if (op.getNumResults() == 0)
2420 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2422 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2425 op.getOperation()->getIterator());
2428 for (
auto [trueResult, falseResult, opResult] :
2429 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2431 if (trueResult == falseResult) {
2432 if (!opResult.use_empty()) {
2433 opResult.replaceAllUsesWith(trueResult);
2439 BoolAttr trueYield, falseYield;
2444 bool trueVal = trueYield.
getValue();
2445 bool falseVal = falseYield.
getValue();
2446 if (!trueVal && falseVal) {
2447 if (!opResult.use_empty()) {
2448 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2449 Value notCond = arith::XOrIOp::create(
2450 rewriter, op.getLoc(), op.getCondition(),
2456 opResult.replaceAllUsesWith(notCond);
2460 if (trueVal && !falseVal) {
2461 if (!opResult.use_empty()) {
2462 opResult.replaceAllUsesWith(op.getCondition());
2492struct CombineIfs :
public OpRewritePattern<IfOp> {
2493 using OpRewritePattern<IfOp>::OpRewritePattern;
2495 LogicalResult matchAndRewrite(IfOp nextIf,
2496 PatternRewriter &rewriter)
const override {
2497 Block *parent = nextIf->getBlock();
2498 if (nextIf == &parent->
front())
2501 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2509 Block *nextThen =
nullptr;
2510 Block *nextElse =
nullptr;
2511 if (nextIf.getCondition() == prevIf.getCondition()) {
2512 nextThen = nextIf.thenBlock();
2513 if (!nextIf.getElseRegion().empty())
2514 nextElse = nextIf.elseBlock();
2516 if (arith::XOrIOp notv =
2517 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2518 if (notv.getLhs() == prevIf.getCondition() &&
2520 nextElse = nextIf.thenBlock();
2521 if (!nextIf.getElseRegion().empty())
2522 nextThen = nextIf.elseBlock();
2525 if (arith::XOrIOp notv =
2526 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2527 if (notv.getLhs() == nextIf.getCondition() &&
2529 nextElse = nextIf.thenBlock();
2530 if (!nextIf.getElseRegion().empty())
2531 nextThen = nextIf.elseBlock();
2535 if (!nextThen && !nextElse)
2538 SmallVector<Value> prevElseYielded;
2539 if (!prevIf.getElseRegion().empty())
2540 prevElseYielded = prevIf.elseYield().getOperands();
2543 for (
auto it : llvm::zip(prevIf.getResults(),
2544 prevIf.thenYield().getOperands(), prevElseYielded))
2545 for (OpOperand &use :
2546 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2550 use.
set(std::get<1>(it));
2555 use.
set(std::get<2>(it));
2560 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2561 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2563 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2564 prevIf.getCondition(),
false);
2565 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2568 combinedIf.getThenRegion(),
2569 combinedIf.getThenRegion().begin());
2572 YieldOp thenYield = combinedIf.thenYield();
2573 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2574 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2577 SmallVector<Value> mergedYields(thenYield.getOperands());
2578 llvm::append_range(mergedYields, thenYield2.getOperands());
2579 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2585 combinedIf.getElseRegion(),
2586 combinedIf.getElseRegion().begin());
2589 if (combinedIf.getElseRegion().empty()) {
2591 combinedIf.getElseRegion(),
2592 combinedIf.getElseRegion().
begin());
2594 YieldOp elseYield = combinedIf.elseYield();
2595 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2596 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2600 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2601 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2603 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2609 SmallVector<Value> prevValues;
2610 SmallVector<Value> nextValues;
2611 for (
const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2612 if (pair.index() < prevIf.getNumResults())
2613 prevValues.push_back(pair.value());
2615 nextValues.push_back(pair.value());
2624struct RemoveEmptyElseBranch :
public OpRewritePattern<IfOp> {
2625 using OpRewritePattern<IfOp>::OpRewritePattern;
2627 LogicalResult matchAndRewrite(IfOp ifOp,
2628 PatternRewriter &rewriter)
const override {
2630 if (ifOp.getNumResults())
2632 Block *elseBlock = ifOp.elseBlock();
2633 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2637 newIfOp.getThenRegion().begin());
2659struct CombineNestedIfs :
public OpRewritePattern<IfOp> {
2660 using OpRewritePattern<IfOp>::OpRewritePattern;
2662 LogicalResult matchAndRewrite(IfOp op,
2663 PatternRewriter &rewriter)
const override {
2664 auto nestedOps = op.thenBlock()->without_terminator();
2666 if (!llvm::hasSingleElement(nestedOps))
2670 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2673 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2677 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2680 SmallVector<Value> thenYield(op.thenYield().getOperands());
2681 SmallVector<Value> elseYield;
2683 llvm::append_range(elseYield, op.elseYield().getOperands());
2687 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2696 for (
const auto &tup : llvm::enumerate(thenYield)) {
2697 if (tup.value().getDefiningOp() == nestedIf) {
2698 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2699 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2700 elseYield[tup.index()]) {
2705 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2718 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2721 elseYieldsToUpgradeToSelect.push_back(tup.index());
2724 Location loc = op.getLoc();
2725 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2726 nestedIf.getCondition());
2727 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2730 SmallVector<Value> results;
2731 llvm::append_range(results, newIf.getResults());
2734 for (
auto idx : elseYieldsToUpgradeToSelect)
2736 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2737 thenYield[idx], elseYield[idx]);
2739 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2742 if (!elseYield.empty()) {
2745 YieldOp::create(rewriter, loc, elseYield);
2754void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2755 MLIRContext *context) {
2756 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2757 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2758 RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
2761 results, IfOp::getOperationName());
2764Block *IfOp::thenBlock() {
return &getThenRegion().back(); }
2765YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2766Block *IfOp::elseBlock() {
2767 Region &r = getElseRegion();
2772YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2778void ParallelOp::build(
2783 result.addOperands(lowerBounds);
2784 result.addOperands(upperBounds);
2785 result.addOperands(steps);
2786 result.addOperands(initVals);
2788 ParallelOp::getOperandSegmentSizeAttr(),
2790 static_cast<int32_t>(upperBounds.size()),
2791 static_cast<int32_t>(steps.size()),
2792 static_cast<int32_t>(initVals.size())}));
2795 OpBuilder::InsertionGuard guard(builder);
2796 unsigned numIVs = steps.size();
2797 SmallVector<Type, 8> argTypes(numIVs, builder.
getIndexType());
2798 SmallVector<Location, 8> argLocs(numIVs,
result.location);
2799 Region *bodyRegion =
result.addRegion();
2802 if (bodyBuilderFn) {
2804 bodyBuilderFn(builder,
result.location,
2809 if (initVals.empty())
2810 ParallelOp::ensureTerminator(*bodyRegion, builder,
result.location);
2813void ParallelOp::build(
2820 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2823 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2827 wrapper = wrappedBuilderFn;
2833LogicalResult ParallelOp::verify() {
2838 if (stepValues.empty())
2840 "needs at least one tuple element for lowerBound, upperBound and step");
2843 for (Value stepValue : stepValues)
2846 return emitOpError(
"constant step operand must be positive");
2850 Block *body = getBody();
2852 return emitOpError() <<
"expects the same number of induction variables: "
2854 <<
" as bound and step values: " << stepValues.size();
2856 if (!arg.getType().isIndex())
2858 "expects arguments for the induction variable to be of index type");
2862 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2867 auto resultsSize = getResults().size();
2868 auto reductionsSize = reduceOp.getReductions().size();
2869 auto initValsSize = getInitVals().size();
2870 if (resultsSize != reductionsSize)
2871 return emitOpError() <<
"expects number of results: " << resultsSize
2872 <<
" to be the same as number of reductions: "
2874 if (resultsSize != initValsSize)
2875 return emitOpError() <<
"expects number of results: " << resultsSize
2876 <<
" to be the same as number of initial values: "
2878 if (reduceOp.getNumOperands() != initValsSize)
2883 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2884 auto resultType = getOperation()->getResult(i).getType();
2885 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2886 if (resultType != reductionOperandType)
2887 return reduceOp.emitOpError()
2888 <<
"expects type of " << i
2889 <<
"-th reduction operand: " << reductionOperandType
2890 <<
" to be the same as the " << i
2891 <<
"-th result type: " << resultType;
2896ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
2899 SmallVector<OpAsmParser::Argument, 4> ivs;
2904 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2911 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2919 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2927 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2938 Region *body =
result.addRegion();
2939 for (
auto &iv : ivs)
2946 ParallelOp::getOperandSegmentSizeAttr(),
2948 static_cast<int32_t>(upper.size()),
2949 static_cast<int32_t>(steps.size()),
2950 static_cast<int32_t>(initVals.size())}));
2959 ParallelOp::ensureTerminator(*body, builder,
result.location);
2963void ParallelOp::print(OpAsmPrinter &p) {
2964 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2965 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2966 if (!getInitVals().empty())
2967 p <<
" init (" << getInitVals() <<
")";
2972 (*this)->getAttrs(),
2973 ParallelOp::getOperandSegmentSizeAttr());
2976SmallVector<Region *> ParallelOp::getLoopRegions() {
return {&getRegion()}; }
2978std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2979 return SmallVector<Value>{getBody()->getArguments()};
2982std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2986std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2990std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2995 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2997 return ParallelOp();
2998 assert(ivArg.getOwner() &&
"unlinked block argument");
2999 auto *containingOp = ivArg.getOwner()->getParentOp();
3000 return dyn_cast<ParallelOp>(containingOp);
3005struct ParallelOpSingleOrZeroIterationDimsFolder
3009 LogicalResult matchAndRewrite(ParallelOp op,
3016 for (
auto [lb,
ub, step, iv] :
3017 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3018 op.getInductionVars())) {
3019 auto numIterations =
3021 if (numIterations.has_value()) {
3023 if (*numIterations == 0) {
3024 rewriter.
replaceOp(op, op.getInitVals());
3029 if (*numIterations == 1) {
3034 newLowerBounds.push_back(lb);
3035 newUpperBounds.push_back(ub);
3036 newSteps.push_back(step);
3039 if (newLowerBounds.size() == op.getLowerBound().size())
3042 if (newLowerBounds.empty()) {
3045 SmallVector<Value> results;
3046 results.reserve(op.getInitVals().size());
3047 for (
auto &bodyOp : op.getBody()->without_terminator())
3048 rewriter.
clone(bodyOp, mapping);
3049 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3050 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3051 Block &reduceBlock = reduceOp.getReductions()[i].front();
3052 auto initValIndex = results.size();
3053 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3057 rewriter.
clone(reduceBodyOp, mapping);
3060 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3061 results.push_back(
result);
3069 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3070 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3076 newOp.getRegion().begin(), mapping);
3077 rewriter.
replaceOp(op, newOp.getResults());
3082struct MergeNestedParallelLoops :
public OpRewritePattern<ParallelOp> {
3083 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3085 LogicalResult matchAndRewrite(ParallelOp op,
3086 PatternRewriter &rewriter)
const override {
3087 Block &outerBody = *op.getBody();
3091 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3096 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3097 llvm::is_contained(innerOp.getUpperBound(), val) ||
3098 llvm::is_contained(innerOp.getStep(), val))
3102 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3105 auto bodyBuilder = [&](OpBuilder &builder, Location ,
3107 Block &innerBody = *innerOp.getBody();
3108 assert(iterVals.size() ==
3116 builder.
clone(op, mapping);
3119 auto concatValues = [](
const auto &first,
const auto &second) {
3120 SmallVector<Value> ret;
3121 ret.reserve(first.size() + second.size());
3122 ret.assign(first.begin(), first.end());
3123 ret.append(second.begin(), second.end());
3127 auto newLowerBounds =
3128 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3129 auto newUpperBounds =
3130 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3131 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3142void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3143 MLIRContext *context) {
3145 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3154void ParallelOp::getSuccessorRegions(
3155 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3159 regions.push_back(RegionSuccessor(&getRegion()));
3161 ResultRange{getResults().end(), getResults().end()}));
3168void ReduceOp::build(OpBuilder &builder, OperationState &
result) {}
3170void ReduceOp::build(OpBuilder &builder, OperationState &
result,
3172 result.addOperands(operands);
3173 for (Value v : operands) {
3174 OpBuilder::InsertionGuard guard(builder);
3175 Region *bodyRegion =
result.addRegion();
3182LogicalResult ReduceOp::verifyRegions() {
3183 if (getReductions().size() != getOperands().size())
3184 return emitOpError() <<
"expects number of reduction regions: "
3185 << getReductions().size()
3186 <<
" to be the same as number of reduction operands: "
3187 << getOperands().size();
3190 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3191 auto type = getOperands()[i].getType();
3194 return emitOpError() << i <<
"-th reduction has an empty body";
3196 llvm::any_of(block.
getArguments(), [&](
const BlockArgument &arg) {
3197 return arg.getType() != type;
3199 return emitOpError() <<
"expected two block arguments with type " << type
3200 <<
" in the " << i <<
"-th reduction region";
3204 return emitOpError(
"reduction bodies must be terminated with an "
3205 "'scf.reduce.return' op");
3212ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3214 return MutableOperandRange(getOperation(), 0, 0);
3221LogicalResult ReduceReturnOp::verify() {
3224 Block *reductionBody = getOperation()->getBlock();
3226 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3228 if (expectedResultType != getResult().
getType())
3229 return emitOpError() <<
"must have type " << expectedResultType
3230 <<
" (the type of the reduction inputs)";
3238void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3239 ::mlir::OperationState &odsState,
TypeRange resultTypes,
3240 ValueRange inits, BodyBuilderFn beforeBuilder,
3241 BodyBuilderFn afterBuilder) {
3245 OpBuilder::InsertionGuard guard(odsBuilder);
3248 SmallVector<Location, 4> beforeArgLocs;
3249 beforeArgLocs.reserve(inits.size());
3250 for (Value operand : inits) {
3251 beforeArgLocs.push_back(operand.getLoc());
3254 Region *beforeRegion = odsState.
addRegion();
3256 inits.getTypes(), beforeArgLocs);
3261 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.
location);
3263 Region *afterRegion = odsState.
addRegion();
3265 resultTypes, afterArgLocs);
3271ConditionOp WhileOp::getConditionOp() {
3272 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3275YieldOp WhileOp::getYieldOp() {
3276 return cast<YieldOp>(getAfterBody()->getTerminator());
3279std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3280 return getYieldOp().getResultsMutable();
3284 return getBeforeBody()->getArguments();
3288 return getAfterBody()->getArguments();
3292 return getBeforeArguments();
3295OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3297 "WhileOp is expected to branch only to the first region");
3301void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3302 SmallVectorImpl<RegionSuccessor> ®ions) {
3305 regions.emplace_back(&getBefore(), getBefore().getArguments());
3309 assert(llvm::is_contained(
3310 {&getAfter(), &getBefore()},
3312 "there are only two regions in a WhileOp");
3316 regions.emplace_back(&getBefore(), getBefore().getArguments());
3321 regions.emplace_back(&getAfter(), getAfter().getArguments());
3324SmallVector<Region *> WhileOp::getLoopRegions() {
3325 return {&getBefore(), &getAfter()};
3335ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
3336 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3337 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3338 Region *before =
result.addRegion();
3339 Region *after =
result.addRegion();
3341 OptionalParseResult listResult =
3346 FunctionType functionType;
3351 result.addTypes(functionType.getResults());
3353 if (functionType.getNumInputs() != operands.size()) {
3355 <<
"expected as many input types as operands " <<
"(expected "
3356 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3366 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3367 regionArgs[i].type = functionType.getInput(i);
3369 return failure(parser.
parseRegion(*before, regionArgs) ||
3375void scf::WhileOp::print(OpAsmPrinter &p) {
3389template <
typename OpTy>
3392 if (left.size() != right.size())
3393 return op.emitOpError(
"expects the same number of ") << message;
3395 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3396 if (left[i] != right[i]) {
3399 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3400 <<
" and " << right[i];
3408LogicalResult scf::WhileOp::verify() {
3411 "expects the 'before' region to terminate with 'scf.condition'");
3412 if (!beforeTerminator)
3417 "expects the 'after' region to terminate with 'scf.yield'");
3418 return success(afterTerminator !=
nullptr);
3455struct WhileMoveIfDown :
public OpRewritePattern<scf::WhileOp> {
3456 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3458 LogicalResult matchAndRewrite(scf::WhileOp op,
3459 PatternRewriter &rewriter)
const override {
3460 auto conditionOp = op.getConditionOp();
3468 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3474 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3475 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3478 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3479 *ifOp->user_begin() == conditionOp)) &&
3480 "ifOp has unexpected uses");
3482 Location loc = op.getLoc();
3486 for (
auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3487 auto it = llvm::find(ifOp->getResults(), arg);
3488 if (it != ifOp->getResults().end()) {
3489 size_t ifOpIdx = it.getIndex();
3490 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3491 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3501 if (&op.getBefore() == operand->get().getParentRegion())
3502 additionalUsedValuesSet.insert(operand->get());
3506 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3507 auto additionalValueTypes = llvm::map_to_vector(
3508 additionalUsedValues, [](Value val) {
return val.
getType(); });
3509 size_t additionalValueSize = additionalUsedValues.size();
3510 SmallVector<Type> newResultTypes(op.getResultTypes());
3511 newResultTypes.append(additionalValueTypes);
3514 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3517 newWhileOp.getBefore().takeBody(op.getBefore());
3518 newWhileOp.getAfter().takeBody(op.getAfter());
3519 newWhileOp.getAfter().addArguments(
3520 additionalValueTypes,
3521 SmallVector<Location>(additionalValueSize, loc));
3525 conditionOp.getArgsMutable().append(additionalUsedValues);
3531 additionalUsedValues,
3532 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3533 [&](OpOperand &use) {
3534 return ifOp.getThenRegion().isAncestor(
3535 use.getOwner()->getParentRegion());
3539 rewriter.
eraseOp(ifOp.thenYield());
3541 newWhileOp.getAfterBody()->begin());
3544 newWhileOp->getResults().drop_back(additionalValueSize));
3568struct WhileConditionTruth :
public OpRewritePattern<WhileOp> {
3569 using OpRewritePattern<WhileOp>::OpRewritePattern;
3571 LogicalResult matchAndRewrite(WhileOp op,
3572 PatternRewriter &rewriter)
const override {
3573 auto term = op.getConditionOp();
3577 Value constantTrue =
nullptr;
3579 bool replaced =
false;
3580 for (
auto yieldedAndBlockArgs :
3581 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3582 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3583 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3585 constantTrue = arith::ConstantOp::create(
3586 rewriter, op.getLoc(), term.getCondition().getType(),
3621struct WhileCmpCond :
public OpRewritePattern<scf::WhileOp> {
3622 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3624 LogicalResult matchAndRewrite(scf::WhileOp op,
3625 PatternRewriter &rewriter)
const override {
3626 using namespace scf;
3627 auto cond = op.getConditionOp();
3628 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3632 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3633 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3634 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3637 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3638 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3642 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3645 if (cmp2.getPredicate() == cmp.getPredicate())
3646 samePredicate =
true;
3647 else if (cmp2.getPredicate() ==
3648 arith::invertPredicate(cmp.getPredicate()))
3649 samePredicate =
false;
3665static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
3667 if (args1.size() != args2.size())
3668 return std::nullopt;
3670 SmallVector<unsigned> ret(args1.size());
3671 for (
auto &&[i, arg1] : llvm::enumerate(args1)) {
3672 auto it = llvm::find(args2, arg1);
3673 if (it == args2.end())
3674 return std::nullopt;
3676 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
3683 llvm::SmallDenseSet<Value> set;
3684 for (Value arg : args) {
3685 if (!set.insert(arg).second)
3695struct WhileOpAlignBeforeArgs :
public OpRewritePattern<WhileOp> {
3698 LogicalResult matchAndRewrite(WhileOp loop,
3699 PatternRewriter &rewriter)
const override {
3700 auto *oldBefore = loop.getBeforeBody();
3701 ConditionOp oldTerm = loop.getConditionOp();
3702 ValueRange beforeArgs = oldBefore->getArguments();
3704 if (beforeArgs == termArgs)
3707 if (hasDuplicates(termArgs))
3710 auto mapping = getArgsMapping(beforeArgs, termArgs);
3715 OpBuilder::InsertionGuard g(rewriter);
3721 auto *oldAfter = loop.getAfterBody();
3723 SmallVector<Type> newResultTypes(beforeArgs.size());
3724 for (
auto &&[i, j] : llvm::enumerate(*mapping))
3725 newResultTypes[j] = loop.getResult(i).getType();
3727 auto newLoop = WhileOp::create(
3728 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3730 auto *newBefore = newLoop.getBeforeBody();
3731 auto *newAfter = newLoop.getAfterBody();
3733 SmallVector<Value> newResults(beforeArgs.size());
3734 SmallVector<Value> newAfterArgs(beforeArgs.size());
3735 for (
auto &&[i, j] : llvm::enumerate(*mapping)) {
3736 newResults[i] = newLoop.getResult(j);
3737 newAfterArgs[i] = newAfter->getArgument(j);
3741 newBefore->getArguments());
3751void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3752 MLIRContext *context) {
3753 results.
add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3754 WhileMoveIfDown>(context);
3756 results, WhileOp::getOperationName());
3770 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
3773 caseValues.push_back(value);
3782 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
3784 p <<
"case " << value <<
' ';
3789LogicalResult scf::IndexSwitchOp::verify() {
3790 if (getCases().size() != getCaseRegions().size()) {
3792 << getCaseRegions().size() <<
" case regions but "
3793 << getCases().size() <<
" case values";
3797 for (int64_t value : getCases())
3798 if (!valueSet.insert(value).second)
3799 return emitOpError(
"has duplicate case value: ") << value;
3800 auto verifyRegion = [&](Region ®ion,
const Twine &name) -> LogicalResult {
3801 auto yield = dyn_cast<YieldOp>(region.
front().
back());
3803 return emitOpError(
"expected region to end with scf.yield, but got ")
3806 if (yield.getNumOperands() != getNumResults()) {
3807 return (
emitOpError(
"expected each region to return ")
3808 << getNumResults() <<
" values, but " << name <<
" returns "
3809 << yield.getNumOperands())
3810 .attachNote(yield.getLoc())
3811 <<
"see yield operation here";
3813 for (
auto [idx,
result, operand] :
3814 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3816 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
3817 if (
result == operand.getType())
3820 << idx <<
" of each region to be " <<
result)
3821 .attachNote(yield.getLoc())
3822 << name <<
" returns " << operand.getType() <<
" here";
3829 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3836unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
3838Block &scf::IndexSwitchOp::getDefaultBlock() {
3839 return getDefaultRegion().front();
3842Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
3843 assert(idx < getNumCases() &&
"case index out-of-bounds");
3844 return getCaseRegions()[idx].front();
3847void IndexSwitchOp::getSuccessorRegions(
3848 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3855 llvm::append_range(successors, getRegions());
3858void IndexSwitchOp::getEntrySuccessorRegions(
3859 ArrayRef<Attribute> operands,
3860 SmallVectorImpl<RegionSuccessor> &successors) {
3861 FoldAdaptor adaptor(operands, *
this);
3864 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3866 llvm::append_range(successors, getRegions());
3872 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3873 if (caseValue == arg.getInt()) {
3874 successors.emplace_back(&caseRegion);
3878 successors.emplace_back(&getDefaultRegion());
3881void IndexSwitchOp::getRegionInvocationBounds(
3882 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3883 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3884 if (!operandValue) {
3886 bounds.append(getNumRegions(), InvocationBounds(0, 1));
3890 unsigned liveIndex = getNumRegions() - 1;
3891 const auto *it = llvm::find(getCases(), operandValue.getInt());
3892 if (it != getCases().end())
3893 liveIndex = std::distance(getCases().begin(), it);
3894 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
3895 bounds.emplace_back(0, i == liveIndex);
3906 if (!maybeCst.has_value())
3909 int64_t caseIdx, e = op.getNumCases();
3910 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
3911 if (cst == op.getCases()[caseIdx])
3915 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
3916 : op.getDefaultRegion();
3931void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3932 MLIRContext *context) {
3933 results.
add<FoldConstantCase>(context);
3935 results, IndexSwitchOp::getOperationName());
3942#define GET_OP_CLASSES
3943#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent(Operation::result_range results)
Initialize a successor that branches back to/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.