30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
47struct SCFInlinerInterface :
public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
52 IRMapping &valueMapping)
const final {
57 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
62 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
67 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
78void SCFDialect::initialize() {
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
95 scf::YieldOp::create(builder, loc);
100template <
typename TerminatorTy>
102 StringRef errorMessage) {
103 Operation *terminatorOperation =
nullptr;
105 terminatorOperation = ®ion.
front().
back();
106 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
118 auto addOp =
ub.getDefiningOp<arith::AddIOp>();
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
125 if (addOp.getLhs() != lb ||
146ParseResult ExecuteRegionOp::parse(
OpAsmParser &parser,
174LogicalResult ExecuteRegionOp::verify() {
175 if (getRegion().empty())
176 return emitOpError(
"region needs to have at least one block");
177 if (getRegion().front().getNumArguments() > 0)
178 return emitOpError(
"region cannot have any arguments");
224 if (op.getNoInline())
226 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
229 Block *prevBlock = op->getBlock();
233 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
235 for (
Block &blk : op.getRegion()) {
236 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
238 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
239 yieldOp.getResults());
247 for (
auto res : op.getResults())
248 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
259 results, ExecuteRegionOp::getOperationName());
262 results, ExecuteRegionOp::getOperationName(),
264 return failure(cast<ExecuteRegionOp>(op).getNoInline());
268void ExecuteRegionOp::getSuccessorRegions(
293 "condition op can only exit the loop or branch to the after"
296 return getArgsMutable();
299void ConditionOp::getSuccessorRegions(
301 FoldAdaptor adaptor(operands, *
this);
303 WhileOp whileOp = getParentOp();
307 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
308 if (!boolAttr || boolAttr.getValue())
309 regions.emplace_back(&whileOp.getAfter());
310 if (!boolAttr || !boolAttr.getValue())
320 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
324 result.addAttribute(getUnsignedCmpAttrName(
result.name),
327 result.addOperands(initArgs);
328 for (
Value v : initArgs)
329 result.addTypes(v.getType());
334 for (
Value v : initArgs)
340 if (initArgs.empty() && !bodyBuilder) {
341 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
342 }
else if (bodyBuilder) {
350LogicalResult ForOp::verify() {
355 if (getBody()->getNumArguments() < getNumInductionVars())
356 return emitOpError(
"expected body to have at least ")
357 << getNumInductionVars()
358 <<
" argument(s) for the induction variable, but got "
359 << getBody()->getNumArguments();
362 if (getInitArgs().size() != getNumResults())
364 "mismatch in number of loop-carried values and defined values");
369LogicalResult ForOp::verifyRegions() {
371 if (getBody()->getNumArguments() < getNumInductionVars())
372 return emitOpError(
"expected body to have at least ")
373 << getNumInductionVars() <<
" argument(s) for the induction "
374 <<
"variable, but got " << getBody()->getNumArguments();
380 "expected induction variable to be same type as bounds and step");
382 if (getNumRegionIterArgs() != getNumResults())
384 "mismatch in number of basic block args and defined values");
386 auto initArgs = getInitArgs();
387 auto iterArgs = getRegionIterArgs();
388 auto opResults = getResults();
390 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
392 return emitOpError() <<
"types mismatch between " << i
393 <<
"th iter operand and defined value";
395 return emitOpError() <<
"types mismatch between " << i
396 <<
"th iter region arg and defined value";
403std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
407std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
411std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
415std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
419bool ForOp::isValidInductionVarType(
Type type) {
424 if (bounds.size() != 1)
426 if (
auto val = dyn_cast<Value>(bounds[0])) {
434 if (bounds.size() != 1)
436 if (
auto val = dyn_cast<Value>(bounds[0])) {
444 if (steps.size() != 1)
446 if (
auto val = dyn_cast<Value>(steps[0])) {
453std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
457LogicalResult ForOp::promoteIfSingleIteration(
RewriterBase &rewriter) {
458 std::optional<APInt> tripCount = getStaticTripCount();
459 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
462 if (!tripCount.has_value() || tripCount->getZExtValue() > 1)
465 if (*tripCount == 0) {
472 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
479 llvm::append_range(bbArgReplacements, getInitArgs());
483 getOperation()->getIterator(), bbArgReplacements);
499 StringRef prefix =
"") {
500 assert(blocksArgs.size() == initializers.size() &&
501 "expected same length of arguments and initializers");
502 if (initializers.empty())
506 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
507 p << std::get<0>(it) <<
" = " << std::get<1>(it);
513 if (getUnsignedCmp())
516 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
520 if (!getInitArgs().empty())
521 p <<
" -> (" << getInitArgs().getTypes() <<
')';
524 p <<
" : " << t <<
' ';
527 !getInitArgs().empty());
529 getUnsignedCmpAttrName().strref());
540 result.addAttribute(getUnsignedCmpAttrName(
result.name),
554 regionArgs.push_back(inductionVariable);
564 if (regionArgs.size() !=
result.types.size() + 1)
567 "mismatch in number of loop-carried values and defined values");
576 regionArgs.front().type = type;
577 for (
auto [iterArg, type] :
578 llvm::zip_equal(llvm::drop_begin(regionArgs),
result.types))
585 ForOp::ensureTerminator(*body, builder,
result.location);
594 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
595 operands,
result.types)) {
596 Type type = std::get<2>(argOperandType);
597 std::get<0>(argOperandType).type = type;
614 return getBody()->getArguments().drop_front(getNumInductionVars());
618 return getInitArgsMutable();
621FailureOr<LoopLikeOpInterface>
622ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
624 bool replaceInitOperandUsesInLoop,
629 auto inits = llvm::to_vector(getInitArgs());
630 inits.append(newInitOperands.begin(), newInitOperands.end());
631 scf::ForOp newLoop = scf::ForOp::create(
637 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
639 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
644 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
645 assert(newInitOperands.size() == newYieldedValues.size() &&
646 "expected as many new yield values as new iter operands");
648 yieldOp.getResultsMutable().append(newYieldedValues);
654 newLoop.getBody()->getArguments().take_front(
655 getBody()->getNumArguments()));
657 if (replaceInitOperandUsesInLoop) {
660 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
671 newLoop->getResults().take_front(getNumResults()));
672 return cast<LoopLikeOpInterface>(newLoop.getOperation());
676 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
679 assert(ivArg.getOwner() &&
"unlinked block argument");
680 auto *containingOp = ivArg.getOwner()->getParentOp();
681 return dyn_cast_or_null<ForOp>(containingOp);
685 return getInitArgs();
690 if (std::optional<APInt> tripCount = getStaticTripCount()) {
693 if (*tripCount == 0) {
702 }
else if (*tripCount == 1) {
726LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
727 for (
auto [lb, ub, step] :
728 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
731 if (!tripCount.has_value() || *tripCount != 1)
740 return getBody()->getArguments().drop_front(getRank());
743MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
744 return getOutputsMutable();
750 scf::InParallelOp terminator = forallOp.getTerminator();
755 bbArgReplacements.append(forallOp.getOutputs().begin(),
756 forallOp.getOutputs().end());
760 forallOp->getIterator(), bbArgReplacements);
765 results.reserve(forallOp.getResults().size());
766 for (
auto &yieldingOp : terminator.getYieldingOps()) {
767 auto parallelInsertSliceOp =
768 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
769 if (!parallelInsertSliceOp)
772 Value dst = parallelInsertSliceOp.getDest();
773 Value src = parallelInsertSliceOp.getSource();
774 if (llvm::isa<TensorType>(src.
getType())) {
775 results.push_back(tensor::InsertSliceOp::create(
776 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
777 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
778 parallelInsertSliceOp.getStrides(),
779 parallelInsertSliceOp.getStaticOffsets(),
780 parallelInsertSliceOp.getStaticSizes(),
781 parallelInsertSliceOp.getStaticStrides()));
783 llvm_unreachable(
"unsupported terminator");
798 assert(lbs.size() == ubs.size() &&
799 "expected the same number of lower and upper bounds");
800 assert(lbs.size() == steps.size() &&
801 "expected the same number of lower bounds and steps");
806 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
808 assert(results.size() == iterArgs.size() &&
809 "loop nest body must return as many values as loop has iteration "
811 return LoopNest{{}, std::move(results)};
819 loops.reserve(lbs.size());
820 ivs.reserve(lbs.size());
823 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
824 auto loop = scf::ForOp::create(
825 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
831 currentIterArgs = args;
832 currentLoc = nestedLoc;
838 loops.push_back(loop);
842 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
844 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
851 ? bodyBuilder(builder, currentLoc, ivs,
852 loops.back().getRegionIterArgs())
854 assert(results.size() == iterArgs.size() &&
855 "loop nest body must return as many values as loop has iteration "
858 scf::YieldOp::create(builder, loc, results);
862 llvm::append_range(nestResults, loops.front().getResults());
863 return LoopNest{std::move(loops), std::move(nestResults)};
876 bodyBuilder(nestedBuilder, nestedLoc, ivs);
885 assert(operand.
getOwner() == forOp);
890 "expected an iter OpOperand");
892 "Expected a different type");
894 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
899 newIterOperands.push_back(opOperand.get());
903 scf::ForOp newForOp = scf::ForOp::create(
904 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
905 forOp.getStep(), newIterOperands,
nullptr,
906 forOp.getUnsignedCmp());
907 newForOp->setAttrs(forOp->getAttrs());
908 Block &newBlock = newForOp.getRegion().
front();
916 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
918 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
919 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
923 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
926 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
929 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
930 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
931 clonedYieldOp.getOperand(yieldIdx));
933 newYieldOperands[yieldIdx] = castOut;
934 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
935 rewriter.
eraseOp(clonedYieldOp);
940 newResults[yieldIdx] =
941 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
976 LogicalResult matchAndRewrite(ForOp op,
978 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
979 OpOperand &iterOpOperand = std::get<0>(it);
982 incomingCast.getSource().getType() == incomingCast.getType())
987 incomingCast.getDest().getType(),
988 incomingCast.getSource().getType()))
990 if (!std::get<1>(it).hasOneUse())
996 rewriter, op, iterOpOperand, incomingCast.getSource(),
998 return tensor::CastOp::create(b, loc, type, source);
1007void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1008 MLIRContext *context) {
1009 results.
add<ForOpTensorCastFolder>(context);
1011 results, ForOp::getOperationName());
1013 results, ForOp::getOperationName(),
1014 [](OpBuilder &builder, Location loc, Value value) {
1018 auto blockArg = cast<BlockArgument>(value);
1019 assert(blockArg.getArgNumber() == 0 &&
"expected induction variable");
1020 auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
1021 return forOp.getLowerBound();
1025std::optional<APInt> ForOp::getConstantStep() {
1028 return step.getValue();
1032std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1033 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1039 if (
auto constantStep = getConstantStep())
1040 if (*constantStep == 1)
1048std::optional<APInt> ForOp::getStaticTripCount() {
1057LogicalResult ForallOp::verify() {
1058 unsigned numLoops = getRank();
1060 if (getNumResults() != getOutputs().size())
1062 << getNumResults() <<
" results, but has only "
1063 << getOutputs().size() <<
" outputs";
1066 auto *body = getBody();
1068 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1069 for (int64_t i = 0; i < numLoops; ++i)
1072 << i <<
"-th block argument to be an index";
1073 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1076 << i <<
"-th output and corresponding block argument";
1077 if (getMapping().has_value() && !getMapping()->empty()) {
1078 if (getDeviceMappingAttrs().size() != numLoops)
1079 return emitOpError() <<
"mapping attribute size must match op rank";
1080 if (
failed(getDeviceMaskingAttr()))
1082 <<
" supports at most one device masking attribute";
1086 Operation *op = getOperation();
1088 getStaticLowerBound(),
1089 getDynamicLowerBound())))
1092 getStaticUpperBound(),
1093 getDynamicUpperBound())))
1096 getStaticStep(), getDynamicStep())))
1102void ForallOp::print(OpAsmPrinter &p) {
1103 Operation *op = getOperation();
1104 p <<
" (" << getInductionVars();
1105 if (isNormalized()) {
1126 if (!getRegionOutArgs().empty())
1127 p <<
"-> (" << getResultTypes() <<
") ";
1128 p.printRegion(getRegion(),
1130 getNumResults() > 0);
1131 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1132 getStaticLowerBoundAttrName(),
1133 getStaticUpperBoundAttrName(),
1134 getStaticStepAttrName()});
1137ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &
result) {
1139 auto indexType =
b.getIndexType();
1144 SmallVector<OpAsmParser::Argument, 4> ivs;
1149 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1159 unsigned numLoops = ivs.size();
1160 staticLbs =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1161 staticSteps =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1190 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1191 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1194 if (outOperands.size() !=
result.types.size())
1196 "mismatch between out operands and types");
1205 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1206 std::unique_ptr<Region> region = std::make_unique<Region>();
1207 for (
auto &iv : ivs) {
1208 iv.type =
b.getIndexType();
1209 regionArgs.push_back(iv);
1211 for (
const auto &it : llvm::enumerate(regionOutArgs)) {
1212 auto &out = it.value();
1213 out.type =
result.types[it.index()];
1214 regionArgs.push_back(out);
1220 ForallOp::ensureTerminator(*region,
b,
result.location);
1221 result.addRegion(std::move(region));
1227 result.addAttribute(
"staticLowerBound", staticLbs);
1228 result.addAttribute(
"staticUpperBound", staticUbs);
1229 result.addAttribute(
"staticStep", staticSteps);
1230 result.addAttribute(
"operandSegmentSizes",
1232 {static_cast<int32_t>(dynamicLbs.size()),
1233 static_cast<int32_t>(dynamicUbs.size()),
1234 static_cast<int32_t>(dynamicSteps.size()),
1235 static_cast<int32_t>(outOperands.size())}));
1240void ForallOp::build(
1241 mlir::OpBuilder &
b, mlir::OperationState &
result,
1242 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1243 ArrayRef<OpFoldResult> steps,
ValueRange outputs,
1244 std::optional<ArrayAttr> mapping,
1246 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1247 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1252 result.addOperands(dynamicLbs);
1253 result.addOperands(dynamicUbs);
1254 result.addOperands(dynamicSteps);
1255 result.addOperands(outputs);
1258 result.addAttribute(getStaticLowerBoundAttrName(
result.name),
1259 b.getDenseI64ArrayAttr(staticLbs));
1260 result.addAttribute(getStaticUpperBoundAttrName(
result.name),
1261 b.getDenseI64ArrayAttr(staticUbs));
1262 result.addAttribute(getStaticStepAttrName(
result.name),
1263 b.getDenseI64ArrayAttr(staticSteps));
1265 "operandSegmentSizes",
1266 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1267 static_cast<int32_t>(dynamicUbs.size()),
1268 static_cast<int32_t>(dynamicSteps.size()),
1269 static_cast<int32_t>(outputs.size())}));
1270 if (mapping.has_value()) {
1271 result.addAttribute(ForallOp::getMappingAttrName(
result.name),
1275 Region *bodyRegion =
result.addRegion();
1276 OpBuilder::InsertionGuard g(
b);
1277 b.createBlock(bodyRegion);
1282 SmallVector<Type>(lbs.size(),
b.getIndexType()),
1283 SmallVector<Location>(staticLbs.size(),
result.location));
1286 SmallVector<Location>(outputs.size(),
result.location));
1288 b.setInsertionPointToStart(&bodyBlock);
1289 if (!bodyBuilderFn) {
1290 ForallOp::ensureTerminator(*bodyRegion,
b,
result.location);
1297void ForallOp::build(
1298 mlir::OpBuilder &
b, mlir::OperationState &
result,
1299 ArrayRef<OpFoldResult> ubs,
ValueRange outputs,
1300 std::optional<ArrayAttr> mapping,
1302 unsigned numLoops = ubs.size();
1303 SmallVector<OpFoldResult> lbs(numLoops,
b.getIndexAttr(0));
1304 SmallVector<OpFoldResult> steps(numLoops,
b.getIndexAttr(1));
1305 build(
b,
result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1309bool ForallOp::isNormalized() {
1310 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1311 return llvm::all_of(results, [&](OpFoldResult ofr) {
1313 return intValue.has_value() && intValue == val;
1316 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1319InParallelOp ForallOp::getTerminator() {
1320 return cast<InParallelOp>(getBody()->getTerminator());
1323SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1324 SmallVector<Operation *> storeOps;
1325 for (Operation *user : bbArg.
getUsers()) {
1326 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1327 storeOps.push_back(parallelOp);
1333SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1334 SmallVector<DeviceMappingAttrInterface> res;
1337 for (
auto attr : getMapping()->getValue()) {
1338 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1345FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1346 DeviceMaskingAttrInterface res;
1349 for (
auto attr : getMapping()->getValue()) {
1350 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1359bool ForallOp::usesLinearMapping() {
1360 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1363 return ifaces.front().isLinearMapping();
1366std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1367 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1371std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1373 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(),
b);
1377std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1379 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(),
b);
1383std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1389 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1392 assert(tidxArg.getOwner() &&
"unlinked block argument");
1393 auto *containingOp = tidxArg.getOwner()->getParentOp();
1394 return dyn_cast<ForallOp>(containingOp);
1402 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1404 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1408 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1411 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1416class ForallOpControlOperandsFolder :
public OpRewritePattern<ForallOp> {
1418 using OpRewritePattern<ForallOp>::OpRewritePattern;
1420 LogicalResult matchAndRewrite(ForallOp op,
1421 PatternRewriter &rewriter)
const override {
1422 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1423 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1424 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1431 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1432 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1435 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1436 op.setStaticLowerBound(staticLowerBound);
1440 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1441 op.setStaticUpperBound(staticUpperBound);
1444 op.getDynamicStepMutable().assign(dynamicStep);
1445 op.setStaticStep(staticStep);
1447 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1449 {static_cast<int32_t>(dynamicLowerBound.size()),
1450 static_cast<int32_t>(dynamicUpperBound.size()),
1451 static_cast<int32_t>(dynamicStep.size()),
1452 static_cast<int32_t>(op.getNumResults())}));
1531struct ForallOpIterArgsFolder :
public OpRewritePattern<ForallOp> {
1532 using OpRewritePattern<ForallOp>::OpRewritePattern;
1534 LogicalResult matchAndRewrite(ForallOp forallOp,
1535 PatternRewriter &rewriter)
const final {
1546 SmallVector<Value> resultsToDelete;
1547 SmallVector<Value> outsToDelete;
1548 SmallVector<BlockArgument> blockArgsToDelete;
1549 SmallVector<Value> newOuts;
1550 BitVector resultIndicesToDelete(forallOp.getNumResults(),
false);
1551 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1553 for (OpResult
result : forallOp.getResults()) {
1554 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1555 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1556 if (
result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1557 resultsToDelete.push_back(
result);
1558 outsToDelete.push_back(opOperand->
get());
1559 blockArgsToDelete.push_back(blockArg);
1560 resultIndicesToDelete[
result.getResultNumber()] =
true;
1563 newOuts.push_back(opOperand->
get());
1569 if (resultsToDelete.empty())
1574 for (
auto blockArg : blockArgsToDelete) {
1575 SmallVector<Operation *> combiningOps =
1576 forallOp.getCombiningOps(blockArg);
1577 for (Operation *combiningOp : combiningOps)
1578 rewriter.
eraseOp(combiningOp);
1580 for (
auto [blockArg,
result, out] :
1581 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1587 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1592 auto newForallOp = cast<scf::ForallOp>(
1594 newForallOp.getOutputsMutable().assign(newOuts);
1600struct ForallOpSingleOrZeroIterationDimsFolder
1601 :
public OpRewritePattern<ForallOp> {
1602 using OpRewritePattern<ForallOp>::OpRewritePattern;
1604 LogicalResult matchAndRewrite(ForallOp op,
1605 PatternRewriter &rewriter)
const override {
1607 if (op.getMapping().has_value() && !op.getMapping()->empty())
1609 Location loc = op.getLoc();
1612 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1615 for (
auto [lb, ub, step, iv] :
1616 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1617 op.getMixedStep(), op.getInductionVars())) {
1618 auto numIterations =
1620 if (numIterations.has_value()) {
1622 if (*numIterations == 0) {
1623 rewriter.
replaceOp(op, op.getOutputs());
1628 if (*numIterations == 1) {
1633 newMixedLowerBounds.push_back(lb);
1634 newMixedUpperBounds.push_back(ub);
1635 newMixedSteps.push_back(step);
1639 if (newMixedLowerBounds.empty()) {
1645 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1647 op,
"no dimensions have 0 or 1 iterations");
1652 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1653 newMixedUpperBounds, newMixedSteps,
1654 op.getOutputs(), std::nullopt,
nullptr);
1655 newOp.getBodyRegion().getBlocks().clear();
1659 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1660 newOp.getStaticLowerBoundAttrName(),
1661 newOp.getStaticUpperBoundAttrName(),
1662 newOp.getStaticStepAttrName()};
1663 for (
const auto &namedAttr : op->getAttrs()) {
1664 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1667 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1671 newOp.getRegion().begin(), mapping);
1672 rewriter.
replaceOp(op, newOp.getResults());
1678struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1679 using OpRewritePattern<ForallOp>::OpRewritePattern;
1681 LogicalResult matchAndRewrite(ForallOp op,
1682 PatternRewriter &rewriter)
const override {
1683 Location loc = op.getLoc();
1684 bool changed =
false;
1685 for (
auto [lb, ub, step, iv] :
1686 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1687 op.getMixedStep(), op.getInductionVars())) {
1690 auto numIterations =
1692 if (!numIterations.has_value() || numIterations.value() != 1) {
1703struct FoldTensorCastOfOutputIntoForallOp
1704 :
public OpRewritePattern<scf::ForallOp> {
1705 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1712 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1713 PatternRewriter &rewriter)
const final {
1714 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1715 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1716 for (
auto en : llvm::enumerate(newOutputTensors)) {
1717 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1724 castOp.getSource().getType())) {
1728 tensorCastProducers[en.index()] =
1729 TypeCast{castOp.getSource().getType(), castOp.getType()};
1730 newOutputTensors[en.index()] = castOp.getSource();
1733 if (tensorCastProducers.empty())
1737 Location loc = forallOp.getLoc();
1738 auto newForallOp = ForallOp::create(
1739 rewriter, loc, forallOp.getMixedLowerBound(),
1740 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1741 newOutputTensors, forallOp.getMapping(),
1742 [&](OpBuilder nestedBuilder, Location nestedLoc,
ValueRange bbArgs) {
1743 auto castBlockArgs =
1744 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1745 for (auto [index, cast] : tensorCastProducers) {
1746 Value &oldTypeBBArg = castBlockArgs[index];
1747 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1748 cast.dstType, oldTypeBBArg);
1752 SmallVector<Value> ivsBlockArgs =
1753 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1754 ivsBlockArgs.append(castBlockArgs);
1756 bbArgs.front().getParentBlock(), ivsBlockArgs);
1762 auto terminator = newForallOp.getTerminator();
1763 for (
auto [yieldingOp, outputBlockArg] : llvm::zip(
1764 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1765 if (
auto parallelCombingingOp =
1766 dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1767 parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1773 SmallVector<Value> castResults = newForallOp.getResults();
1774 for (
auto &item : tensorCastProducers) {
1775 Value &oldTypeResult = castResults[item.first];
1776 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1779 rewriter.
replaceOp(forallOp, castResults);
1786void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1787 MLIRContext *context) {
1788 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1789 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1790 ForallOpSingleOrZeroIterationDimsFolder,
1791 ForallOpReplaceConstantInductionVar>(context);
1794void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1795 SmallVectorImpl<RegionSuccessor> ®ions) {
1802 regions.push_back(RegionSuccessor(&getRegion()));
1819void InParallelOp::build(OpBuilder &
b, OperationState &
result) {
1820 OpBuilder::InsertionGuard g(
b);
1821 Region *bodyRegion =
result.addRegion();
1822 b.createBlock(bodyRegion);
1825LogicalResult InParallelOp::verify() {
1826 scf::ForallOp forallOp =
1827 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1829 return this->
emitOpError(
"expected forall op parent");
1831 for (Operation &op : getRegion().front().getOperations()) {
1832 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1833 if (!parallelCombiningOp) {
1834 return this->
emitOpError(
"expected only ParallelCombiningOpInterface")
1839 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1840 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1841 for (OpOperand &dest : dests) {
1842 if (!llvm::is_contained(regionOutArgs, dest.get()))
1843 return op.emitOpError(
"may only insert into an output block argument");
1850void InParallelOp::print(OpAsmPrinter &p) {
1858ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
1861 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1862 std::unique_ptr<Region> region = std::make_unique<Region>();
1866 if (region->empty())
1867 OpBuilder(builder.
getContext()).createBlock(region.get());
1868 result.addRegion(std::move(region));
1876OpResult InParallelOp::getParentResult(int64_t idx) {
1877 return getOperation()->getParentOp()->getResult(idx);
1880SmallVector<BlockArgument> InParallelOp::getDests() {
1881 SmallVector<BlockArgument> updatedDests;
1882 for (Operation &yieldingOp : getYieldingOps()) {
1883 auto parallelCombiningOp =
1884 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1885 if (!parallelCombiningOp)
1887 for (OpOperand &updatedOperand :
1888 parallelCombiningOp.getUpdatedDestinations())
1889 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1891 return updatedDests;
1894llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1895 return getRegion().front().getOperations();
1903 assert(a &&
"expected non-empty operation");
1904 assert(
b &&
"expected non-empty operation");
1909 if (ifOp->isProperAncestor(
b))
1912 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1913 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*
b));
1915 ifOp = ifOp->getParentOfType<IfOp>();
1923IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1924 IfOp::Adaptor adaptor,
1926 if (adaptor.getRegions().empty())
1928 Region *r = &adaptor.getThenRegion();
1934 auto yieldOp = llvm::dyn_cast<YieldOp>(
b.back());
1937 TypeRange types = yieldOp.getOperandTypes();
1938 llvm::append_range(inferredReturnTypes, types);
1944 return build(builder,
result, resultTypes, cond,
false,
1948void IfOp::build(OpBuilder &builder, OperationState &
result,
1949 TypeRange resultTypes, Value cond,
bool addThenBlock,
1950 bool addElseBlock) {
1951 assert((!addElseBlock || addThenBlock) &&
1952 "must not create else block w/o then block");
1953 result.addTypes(resultTypes);
1954 result.addOperands(cond);
1957 OpBuilder::InsertionGuard guard(builder);
1958 Region *thenRegion =
result.addRegion();
1961 Region *elseRegion =
result.addRegion();
1966void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
1967 bool withElseRegion) {
1971void IfOp::build(OpBuilder &builder, OperationState &
result,
1972 TypeRange resultTypes, Value cond,
bool withElseRegion) {
1973 result.addTypes(resultTypes);
1974 result.addOperands(cond);
1977 OpBuilder::InsertionGuard guard(builder);
1978 Region *thenRegion =
result.addRegion();
1980 if (resultTypes.empty())
1981 IfOp::ensureTerminator(*thenRegion, builder,
result.location);
1984 Region *elseRegion =
result.addRegion();
1985 if (withElseRegion) {
1987 if (resultTypes.empty())
1988 IfOp::ensureTerminator(*elseRegion, builder,
result.location);
1992void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
1994 function_ref<
void(OpBuilder &, Location)> elseBuilder) {
1995 assert(thenBuilder &&
"the builder callback for 'then' must be present");
1996 result.addOperands(cond);
1999 OpBuilder::InsertionGuard guard(builder);
2000 Region *thenRegion =
result.addRegion();
2002 thenBuilder(builder,
result.location);
2005 Region *elseRegion =
result.addRegion();
2008 elseBuilder(builder,
result.location);
2012 SmallVector<Type> inferredReturnTypes;
2014 auto attrDict = DictionaryAttr::get(ctx,
result.attributes);
2015 if (succeeded(inferReturnTypes(ctx, std::nullopt,
result.operands, attrDict,
2017 inferredReturnTypes))) {
2018 result.addTypes(inferredReturnTypes);
2022LogicalResult IfOp::verify() {
2023 if (getNumResults() != 0 && getElseRegion().empty())
2024 return emitOpError(
"must have an else block if defining values");
2028ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
2030 result.regions.reserve(2);
2031 Region *thenRegion =
result.addRegion();
2032 Region *elseRegion =
result.addRegion();
2035 OpAsmParser::UnresolvedOperand cond;
2061void IfOp::print(OpAsmPrinter &p) {
2062 bool printBlockTerminators =
false;
2064 p <<
" " << getCondition();
2065 if (!getResults().empty()) {
2066 p <<
" -> (" << getResultTypes() <<
")";
2068 printBlockTerminators =
true;
2073 printBlockTerminators);
2076 auto &elseRegion = getElseRegion();
2077 if (!elseRegion.
empty()) {
2081 printBlockTerminators);
2087void IfOp::getSuccessorRegions(RegionBranchPoint point,
2088 SmallVectorImpl<RegionSuccessor> ®ions) {
2096 regions.push_back(RegionSuccessor(&getThenRegion()));
2099 Region *elseRegion = &this->getElseRegion();
2100 if (elseRegion->
empty())
2103 regions.push_back(RegionSuccessor(elseRegion));
2106ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
2111void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2112 SmallVectorImpl<RegionSuccessor> ®ions) {
2113 FoldAdaptor adaptor(operands, *
this);
2114 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2115 if (!boolAttr || boolAttr.getValue())
2116 regions.emplace_back(&getThenRegion());
2119 if (!boolAttr || !boolAttr.getValue()) {
2120 if (!getElseRegion().empty())
2121 regions.emplace_back(&getElseRegion());
2127LogicalResult IfOp::fold(FoldAdaptor adaptor,
2128 SmallVectorImpl<OpFoldResult> &results) {
2130 if (getElseRegion().empty())
2133 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2140 getConditionMutable().assign(xorStmt.getLhs());
2141 Block *thenBlock = &getThenRegion().front();
2144 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2145 getElseRegion().getBlocks());
2146 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2147 getThenRegion().getBlocks(), thenBlock);
2151void IfOp::getRegionInvocationBounds(
2152 ArrayRef<Attribute> operands,
2153 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2154 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2157 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2158 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2161 invocationBounds.assign(2, {0, 1});
2168struct ConvertTrivialIfToSelect :
public OpRewritePattern<IfOp> {
2169 using OpRewritePattern<IfOp>::OpRewritePattern;
2171 LogicalResult matchAndRewrite(IfOp op,
2172 PatternRewriter &rewriter)
const override {
2173 if (op->getNumResults() == 0)
2176 auto cond = op.getCondition();
2177 auto thenYieldArgs = op.thenYield().getOperands();
2178 auto elseYieldArgs = op.elseYield().getOperands();
2180 SmallVector<Type> nonHoistable;
2181 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2182 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2183 &op.getElseRegion() == falseVal.getParentRegion())
2184 nonHoistable.push_back(trueVal.getType());
2188 if (nonHoistable.size() == op->getNumResults())
2191 IfOp
replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2195 replacement.getThenRegion().takeBody(op.getThenRegion());
2196 replacement.getElseRegion().takeBody(op.getElseRegion());
2198 SmallVector<Value> results(op->getNumResults());
2199 assert(thenYieldArgs.size() == results.size());
2200 assert(elseYieldArgs.size() == results.size());
2202 SmallVector<Value> trueYields;
2203 SmallVector<Value> falseYields;
2205 for (
const auto &it :
2206 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2207 Value trueVal = std::get<0>(it.value());
2208 Value falseVal = std::get<1>(it.value());
2211 results[it.index()] =
replacement.getResult(trueYields.size());
2212 trueYields.push_back(trueVal);
2213 falseYields.push_back(falseVal);
2214 }
else if (trueVal == falseVal)
2215 results[it.index()] = trueVal;
2217 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2218 cond, trueVal, falseVal);
2245struct ConditionPropagation :
public OpRewritePattern<IfOp> {
2246 using OpRewritePattern<IfOp>::OpRewritePattern;
2249 enum class Parent { Then, Else,
None };
2254 static Parent getParentType(Region *toCheck, IfOp op,
2256 Region *endRegion) {
2257 SmallVector<Region *> seen;
2258 while (toCheck != endRegion) {
2259 auto found = cache.find(toCheck);
2260 if (found != cache.end())
2261 return found->second;
2262 seen.push_back(toCheck);
2263 if (&op.getThenRegion() == toCheck) {
2264 for (Region *region : seen)
2265 cache[region] = Parent::Then;
2266 return Parent::Then;
2268 if (&op.getElseRegion() == toCheck) {
2269 for (Region *region : seen)
2270 cache[region] = Parent::Else;
2271 return Parent::Else;
2276 for (Region *region : seen)
2277 cache[region] = Parent::None;
2278 return Parent::None;
2281 LogicalResult matchAndRewrite(IfOp op,
2282 PatternRewriter &rewriter)
const override {
2288 bool changed =
false;
2293 Value constantTrue =
nullptr;
2294 Value constantFalse =
nullptr;
2297 for (OpOperand &use :
2298 llvm::make_early_inc_range(op.getCondition().getUses())) {
2301 case Parent::Then: {
2305 constantTrue = arith::ConstantOp::create(
2309 [&]() { use.set(constantTrue); });
2312 case Parent::Else: {
2316 constantFalse = arith::ConstantOp::create(
2320 [&]() { use.set(constantFalse); });
2368struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2369 using OpRewritePattern<IfOp>::OpRewritePattern;
2371 LogicalResult matchAndRewrite(IfOp op,
2372 PatternRewriter &rewriter)
const override {
2374 if (op.getNumResults() == 0)
2378 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2380 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2383 op.getOperation()->getIterator());
2384 bool changed =
false;
2386 for (
auto [trueResult, falseResult, opResult] :
2387 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2389 if (trueResult == falseResult) {
2390 if (!opResult.use_empty()) {
2391 opResult.replaceAllUsesWith(trueResult);
2397 BoolAttr trueYield, falseYield;
2402 bool trueVal = trueYield.
getValue();
2403 bool falseVal = falseYield.
getValue();
2404 if (!trueVal && falseVal) {
2405 if (!opResult.use_empty()) {
2406 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2407 Value notCond = arith::XOrIOp::create(
2408 rewriter, op.getLoc(), op.getCondition(),
2414 opResult.replaceAllUsesWith(notCond);
2418 if (trueVal && !falseVal) {
2419 if (!opResult.use_empty()) {
2420 opResult.replaceAllUsesWith(op.getCondition());
2450struct CombineIfs :
public OpRewritePattern<IfOp> {
2451 using OpRewritePattern<IfOp>::OpRewritePattern;
2453 LogicalResult matchAndRewrite(IfOp nextIf,
2454 PatternRewriter &rewriter)
const override {
2455 Block *parent = nextIf->getBlock();
2456 if (nextIf == &parent->
front())
2459 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2467 Block *nextThen =
nullptr;
2468 Block *nextElse =
nullptr;
2469 if (nextIf.getCondition() == prevIf.getCondition()) {
2470 nextThen = nextIf.thenBlock();
2471 if (!nextIf.getElseRegion().empty())
2472 nextElse = nextIf.elseBlock();
2474 if (arith::XOrIOp notv =
2475 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2476 if (notv.getLhs() == prevIf.getCondition() &&
2478 nextElse = nextIf.thenBlock();
2479 if (!nextIf.getElseRegion().empty())
2480 nextThen = nextIf.elseBlock();
2483 if (arith::XOrIOp notv =
2484 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2485 if (notv.getLhs() == nextIf.getCondition() &&
2487 nextElse = nextIf.thenBlock();
2488 if (!nextIf.getElseRegion().empty())
2489 nextThen = nextIf.elseBlock();
2493 if (!nextThen && !nextElse)
2496 SmallVector<Value> prevElseYielded;
2497 if (!prevIf.getElseRegion().empty())
2498 prevElseYielded = prevIf.elseYield().getOperands();
2501 for (
auto it : llvm::zip(prevIf.getResults(),
2502 prevIf.thenYield().getOperands(), prevElseYielded))
2503 for (OpOperand &use :
2504 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2508 use.
set(std::get<1>(it));
2513 use.
set(std::get<2>(it));
2518 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2519 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2521 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2522 prevIf.getCondition(),
false);
2523 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2526 combinedIf.getThenRegion(),
2527 combinedIf.getThenRegion().begin());
2530 YieldOp thenYield = combinedIf.thenYield();
2531 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2532 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2535 SmallVector<Value> mergedYields(thenYield.getOperands());
2536 llvm::append_range(mergedYields, thenYield2.getOperands());
2537 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2543 combinedIf.getElseRegion(),
2544 combinedIf.getElseRegion().begin());
2547 if (combinedIf.getElseRegion().empty()) {
2549 combinedIf.getElseRegion(),
2550 combinedIf.getElseRegion().
begin());
2552 YieldOp elseYield = combinedIf.elseYield();
2553 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2554 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2558 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2559 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2561 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2567 SmallVector<Value> prevValues;
2568 SmallVector<Value> nextValues;
2569 for (
const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2570 if (pair.index() < prevIf.getNumResults())
2571 prevValues.push_back(pair.value());
2573 nextValues.push_back(pair.value());
2582struct RemoveEmptyElseBranch :
public OpRewritePattern<IfOp> {
2583 using OpRewritePattern<IfOp>::OpRewritePattern;
2585 LogicalResult matchAndRewrite(IfOp ifOp,
2586 PatternRewriter &rewriter)
const override {
2588 if (ifOp.getNumResults())
2590 Block *elseBlock = ifOp.elseBlock();
2591 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2595 newIfOp.getThenRegion().begin());
2617struct CombineNestedIfs :
public OpRewritePattern<IfOp> {
2618 using OpRewritePattern<IfOp>::OpRewritePattern;
2620 LogicalResult matchAndRewrite(IfOp op,
2621 PatternRewriter &rewriter)
const override {
2622 auto nestedOps = op.thenBlock()->without_terminator();
2624 if (!llvm::hasSingleElement(nestedOps))
2628 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2631 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2635 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2638 SmallVector<Value> thenYield(op.thenYield().getOperands());
2639 SmallVector<Value> elseYield;
2641 llvm::append_range(elseYield, op.elseYield().getOperands());
2645 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2654 for (
const auto &tup : llvm::enumerate(thenYield)) {
2655 if (tup.value().getDefiningOp() == nestedIf) {
2656 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2657 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2658 elseYield[tup.index()]) {
2663 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2676 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2679 elseYieldsToUpgradeToSelect.push_back(tup.index());
2682 Location loc = op.getLoc();
2683 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2684 nestedIf.getCondition());
2685 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2688 SmallVector<Value> results;
2689 llvm::append_range(results, newIf.getResults());
2692 for (
auto idx : elseYieldsToUpgradeToSelect)
2694 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2695 thenYield[idx], elseYield[idx]);
2697 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2700 if (!elseYield.empty()) {
2703 YieldOp::create(rewriter, loc, elseYield);
2712void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2713 MLIRContext *context) {
2714 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2715 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2716 ReplaceIfYieldWithConditionOrValue>(context);
2718 results, IfOp::getOperationName());
2720 IfOp::getOperationName());
2723Block *IfOp::thenBlock() {
return &getThenRegion().back(); }
2724YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2725Block *IfOp::elseBlock() {
2726 Region &r = getElseRegion();
2731YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2737void ParallelOp::build(
2742 result.addOperands(lowerBounds);
2743 result.addOperands(upperBounds);
2744 result.addOperands(steps);
2745 result.addOperands(initVals);
2747 ParallelOp::getOperandSegmentSizeAttr(),
2749 static_cast<int32_t>(upperBounds.size()),
2750 static_cast<int32_t>(steps.size()),
2751 static_cast<int32_t>(initVals.size())}));
2754 OpBuilder::InsertionGuard guard(builder);
2755 unsigned numIVs = steps.size();
2756 SmallVector<Type, 8> argTypes(numIVs, builder.
getIndexType());
2757 SmallVector<Location, 8> argLocs(numIVs,
result.location);
2758 Region *bodyRegion =
result.addRegion();
2761 if (bodyBuilderFn) {
2763 bodyBuilderFn(builder,
result.location,
2768 if (initVals.empty())
2769 ParallelOp::ensureTerminator(*bodyRegion, builder,
result.location);
2772void ParallelOp::build(
2779 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2782 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2786 wrapper = wrappedBuilderFn;
2792LogicalResult ParallelOp::verify() {
2797 if (stepValues.empty())
2799 "needs at least one tuple element for lowerBound, upperBound and step");
2802 for (Value stepValue : stepValues)
2805 return emitOpError(
"constant step operand must be positive");
2809 Block *body = getBody();
2811 return emitOpError() <<
"expects the same number of induction variables: "
2813 <<
" as bound and step values: " << stepValues.size();
2815 if (!arg.getType().isIndex())
2817 "expects arguments for the induction variable to be of index type");
2821 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2826 auto resultsSize = getResults().size();
2827 auto reductionsSize = reduceOp.getReductions().size();
2828 auto initValsSize = getInitVals().size();
2829 if (resultsSize != reductionsSize)
2830 return emitOpError() <<
"expects number of results: " << resultsSize
2831 <<
" to be the same as number of reductions: "
2833 if (resultsSize != initValsSize)
2834 return emitOpError() <<
"expects number of results: " << resultsSize
2835 <<
" to be the same as number of initial values: "
2837 if (reduceOp.getNumOperands() != initValsSize)
2842 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2843 auto resultType = getOperation()->getResult(i).getType();
2844 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2845 if (resultType != reductionOperandType)
2846 return reduceOp.emitOpError()
2847 <<
"expects type of " << i
2848 <<
"-th reduction operand: " << reductionOperandType
2849 <<
" to be the same as the " << i
2850 <<
"-th result type: " << resultType;
2855ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
2858 SmallVector<OpAsmParser::Argument, 4> ivs;
2863 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2870 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2878 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2886 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2897 Region *body =
result.addRegion();
2898 for (
auto &iv : ivs)
2905 ParallelOp::getOperandSegmentSizeAttr(),
2907 static_cast<int32_t>(upper.size()),
2908 static_cast<int32_t>(steps.size()),
2909 static_cast<int32_t>(initVals.size())}));
2918 ParallelOp::ensureTerminator(*body, builder,
result.location);
2922void ParallelOp::print(OpAsmPrinter &p) {
2923 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2924 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2925 if (!getInitVals().empty())
2926 p <<
" init (" << getInitVals() <<
")";
2931 (*this)->getAttrs(),
2932 ParallelOp::getOperandSegmentSizeAttr());
2935SmallVector<Region *> ParallelOp::getLoopRegions() {
return {&getRegion()}; }
2937std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2938 return SmallVector<Value>{getBody()->getArguments()};
2941std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2945std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2949std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2954 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2956 return ParallelOp();
2957 assert(ivArg.getOwner() &&
"unlinked block argument");
2958 auto *containingOp = ivArg.getOwner()->getParentOp();
2959 return dyn_cast<ParallelOp>(containingOp);
2964struct ParallelOpSingleOrZeroIterationDimsFolder
2968 LogicalResult matchAndRewrite(ParallelOp op,
2975 for (
auto [lb,
ub, step, iv] :
2976 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2977 op.getInductionVars())) {
2978 auto numIterations =
2980 if (numIterations.has_value()) {
2982 if (*numIterations == 0) {
2983 rewriter.
replaceOp(op, op.getInitVals());
2988 if (*numIterations == 1) {
2993 newLowerBounds.push_back(lb);
2994 newUpperBounds.push_back(ub);
2995 newSteps.push_back(step);
2998 if (newLowerBounds.size() == op.getLowerBound().size())
3001 if (newLowerBounds.empty()) {
3004 SmallVector<Value> results;
3005 results.reserve(op.getInitVals().size());
3006 for (
auto &bodyOp : op.getBody()->without_terminator())
3007 rewriter.
clone(bodyOp, mapping);
3008 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3009 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3010 Block &reduceBlock = reduceOp.getReductions()[i].front();
3011 auto initValIndex = results.size();
3012 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3016 rewriter.
clone(reduceBodyOp, mapping);
3019 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3020 results.push_back(
result);
3028 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3029 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3035 newOp.getRegion().begin(), mapping);
3036 rewriter.
replaceOp(op, newOp.getResults());
3041struct MergeNestedParallelLoops :
public OpRewritePattern<ParallelOp> {
3042 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3044 LogicalResult matchAndRewrite(ParallelOp op,
3045 PatternRewriter &rewriter)
const override {
3046 Block &outerBody = *op.getBody();
3050 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3055 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3056 llvm::is_contained(innerOp.getUpperBound(), val) ||
3057 llvm::is_contained(innerOp.getStep(), val))
3061 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3064 auto bodyBuilder = [&](OpBuilder &builder, Location ,
3066 Block &innerBody = *innerOp.getBody();
3067 assert(iterVals.size() ==
3075 builder.
clone(op, mapping);
3078 auto concatValues = [](
const auto &first,
const auto &second) {
3079 SmallVector<Value> ret;
3080 ret.reserve(first.size() + second.size());
3081 ret.assign(first.begin(), first.end());
3082 ret.append(second.begin(), second.end());
3086 auto newLowerBounds =
3087 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3088 auto newUpperBounds =
3089 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3090 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3101void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3102 MLIRContext *context) {
3104 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3113void ParallelOp::getSuccessorRegions(
3114 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3118 regions.push_back(RegionSuccessor(&getRegion()));
3126void ReduceOp::build(OpBuilder &builder, OperationState &
result) {}
3128void ReduceOp::build(OpBuilder &builder, OperationState &
result,
3130 result.addOperands(operands);
3131 for (Value v : operands) {
3132 OpBuilder::InsertionGuard guard(builder);
3133 Region *bodyRegion =
result.addRegion();
3135 ArrayRef<Type>{v.getType(), v.getType()},
3140LogicalResult ReduceOp::verifyRegions() {
3141 if (getReductions().size() != getOperands().size())
3142 return emitOpError() <<
"expects number of reduction regions: "
3143 << getReductions().size()
3144 <<
" to be the same as number of reduction operands: "
3145 << getOperands().size();
3148 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3149 auto type = getOperands()[i].getType();
3150 Block &block = getReductions()[i].front();
3152 return emitOpError() << i <<
"-th reduction has an empty body";
3154 llvm::any_of(block.
getArguments(), [&](
const BlockArgument &arg) {
3155 return arg.getType() != type;
3157 return emitOpError() <<
"expected two block arguments with type " << type
3158 <<
" in the " << i <<
"-th reduction region";
3162 return emitOpError(
"reduction bodies must be terminated with an "
3163 "'scf.reduce.return' op");
3170ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3172 return MutableOperandRange(getOperation(), 0, 0);
3179LogicalResult ReduceReturnOp::verify() {
3182 Block *reductionBody = getOperation()->getBlock();
3184 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3186 if (expectedResultType != getResult().
getType())
3187 return emitOpError() <<
"must have type " << expectedResultType
3188 <<
" (the type of the reduction inputs)";
3196void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3197 ::mlir::OperationState &odsState,
TypeRange resultTypes,
3198 ValueRange inits, BodyBuilderFn beforeBuilder,
3199 BodyBuilderFn afterBuilder) {
3203 OpBuilder::InsertionGuard guard(odsBuilder);
3206 SmallVector<Location, 4> beforeArgLocs;
3207 beforeArgLocs.reserve(inits.size());
3208 for (Value operand : inits) {
3209 beforeArgLocs.push_back(operand.getLoc());
3212 Region *beforeRegion = odsState.
addRegion();
3214 inits.getTypes(), beforeArgLocs);
3219 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.
location);
3221 Region *afterRegion = odsState.
addRegion();
3223 resultTypes, afterArgLocs);
3229ConditionOp WhileOp::getConditionOp() {
3230 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3233YieldOp WhileOp::getYieldOp() {
3234 return cast<YieldOp>(getAfterBody()->getTerminator());
3237std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3238 return getYieldOp().getResultsMutable();
3242 return getBeforeBody()->getArguments();
3246 return getAfterBody()->getArguments();
3250 return getBeforeArguments();
3253OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3255 "WhileOp is expected to branch only to the first region");
3259void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3260 SmallVectorImpl<RegionSuccessor> ®ions) {
3263 regions.emplace_back(&getBefore());
3267 assert(llvm::is_contained(
3268 {&getAfter(), &getBefore()},
3270 "there are only two regions in a WhileOp");
3274 regions.emplace_back(&getBefore());
3279 regions.emplace_back(&getAfter());
3282ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
3284 return getOperation()->getResults();
3285 if (successor == &getBefore())
3286 return getBefore().getArguments();
3287 if (successor == &getAfter())
3288 return getAfter().getArguments();
3289 llvm_unreachable(
"invalid region successor");
3292SmallVector<Region *> WhileOp::getLoopRegions() {
3293 return {&getBefore(), &getAfter()};
3303ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
3304 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3305 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3306 Region *before =
result.addRegion();
3307 Region *after =
result.addRegion();
3309 OptionalParseResult listResult =
3314 FunctionType functionType;
3319 result.addTypes(functionType.getResults());
3321 if (functionType.getNumInputs() != operands.size()) {
3323 <<
"expected as many input types as operands " <<
"(expected "
3324 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3334 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3335 regionArgs[i].type = functionType.getInput(i);
3337 return failure(parser.
parseRegion(*before, regionArgs) ||
3343void scf::WhileOp::print(OpAsmPrinter &p) {
3357template <
typename OpTy>
3360 if (left.size() != right.size())
3361 return op.emitOpError(
"expects the same number of ") << message;
3363 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3364 if (left[i] != right[i]) {
3367 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3368 <<
" and " << right[i];
3376LogicalResult scf::WhileOp::verify() {
3379 "expects the 'before' region to terminate with 'scf.condition'");
3380 if (!beforeTerminator)
3385 "expects the 'after' region to terminate with 'scf.yield'");
3386 return success(afterTerminator !=
nullptr);
3423struct WhileMoveIfDown :
public OpRewritePattern<scf::WhileOp> {
3424 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3426 LogicalResult matchAndRewrite(scf::WhileOp op,
3427 PatternRewriter &rewriter)
const override {
3428 auto conditionOp = op.getConditionOp();
3436 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3442 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3443 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3446 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3447 *ifOp->user_begin() == conditionOp)) &&
3448 "ifOp has unexpected uses");
3450 Location loc = op.getLoc();
3454 for (
auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3455 auto it = llvm::find(ifOp->getResults(), arg);
3456 if (it != ifOp->getResults().end()) {
3457 size_t ifOpIdx = it.getIndex();
3458 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3459 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3469 if (&op.getBefore() == operand->get().getParentRegion())
3470 additionalUsedValuesSet.insert(operand->get());
3474 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3475 auto additionalValueTypes = llvm::map_to_vector(
3476 additionalUsedValues, [](Value val) {
return val.
getType(); });
3477 size_t additionalValueSize = additionalUsedValues.size();
3478 SmallVector<Type> newResultTypes(op.getResultTypes());
3479 newResultTypes.append(additionalValueTypes);
3482 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3485 newWhileOp.getBefore().takeBody(op.getBefore());
3486 newWhileOp.getAfter().takeBody(op.getAfter());
3487 newWhileOp.getAfter().addArguments(
3488 additionalValueTypes,
3489 SmallVector<Location>(additionalValueSize, loc));
3493 conditionOp.getArgsMutable().append(additionalUsedValues);
3499 additionalUsedValues,
3500 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3501 [&](OpOperand &use) {
3502 return ifOp.getThenRegion().isAncestor(
3503 use.getOwner()->getParentRegion());
3507 rewriter.
eraseOp(ifOp.thenYield());
3509 newWhileOp.getAfterBody()->begin());
3512 newWhileOp->getResults().drop_back(additionalValueSize));
3536struct WhileConditionTruth :
public OpRewritePattern<WhileOp> {
3537 using OpRewritePattern<WhileOp>::OpRewritePattern;
3539 LogicalResult matchAndRewrite(WhileOp op,
3540 PatternRewriter &rewriter)
const override {
3541 auto term = op.getConditionOp();
3545 Value constantTrue =
nullptr;
3547 bool replaced =
false;
3548 for (
auto yieldedAndBlockArgs :
3549 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3550 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3551 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3553 constantTrue = arith::ConstantOp::create(
3554 rewriter, op.getLoc(), term.getCondition().getType(),
3589struct WhileCmpCond :
public OpRewritePattern<scf::WhileOp> {
3590 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3592 LogicalResult matchAndRewrite(scf::WhileOp op,
3593 PatternRewriter &rewriter)
const override {
3594 using namespace scf;
3595 auto cond = op.getConditionOp();
3596 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3599 bool changed =
false;
3600 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3601 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3602 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3605 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3606 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3610 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3613 if (cmp2.getPredicate() == cmp.getPredicate())
3614 samePredicate =
true;
3615 else if (cmp2.getPredicate() ==
3616 arith::invertPredicate(cmp.getPredicate()))
3617 samePredicate =
false;
3633static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
3635 if (args1.size() != args2.size())
3636 return std::nullopt;
3638 SmallVector<unsigned> ret(args1.size());
3639 for (
auto &&[i, arg1] : llvm::enumerate(args1)) {
3640 auto it = llvm::find(args2, arg1);
3641 if (it == args2.end())
3642 return std::nullopt;
3644 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
3651 llvm::SmallDenseSet<Value> set;
3652 for (Value arg : args) {
3653 if (!set.insert(arg).second)
3663struct WhileOpAlignBeforeArgs :
public OpRewritePattern<WhileOp> {
3666 LogicalResult matchAndRewrite(WhileOp loop,
3667 PatternRewriter &rewriter)
const override {
3668 auto *oldBefore = loop.getBeforeBody();
3669 ConditionOp oldTerm = loop.getConditionOp();
3670 ValueRange beforeArgs = oldBefore->getArguments();
3672 if (beforeArgs == termArgs)
3675 if (hasDuplicates(termArgs))
3678 auto mapping = getArgsMapping(beforeArgs, termArgs);
3683 OpBuilder::InsertionGuard g(rewriter);
3689 auto *oldAfter = loop.getAfterBody();
3691 SmallVector<Type> newResultTypes(beforeArgs.size());
3692 for (
auto &&[i, j] : llvm::enumerate(*mapping))
3693 newResultTypes[j] = loop.getResult(i).getType();
3695 auto newLoop = WhileOp::create(
3696 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3698 auto *newBefore = newLoop.getBeforeBody();
3699 auto *newAfter = newLoop.getAfterBody();
3701 SmallVector<Value> newResults(beforeArgs.size());
3702 SmallVector<Value> newAfterArgs(beforeArgs.size());
3703 for (
auto &&[i, j] : llvm::enumerate(*mapping)) {
3704 newResults[i] = newLoop.getResult(j);
3705 newAfterArgs[i] = newAfter->getArgument(j);
3709 newBefore->getArguments());
3719void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3720 MLIRContext *context) {
3721 results.
add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3722 WhileMoveIfDown>(context);
3724 results, WhileOp::getOperationName());
3726 WhileOp::getOperationName());
3740 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
3743 caseValues.push_back(value);
3752 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
3754 p <<
"case " << value <<
' ';
3759LogicalResult scf::IndexSwitchOp::verify() {
3760 if (getCases().size() != getCaseRegions().size()) {
3762 << getCaseRegions().size() <<
" case regions but "
3763 << getCases().size() <<
" case values";
3767 for (int64_t value : getCases())
3768 if (!valueSet.insert(value).second)
3769 return emitOpError(
"has duplicate case value: ") << value;
3770 auto verifyRegion = [&](Region ®ion,
const Twine &name) -> LogicalResult {
3771 auto yield = dyn_cast<YieldOp>(region.
front().
back());
3773 return emitOpError(
"expected region to end with scf.yield, but got ")
3776 if (yield.getNumOperands() != getNumResults()) {
3777 return (
emitOpError(
"expected each region to return ")
3778 << getNumResults() <<
" values, but " << name <<
" returns "
3779 << yield.getNumOperands())
3780 .attachNote(yield.getLoc())
3781 <<
"see yield operation here";
3783 for (
auto [idx,
result, operand] :
3784 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3786 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
3787 if (
result == operand.getType())
3790 << idx <<
" of each region to be " <<
result)
3791 .attachNote(yield.getLoc())
3792 << name <<
" returns " << operand.getType() <<
" here";
3799 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3806unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
3808Block &scf::IndexSwitchOp::getDefaultBlock() {
3809 return getDefaultRegion().front();
3812Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
3813 assert(idx < getNumCases() &&
"case index out-of-bounds");
3814 return getCaseRegions()[idx].front();
3817void IndexSwitchOp::getSuccessorRegions(
3818 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3825 llvm::append_range(successors, getRegions());
3828ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
3833void IndexSwitchOp::getEntrySuccessorRegions(
3834 ArrayRef<Attribute> operands,
3835 SmallVectorImpl<RegionSuccessor> &successors) {
3836 FoldAdaptor adaptor(operands, *
this);
3839 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3841 llvm::append_range(successors, getRegions());
3847 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3848 if (caseValue == arg.getInt()) {
3849 successors.emplace_back(&caseRegion);
3853 successors.emplace_back(&getDefaultRegion());
3856void IndexSwitchOp::getRegionInvocationBounds(
3857 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3858 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3859 if (!operandValue) {
3861 bounds.append(getNumRegions(), InvocationBounds(0, 1));
3865 unsigned liveIndex = getNumRegions() - 1;
3866 const auto *it = llvm::find(getCases(), operandValue.getInt());
3867 if (it != getCases().end())
3868 liveIndex = std::distance(getCases().begin(), it);
3869 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
3870 bounds.emplace_back(0, i == liveIndex);
3873void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3874 MLIRContext *context) {
3876 results, IndexSwitchOp::getOperationName());
3878 results, IndexSwitchOp::getOperationName());
3885#define GET_OP_CLASSES
3886#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
static Value defaultReplBuilderFn(OpBuilder &builder, Location loc, Value value)
Default implementation of the non-successor-input replacement builder function.
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.