30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/DebugLog.h"
40#include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
47struct SCFInlinerInterface :
public DialectInlinerInterface {
48 using DialectInlinerInterface::DialectInlinerInterface;
52 IRMapping &valueMapping)
const final {
57 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
62 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
63 auto retValOp = dyn_cast<scf::YieldOp>(op);
67 for (
auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
68 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
78void SCFDialect::initialize() {
81#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
83 addInterfaces<SCFInlinerInterface>();
84 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
85 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
86 InParallelOp, ReduceReturnOp>();
87 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
88 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
89 ForallOp, InParallelOp, WhileOp, YieldOp>();
90 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
95 scf::YieldOp::create(builder, loc);
100template <
typename TerminatorTy>
102 StringRef errorMessage) {
103 Operation *terminatorOperation =
nullptr;
105 terminatorOperation = ®ion.
front().
back();
106 if (
auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
110 if (terminatorOperation)
111 diag.attachNote(terminatorOperation->
getLoc()) <<
"terminator here";
118 auto addOp =
ub.getDefiningOp<arith::AddIOp>();
121 if ((isSigned && !addOp.hasNoSignedWrap()) ||
122 (!isSigned && !addOp.hasNoUnsignedWrap()))
125 if (addOp.getLhs() != lb ||
146ParseResult ExecuteRegionOp::parse(
OpAsmParser &parser,
174LogicalResult ExecuteRegionOp::verify() {
175 if (getRegion().empty())
176 return emitOpError(
"region needs to have at least one block");
177 if (getRegion().front().getNumArguments() > 0)
178 return emitOpError(
"region cannot have any arguments");
224 if (op.getNoInline())
226 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
229 Block *prevBlock = op->getBlock();
233 cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
235 for (
Block &blk : op.getRegion()) {
236 if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
238 cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
239 yieldOp.getResults());
247 for (
auto res : op.getResults())
248 blockArgs.push_back(postBlock->
addArgument(res.getType(), res.getLoc()));
259 results, ExecuteRegionOp::getOperationName());
262 results, ExecuteRegionOp::getOperationName(),
264 return failure(cast<ExecuteRegionOp>(op).getNoInline());
268void ExecuteRegionOp::getSuccessorRegions(
293 "condition op can only exit the loop or branch to the after"
296 return getArgsMutable();
299void ConditionOp::getSuccessorRegions(
301 FoldAdaptor adaptor(operands, *
this);
303 WhileOp whileOp = getParentOp();
307 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
308 if (!boolAttr || boolAttr.getValue())
309 regions.emplace_back(&whileOp.getAfter());
310 if (!boolAttr || !boolAttr.getValue())
320 BodyBuilderFn bodyBuilder,
bool unsignedCmp) {
324 result.addAttribute(getUnsignedCmpAttrName(
result.name),
327 result.addOperands(initArgs);
328 for (
Value v : initArgs)
329 result.addTypes(v.getType());
334 for (
Value v : initArgs)
340 if (initArgs.empty() && !bodyBuilder) {
341 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
342 }
else if (bodyBuilder) {
350LogicalResult ForOp::verify() {
355 if (getBody()->getNumArguments() < getNumInductionVars())
356 return emitOpError(
"expected body to have at least ")
357 << getNumInductionVars()
358 <<
" argument(s) for the induction variable, but got "
359 << getBody()->getNumArguments();
362 if (getInitArgs().size() != getNumResults())
364 "mismatch in number of loop-carried values and defined values");
369LogicalResult ForOp::verifyRegions() {
371 if (getBody()->getNumArguments() < getNumInductionVars())
372 return emitOpError(
"expected body to have at least ")
373 << getNumInductionVars() <<
" argument(s) for the induction "
374 <<
"variable, but got " << getBody()->getNumArguments();
380 "expected induction variable to be same type as bounds and step");
382 if (getNumRegionIterArgs() != getNumResults())
384 "mismatch in number of basic block args and defined values");
386 auto initArgs = getInitArgs();
387 auto iterArgs = getRegionIterArgs();
388 auto opResults = getResults();
390 for (
auto e : llvm::zip(initArgs, iterArgs, opResults)) {
392 return emitOpError() <<
"types mismatch between " << i
393 <<
"th iter operand and defined value";
395 return emitOpError() <<
"types mismatch between " << i
396 <<
"th iter region arg and defined value";
403std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
407std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
411std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
415std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
419bool ForOp::isValidInductionVarType(
Type type) {
424 if (bounds.size() != 1)
426 if (
auto val = dyn_cast<Value>(bounds[0])) {
434 if (bounds.size() != 1)
436 if (
auto val = dyn_cast<Value>(bounds[0])) {
444 if (steps.size() != 1)
446 if (
auto val = dyn_cast<Value>(steps[0])) {
453std::optional<ResultRange> ForOp::getLoopResults() {
return getResults(); }
457LogicalResult ForOp::promoteIfSingleIteration(
RewriterBase &rewriter) {
458 std::optional<APInt> tripCount = getStaticTripCount();
459 LDBG() <<
"promoteIfSingleIteration tripCount is " << tripCount
462 if (!tripCount.has_value() || tripCount->getZExtValue() > 1)
465 if (*tripCount == 0) {
472 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
479 llvm::append_range(bbArgReplacements, getInitArgs());
483 getOperation()->getIterator(), bbArgReplacements);
499 StringRef prefix =
"") {
500 assert(blocksArgs.size() == initializers.size() &&
501 "expected same length of arguments and initializers");
502 if (initializers.empty())
506 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
507 p << std::get<0>(it) <<
" = " << std::get<1>(it);
513 if (getUnsignedCmp())
516 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
520 if (!getInitArgs().empty())
521 p <<
" -> (" << getInitArgs().getTypes() <<
')';
524 p <<
" : " << t <<
' ';
527 !getInitArgs().empty());
529 getUnsignedCmpAttrName().strref());
540 result.addAttribute(getUnsignedCmpAttrName(
result.name),
554 regionArgs.push_back(inductionVariable);
564 if (regionArgs.size() !=
result.types.size() + 1)
567 "mismatch in number of loop-carried values and defined values");
576 regionArgs.front().type = type;
577 for (
auto [iterArg, type] :
578 llvm::zip_equal(llvm::drop_begin(regionArgs),
result.types))
585 ForOp::ensureTerminator(*body, builder,
result.location);
594 for (
auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
595 operands,
result.types)) {
596 Type type = std::get<2>(argOperandType);
597 std::get<0>(argOperandType).type = type;
614 return getBody()->getArguments().drop_front(getNumInductionVars());
618 return getInitArgsMutable();
621FailureOr<LoopLikeOpInterface>
622ForOp::replaceWithAdditionalYields(
RewriterBase &rewriter,
624 bool replaceInitOperandUsesInLoop,
629 auto inits = llvm::to_vector(getInitArgs());
630 inits.append(newInitOperands.begin(), newInitOperands.end());
631 scf::ForOp newLoop = scf::ForOp::create(
637 auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
639 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
644 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
645 assert(newInitOperands.size() == newYieldedValues.size() &&
646 "expected as many new yield values as new iter operands");
648 yieldOp.getResultsMutable().append(newYieldedValues);
654 newLoop.getBody()->getArguments().take_front(
655 getBody()->getNumArguments()));
657 if (replaceInitOperandUsesInLoop) {
660 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
671 newLoop->getResults().take_front(getNumResults()));
672 return cast<LoopLikeOpInterface>(newLoop.getOperation());
676 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
679 assert(ivArg.getOwner() &&
"unlinked block argument");
680 auto *containingOp = ivArg.getOwner()->getParentOp();
681 return dyn_cast_or_null<ForOp>(containingOp);
685 return getInitArgs();
690 if (std::optional<APInt> tripCount = getStaticTripCount()) {
693 if (*tripCount == 0) {
702 }
else if (*tripCount == 1) {
726LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
727 for (
auto [lb, ub, step] :
728 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
731 if (!tripCount.has_value() || *tripCount != 1)
740 return getBody()->getArguments().drop_front(getRank());
743MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
744 return getOutputsMutable();
750 scf::InParallelOp terminator = forallOp.getTerminator();
755 bbArgReplacements.append(forallOp.getOutputs().begin(),
756 forallOp.getOutputs().end());
760 forallOp->getIterator(), bbArgReplacements);
765 results.reserve(forallOp.getResults().size());
766 for (
auto &yieldingOp : terminator.getYieldingOps()) {
767 auto parallelInsertSliceOp =
768 dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
769 if (!parallelInsertSliceOp)
772 Value dst = parallelInsertSliceOp.getDest();
773 Value src = parallelInsertSliceOp.getSource();
774 if (llvm::isa<TensorType>(src.
getType())) {
775 results.push_back(tensor::InsertSliceOp::create(
776 rewriter, forallOp.getLoc(), dst.
getType(), src, dst,
777 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
778 parallelInsertSliceOp.getStrides(),
779 parallelInsertSliceOp.getStaticOffsets(),
780 parallelInsertSliceOp.getStaticSizes(),
781 parallelInsertSliceOp.getStaticStrides()));
783 llvm_unreachable(
"unsupported terminator");
798 assert(lbs.size() == ubs.size() &&
799 "expected the same number of lower and upper bounds");
800 assert(lbs.size() == steps.size() &&
801 "expected the same number of lower bounds and steps");
806 bodyBuilder ? bodyBuilder(builder, loc,
ValueRange(), iterArgs)
808 assert(results.size() == iterArgs.size() &&
809 "loop nest body must return as many values as loop has iteration "
811 return LoopNest{{}, std::move(results)};
819 loops.reserve(lbs.size());
820 ivs.reserve(lbs.size());
823 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
824 auto loop = scf::ForOp::create(
825 builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
831 currentIterArgs = args;
832 currentLoc = nestedLoc;
838 loops.push_back(loop);
842 for (
unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
844 scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
851 ? bodyBuilder(builder, currentLoc, ivs,
852 loops.back().getRegionIterArgs())
854 assert(results.size() == iterArgs.size() &&
855 "loop nest body must return as many values as loop has iteration "
858 scf::YieldOp::create(builder, loc, results);
862 llvm::append_range(nestResults, loops.front().getResults());
863 return LoopNest{std::move(loops), std::move(nestResults)};
876 bodyBuilder(nestedBuilder, nestedLoc, ivs);
885 assert(operand.
getOwner() == forOp);
890 "expected an iter OpOperand");
892 "Expected a different type");
894 for (
OpOperand &opOperand : forOp.getInitArgsMutable()) {
899 newIterOperands.push_back(opOperand.get());
903 scf::ForOp newForOp = scf::ForOp::create(
904 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
905 forOp.getStep(), newIterOperands,
nullptr,
906 forOp.getUnsignedCmp());
907 newForOp->setAttrs(forOp->getAttrs());
908 Block &newBlock = newForOp.getRegion().
front();
916 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
918 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
919 newBlockTransferArgs[newRegionIterArg.
getArgNumber()] = castIn;
923 rewriter.
mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
926 auto clonedYieldOp = cast<scf::YieldOp>(newBlock.
getTerminator());
929 newRegionIterArg.
getArgNumber() - forOp.getNumInductionVars();
930 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
931 clonedYieldOp.getOperand(yieldIdx));
933 newYieldOperands[yieldIdx] = castOut;
934 scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
935 rewriter.
eraseOp(clonedYieldOp);
940 newResults[yieldIdx] =
941 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
976 LogicalResult matchAndRewrite(ForOp op,
978 for (
auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
979 OpOperand &iterOpOperand = std::get<0>(it);
982 incomingCast.getSource().getType() == incomingCast.getType())
987 incomingCast.getDest().getType(),
988 incomingCast.getSource().getType()))
990 if (!std::get<1>(it).hasOneUse())
996 rewriter, op, iterOpOperand, incomingCast.getSource(),
998 return tensor::CastOp::create(b, loc, type, source);
1007void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1008 MLIRContext *context) {
1009 results.
add<ForOpTensorCastFolder>(context);
1011 results, ForOp::getOperationName());
1013 results, ForOp::getOperationName(),
1014 [](OpBuilder &builder, Location loc, Value value) {
1018 auto blockArg = cast<BlockArgument>(value);
1019 assert(blockArg.getArgNumber() == 0 &&
"expected induction variable");
1020 auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
1021 return forOp.getLowerBound();
1025std::optional<APInt> ForOp::getConstantStep() {
1028 return step.getValue();
1032std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1033 return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1039 if (
auto constantStep = getConstantStep())
1040 if (*constantStep == 1)
1048std::optional<APInt> ForOp::getStaticTripCount() {
1057LogicalResult ForallOp::verify() {
1058 unsigned numLoops = getRank();
1060 if (getNumResults() != getOutputs().size())
1062 << getNumResults() <<
" results, but has only "
1063 << getOutputs().size() <<
" outputs";
1066 auto *body = getBody();
1068 return emitOpError(
"region expects ") << numLoops <<
" arguments";
1069 for (int64_t i = 0; i < numLoops; ++i)
1072 << i <<
"-th block argument to be an index";
1073 for (
unsigned i = 0; i < getOutputs().size(); ++i)
1076 << i <<
"-th output and corresponding block argument";
1077 if (getMapping().has_value() && !getMapping()->empty()) {
1078 if (getDeviceMappingAttrs().size() != numLoops)
1079 return emitOpError() <<
"mapping attribute size must match op rank";
1080 if (
failed(getDeviceMaskingAttr()))
1082 <<
" supports at most one device masking attribute";
1086 Operation *op = getOperation();
1088 getStaticLowerBound(),
1089 getDynamicLowerBound())))
1092 getStaticUpperBound(),
1093 getDynamicUpperBound())))
1096 getStaticStep(), getDynamicStep())))
1102void ForallOp::print(OpAsmPrinter &p) {
1103 Operation *op = getOperation();
1104 p <<
" (" << getInductionVars();
1105 if (isNormalized()) {
1126 if (!getRegionOutArgs().empty())
1127 p <<
"-> (" << getResultTypes() <<
") ";
1128 p.printRegion(getRegion(),
1130 getNumResults() > 0);
1131 p.printOptionalAttrDict(op->
getAttrs(), {getOperandSegmentSizesAttrName(),
1132 getStaticLowerBoundAttrName(),
1133 getStaticUpperBoundAttrName(),
1134 getStaticStepAttrName()});
1137ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &
result) {
1139 auto indexType =
b.getIndexType();
1144 SmallVector<OpAsmParser::Argument, 4> ivs;
1149 SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1159 unsigned numLoops = ivs.size();
1160 staticLbs =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1161 staticSteps =
b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1190 SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
1191 SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
1194 if (outOperands.size() !=
result.types.size())
1196 "mismatch between out operands and types");
1205 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1206 std::unique_ptr<Region> region = std::make_unique<Region>();
1207 for (
auto &iv : ivs) {
1208 iv.type =
b.getIndexType();
1209 regionArgs.push_back(iv);
1211 for (
const auto &it : llvm::enumerate(regionOutArgs)) {
1212 auto &out = it.value();
1213 out.type =
result.types[it.index()];
1214 regionArgs.push_back(out);
1220 ForallOp::ensureTerminator(*region,
b,
result.location);
1221 result.addRegion(std::move(region));
1227 result.addAttribute(
"staticLowerBound", staticLbs);
1228 result.addAttribute(
"staticUpperBound", staticUbs);
1229 result.addAttribute(
"staticStep", staticSteps);
1230 result.addAttribute(
"operandSegmentSizes",
1232 {static_cast<int32_t>(dynamicLbs.size()),
1233 static_cast<int32_t>(dynamicUbs.size()),
1234 static_cast<int32_t>(dynamicSteps.size()),
1235 static_cast<int32_t>(outOperands.size())}));
1240void ForallOp::build(
1241 mlir::OpBuilder &
b, mlir::OperationState &
result,
1242 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
1243 ArrayRef<OpFoldResult> steps,
ValueRange outputs,
1244 std::optional<ArrayAttr> mapping,
1246 SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1247 SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1252 result.addOperands(dynamicLbs);
1253 result.addOperands(dynamicUbs);
1254 result.addOperands(dynamicSteps);
1255 result.addOperands(outputs);
1258 result.addAttribute(getStaticLowerBoundAttrName(
result.name),
1259 b.getDenseI64ArrayAttr(staticLbs));
1260 result.addAttribute(getStaticUpperBoundAttrName(
result.name),
1261 b.getDenseI64ArrayAttr(staticUbs));
1262 result.addAttribute(getStaticStepAttrName(
result.name),
1263 b.getDenseI64ArrayAttr(staticSteps));
1265 "operandSegmentSizes",
1266 b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1267 static_cast<int32_t>(dynamicUbs.size()),
1268 static_cast<int32_t>(dynamicSteps.size()),
1269 static_cast<int32_t>(outputs.size())}));
1270 if (mapping.has_value()) {
1271 result.addAttribute(ForallOp::getMappingAttrName(
result.name),
1275 Region *bodyRegion =
result.addRegion();
1276 OpBuilder::InsertionGuard g(
b);
1277 b.createBlock(bodyRegion);
1282 SmallVector<Type>(lbs.size(),
b.getIndexType()),
1283 SmallVector<Location>(staticLbs.size(),
result.location));
1286 SmallVector<Location>(outputs.size(),
result.location));
1288 b.setInsertionPointToStart(&bodyBlock);
1289 if (!bodyBuilderFn) {
1290 ForallOp::ensureTerminator(*bodyRegion,
b,
result.location);
1297void ForallOp::build(
1298 mlir::OpBuilder &
b, mlir::OperationState &
result,
1299 ArrayRef<OpFoldResult> ubs,
ValueRange outputs,
1300 std::optional<ArrayAttr> mapping,
1302 unsigned numLoops = ubs.size();
1303 SmallVector<OpFoldResult> lbs(numLoops,
b.getIndexAttr(0));
1304 SmallVector<OpFoldResult> steps(numLoops,
b.getIndexAttr(1));
1305 build(
b,
result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1309bool ForallOp::isNormalized() {
1310 auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1311 return llvm::all_of(results, [&](OpFoldResult ofr) {
1313 return intValue.has_value() && intValue == val;
1316 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1319InParallelOp ForallOp::getTerminator() {
1320 return cast<InParallelOp>(getBody()->getTerminator());
1323SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1324 SmallVector<Operation *> storeOps;
1325 for (Operation *user : bbArg.
getUsers()) {
1326 if (
auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1327 storeOps.push_back(parallelOp);
1333SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1334 SmallVector<DeviceMappingAttrInterface> res;
1337 for (
auto attr : getMapping()->getValue()) {
1338 auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1345FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1346 DeviceMaskingAttrInterface res;
1349 for (
auto attr : getMapping()->getValue()) {
1350 auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1359bool ForallOp::usesLinearMapping() {
1360 SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1363 return ifaces.front().isLinearMapping();
1366std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1367 return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1371std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1373 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(),
b);
1377std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1379 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(),
b);
1383std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1389 auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1392 assert(tidxArg.getOwner() &&
"unlinked block argument");
1393 auto *containingOp = tidxArg.getOwner()->getParentOp();
1394 return dyn_cast<ForallOp>(containingOp);
1402 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1404 auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1408 forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1411 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1416class ForallOpControlOperandsFolder :
public OpRewritePattern<ForallOp> {
1418 using OpRewritePattern<ForallOp>::OpRewritePattern;
1420 LogicalResult matchAndRewrite(ForallOp op,
1421 PatternRewriter &rewriter)
const override {
1422 SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1423 SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1424 SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1431 SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1432 SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1435 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1436 op.setStaticLowerBound(staticLowerBound);
1440 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1441 op.setStaticUpperBound(staticUpperBound);
1444 op.getDynamicStepMutable().assign(dynamicStep);
1445 op.setStaticStep(staticStep);
1447 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1449 {static_cast<int32_t>(dynamicLowerBound.size()),
1450 static_cast<int32_t>(dynamicUpperBound.size()),
1451 static_cast<int32_t>(dynamicStep.size()),
1452 static_cast<int32_t>(op.getNumResults())}));
1531struct ForallOpIterArgsFolder :
public OpRewritePattern<ForallOp> {
1532 using OpRewritePattern<ForallOp>::OpRewritePattern;
1534 LogicalResult matchAndRewrite(ForallOp forallOp,
1535 PatternRewriter &rewriter)
const final {
1546 SmallVector<Value> resultsToDelete;
1547 SmallVector<Value> outsToDelete;
1548 SmallVector<BlockArgument> blockArgsToDelete;
1549 SmallVector<Value> newOuts;
1550 BitVector resultIndicesToDelete(forallOp.getNumResults(),
false);
1551 BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
1553 for (OpResult
result : forallOp.getResults()) {
1554 OpOperand *opOperand = forallOp.getTiedOpOperand(
result);
1555 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1556 if (
result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1557 resultsToDelete.push_back(
result);
1558 outsToDelete.push_back(opOperand->
get());
1559 blockArgsToDelete.push_back(blockArg);
1560 resultIndicesToDelete[
result.getResultNumber()] =
true;
1563 newOuts.push_back(opOperand->
get());
1569 if (resultsToDelete.empty())
1574 for (
auto blockArg : blockArgsToDelete) {
1575 SmallVector<Operation *> combiningOps =
1576 forallOp.getCombiningOps(blockArg);
1577 for (Operation *combiningOp : combiningOps)
1578 rewriter.
eraseOp(combiningOp);
1580 for (
auto [blockArg,
result, out] :
1581 llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
1587 forallOp.getBody()->eraseArguments(blockIndicesToDelete);
1592 auto newForallOp = cast<scf::ForallOp>(
1594 newForallOp.getOutputsMutable().assign(newOuts);
1600struct ForallOpSingleOrZeroIterationDimsFolder
1601 :
public OpRewritePattern<ForallOp> {
1602 using OpRewritePattern<ForallOp>::OpRewritePattern;
1604 LogicalResult matchAndRewrite(ForallOp op,
1605 PatternRewriter &rewriter)
const override {
1607 if (op.getMapping().has_value() && !op.getMapping()->empty())
1609 Location loc = op.getLoc();
1612 SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1615 for (
auto [lb, ub, step, iv] :
1616 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1617 op.getMixedStep(), op.getInductionVars())) {
1618 auto numIterations =
1620 if (numIterations.has_value()) {
1622 if (*numIterations == 0) {
1623 rewriter.
replaceOp(op, op.getOutputs());
1628 if (*numIterations == 1) {
1633 newMixedLowerBounds.push_back(lb);
1634 newMixedUpperBounds.push_back(ub);
1635 newMixedSteps.push_back(step);
1639 if (newMixedLowerBounds.empty()) {
1645 if (newMixedLowerBounds.size() ==
static_cast<unsigned>(op.getRank())) {
1647 op,
"no dimensions have 0 or 1 iterations");
1652 newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1653 newMixedUpperBounds, newMixedSteps,
1654 op.getOutputs(), std::nullopt,
nullptr);
1655 newOp.getBodyRegion().getBlocks().clear();
1659 SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1660 newOp.getStaticLowerBoundAttrName(),
1661 newOp.getStaticUpperBoundAttrName(),
1662 newOp.getStaticStepAttrName()};
1663 for (
const auto &namedAttr : op->getAttrs()) {
1664 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1667 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1671 newOp.getRegion().begin(), mapping);
1672 rewriter.
replaceOp(op, newOp.getResults());
1678struct ForallOpReplaceConstantInductionVar :
public OpRewritePattern<ForallOp> {
1679 using OpRewritePattern<ForallOp>::OpRewritePattern;
1681 LogicalResult matchAndRewrite(ForallOp op,
1682 PatternRewriter &rewriter)
const override {
1683 Location loc = op.getLoc();
1684 bool changed =
false;
1685 for (
auto [lb, ub, step, iv] :
1686 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1687 op.getMixedStep(), op.getInductionVars())) {
1690 auto numIterations =
1692 if (!numIterations.has_value() || numIterations.value() != 1) {
1703struct FoldTensorCastOfOutputIntoForallOp
1704 :
public OpRewritePattern<scf::ForallOp> {
1705 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1712 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1713 PatternRewriter &rewriter)
const final {
1714 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1715 llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1716 for (
auto en : llvm::enumerate(newOutputTensors)) {
1717 auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1724 castOp.getSource().getType())) {
1728 tensorCastProducers[en.index()] =
1729 TypeCast{castOp.getSource().getType(), castOp.getType()};
1730 newOutputTensors[en.index()] = castOp.getSource();
1733 if (tensorCastProducers.empty())
1737 Location loc = forallOp.getLoc();
1738 auto newForallOp = ForallOp::create(
1739 rewriter, loc, forallOp.getMixedLowerBound(),
1740 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1741 newOutputTensors, forallOp.getMapping(),
1742 [&](OpBuilder nestedBuilder, Location nestedLoc,
ValueRange bbArgs) {
1743 auto castBlockArgs =
1744 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1745 for (auto [index, cast] : tensorCastProducers) {
1746 Value &oldTypeBBArg = castBlockArgs[index];
1747 oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1748 cast.dstType, oldTypeBBArg);
1752 SmallVector<Value> ivsBlockArgs =
1753 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1754 ivsBlockArgs.append(castBlockArgs);
1756 bbArgs.front().getParentBlock(), ivsBlockArgs);
1767 llvm::SmallDenseSet<Value> newIterArgSet(
1768 newForallOp.getRegionIterArgs().begin(),
1769 newForallOp.getRegionIterArgs().end());
1770 auto terminator = newForallOp.getTerminator();
1771 for (
auto &yieldingOp : terminator.getYieldingOps()) {
1772 auto parallelCombiningOp =
1773 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1774 if (!parallelCombiningOp)
1776 for (OpOperand &dest : parallelCombiningOp.getUpdatedDestinations()) {
1777 auto castOp = dest.get().getDefiningOp<tensor::CastOp>();
1778 if (castOp && newIterArgSet.contains(castOp.getSource()))
1779 dest.set(castOp.getSource());
1785 SmallVector<Value> castResults = newForallOp.getResults();
1786 for (
auto &item : tensorCastProducers) {
1787 Value &oldTypeResult = castResults[item.first];
1788 oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1791 rewriter.
replaceOp(forallOp, castResults);
1798void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1799 MLIRContext *context) {
1800 results.
add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1801 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1802 ForallOpSingleOrZeroIterationDimsFolder,
1803 ForallOpReplaceConstantInductionVar>(context);
1806void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1807 SmallVectorImpl<RegionSuccessor> ®ions) {
1814 regions.push_back(RegionSuccessor(&getRegion()));
1831void InParallelOp::build(OpBuilder &
b, OperationState &
result) {
1832 OpBuilder::InsertionGuard g(
b);
1833 Region *bodyRegion =
result.addRegion();
1834 b.createBlock(bodyRegion);
1837LogicalResult InParallelOp::verify() {
1838 scf::ForallOp forallOp =
1839 dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1841 return this->
emitOpError(
"expected forall op parent");
1843 for (Operation &op : getRegion().front().getOperations()) {
1844 auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1845 if (!parallelCombiningOp) {
1846 return this->
emitOpError(
"expected only ParallelCombiningOpInterface")
1851 MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1852 ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1853 for (OpOperand &dest : dests) {
1854 if (!llvm::is_contained(regionOutArgs, dest.get()))
1855 return op.emitOpError(
"may only insert into an output block argument");
1862void InParallelOp::print(OpAsmPrinter &p) {
1870ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
1873 SmallVector<OpAsmParser::Argument, 8> regionOperands;
1874 std::unique_ptr<Region> region = std::make_unique<Region>();
1878 if (region->empty())
1879 OpBuilder(builder.
getContext()).createBlock(region.get());
1880 result.addRegion(std::move(region));
1888OpResult InParallelOp::getParentResult(int64_t idx) {
1889 return getOperation()->getParentOp()->getResult(idx);
1892SmallVector<BlockArgument> InParallelOp::getDests() {
1893 SmallVector<BlockArgument> updatedDests;
1894 for (Operation &yieldingOp : getYieldingOps()) {
1895 auto parallelCombiningOp =
1896 dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
1897 if (!parallelCombiningOp)
1899 for (OpOperand &updatedOperand :
1900 parallelCombiningOp.getUpdatedDestinations())
1901 updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
1903 return updatedDests;
1906llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1907 return getRegion().front().getOperations();
1915 assert(a &&
"expected non-empty operation");
1916 assert(
b &&
"expected non-empty operation");
1921 if (ifOp->isProperAncestor(
b))
1924 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1925 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*
b));
1927 ifOp = ifOp->getParentOfType<IfOp>();
1935IfOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1936 IfOp::Adaptor adaptor,
1938 if (adaptor.getRegions().empty())
1940 Region *r = &adaptor.getThenRegion();
1946 auto yieldOp = llvm::dyn_cast<YieldOp>(
b.back());
1949 TypeRange types = yieldOp.getOperandTypes();
1950 llvm::append_range(inferredReturnTypes, types);
1956 return build(builder,
result, resultTypes, cond,
false,
1960void IfOp::build(OpBuilder &builder, OperationState &
result,
1961 TypeRange resultTypes, Value cond,
bool addThenBlock,
1962 bool addElseBlock) {
1963 assert((!addElseBlock || addThenBlock) &&
1964 "must not create else block w/o then block");
1965 result.addTypes(resultTypes);
1966 result.addOperands(cond);
1969 OpBuilder::InsertionGuard guard(builder);
1970 Region *thenRegion =
result.addRegion();
1973 Region *elseRegion =
result.addRegion();
1978void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
1979 bool withElseRegion) {
1983void IfOp::build(OpBuilder &builder, OperationState &
result,
1984 TypeRange resultTypes, Value cond,
bool withElseRegion) {
1985 result.addTypes(resultTypes);
1986 result.addOperands(cond);
1989 OpBuilder::InsertionGuard guard(builder);
1990 Region *thenRegion =
result.addRegion();
1992 if (resultTypes.empty())
1993 IfOp::ensureTerminator(*thenRegion, builder,
result.location);
1996 Region *elseRegion =
result.addRegion();
1997 if (withElseRegion) {
1999 if (resultTypes.empty())
2000 IfOp::ensureTerminator(*elseRegion, builder,
result.location);
2004void IfOp::build(OpBuilder &builder, OperationState &
result, Value cond,
2006 function_ref<
void(OpBuilder &, Location)> elseBuilder) {
2007 assert(thenBuilder &&
"the builder callback for 'then' must be present");
2008 result.addOperands(cond);
2011 OpBuilder::InsertionGuard guard(builder);
2012 Region *thenRegion =
result.addRegion();
2014 thenBuilder(builder,
result.location);
2017 Region *elseRegion =
result.addRegion();
2020 elseBuilder(builder,
result.location);
2024 SmallVector<Type> inferredReturnTypes;
2026 auto attrDict = DictionaryAttr::get(ctx,
result.attributes);
2027 if (succeeded(inferReturnTypes(ctx, std::nullopt,
result.operands, attrDict,
2029 inferredReturnTypes))) {
2030 result.addTypes(inferredReturnTypes);
2034LogicalResult IfOp::verify() {
2035 if (getNumResults() != 0 && getElseRegion().empty())
2036 return emitOpError(
"must have an else block if defining values");
2040ParseResult IfOp::parse(OpAsmParser &parser, OperationState &
result) {
2042 result.regions.reserve(2);
2043 Region *thenRegion =
result.addRegion();
2044 Region *elseRegion =
result.addRegion();
2047 OpAsmParser::UnresolvedOperand cond;
2073void IfOp::print(OpAsmPrinter &p) {
2074 bool printBlockTerminators =
false;
2076 p <<
" " << getCondition();
2077 if (!getResults().empty()) {
2078 p <<
" -> (" << getResultTypes() <<
")";
2080 printBlockTerminators =
true;
2085 printBlockTerminators);
2088 auto &elseRegion = getElseRegion();
2089 if (!elseRegion.
empty()) {
2093 printBlockTerminators);
2099void IfOp::getSuccessorRegions(RegionBranchPoint point,
2100 SmallVectorImpl<RegionSuccessor> ®ions) {
2108 regions.push_back(RegionSuccessor(&getThenRegion()));
2111 Region *elseRegion = &this->getElseRegion();
2112 if (elseRegion->
empty())
2115 regions.push_back(RegionSuccessor(elseRegion));
2118ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
2123void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2124 SmallVectorImpl<RegionSuccessor> ®ions) {
2125 FoldAdaptor adaptor(operands, *
this);
2126 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2127 if (!boolAttr || boolAttr.getValue())
2128 regions.emplace_back(&getThenRegion());
2131 if (!boolAttr || !boolAttr.getValue()) {
2132 if (!getElseRegion().empty())
2133 regions.emplace_back(&getElseRegion());
2139LogicalResult IfOp::fold(FoldAdaptor adaptor,
2140 SmallVectorImpl<OpFoldResult> &results) {
2142 if (getElseRegion().empty())
2145 arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2152 getConditionMutable().assign(xorStmt.getLhs());
2153 Block *thenBlock = &getThenRegion().front();
2156 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2157 getElseRegion().getBlocks());
2158 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2159 getThenRegion().getBlocks(), thenBlock);
2163void IfOp::getRegionInvocationBounds(
2164 ArrayRef<Attribute> operands,
2165 SmallVectorImpl<InvocationBounds> &invocationBounds) {
2166 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2169 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2170 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2173 invocationBounds.assign(2, {0, 1});
2180struct ConvertTrivialIfToSelect :
public OpRewritePattern<IfOp> {
2181 using OpRewritePattern<IfOp>::OpRewritePattern;
2183 LogicalResult matchAndRewrite(IfOp op,
2184 PatternRewriter &rewriter)
const override {
2185 if (op->getNumResults() == 0)
2188 auto cond = op.getCondition();
2189 auto thenYieldArgs = op.thenYield().getOperands();
2190 auto elseYieldArgs = op.elseYield().getOperands();
2192 SmallVector<Type> nonHoistable;
2193 for (
auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2194 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2195 &op.getElseRegion() == falseVal.getParentRegion())
2196 nonHoistable.push_back(trueVal.getType());
2200 if (nonHoistable.size() == op->getNumResults())
2203 IfOp
replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2207 replacement.getThenRegion().takeBody(op.getThenRegion());
2208 replacement.getElseRegion().takeBody(op.getElseRegion());
2210 SmallVector<Value> results(op->getNumResults());
2211 assert(thenYieldArgs.size() == results.size());
2212 assert(elseYieldArgs.size() == results.size());
2214 SmallVector<Value> trueYields;
2215 SmallVector<Value> falseYields;
2217 for (
const auto &it :
2218 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2219 Value trueVal = std::get<0>(it.value());
2220 Value falseVal = std::get<1>(it.value());
2223 results[it.index()] =
replacement.getResult(trueYields.size());
2224 trueYields.push_back(trueVal);
2225 falseYields.push_back(falseVal);
2226 }
else if (trueVal == falseVal)
2227 results[it.index()] = trueVal;
2229 results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2230 cond, trueVal, falseVal);
2257struct ConditionPropagation :
public OpRewritePattern<IfOp> {
2258 using OpRewritePattern<IfOp>::OpRewritePattern;
2261 enum class Parent { Then, Else,
None };
2266 static Parent getParentType(Region *toCheck, IfOp op,
2268 Region *endRegion) {
2269 SmallVector<Region *> seen;
2270 while (toCheck != endRegion) {
2271 auto found = cache.find(toCheck);
2272 if (found != cache.end())
2273 return found->second;
2274 seen.push_back(toCheck);
2275 if (&op.getThenRegion() == toCheck) {
2276 for (Region *region : seen)
2277 cache[region] = Parent::Then;
2278 return Parent::Then;
2280 if (&op.getElseRegion() == toCheck) {
2281 for (Region *region : seen)
2282 cache[region] = Parent::Else;
2283 return Parent::Else;
2288 for (Region *region : seen)
2289 cache[region] = Parent::None;
2290 return Parent::None;
2293 LogicalResult matchAndRewrite(IfOp op,
2294 PatternRewriter &rewriter)
const override {
2300 bool changed =
false;
2305 Value constantTrue =
nullptr;
2306 Value constantFalse =
nullptr;
2309 for (OpOperand &use :
2310 llvm::make_early_inc_range(op.getCondition().getUses())) {
2313 case Parent::Then: {
2317 constantTrue = arith::ConstantOp::create(
2321 [&]() { use.set(constantTrue); });
2324 case Parent::Else: {
2328 constantFalse = arith::ConstantOp::create(
2332 [&]() { use.set(constantFalse); });
2380struct ReplaceIfYieldWithConditionOrValue :
public OpRewritePattern<IfOp> {
2381 using OpRewritePattern<IfOp>::OpRewritePattern;
2383 LogicalResult matchAndRewrite(IfOp op,
2384 PatternRewriter &rewriter)
const override {
2386 if (op.getNumResults() == 0)
2390 cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2392 cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2395 op.getOperation()->getIterator());
2396 bool changed =
false;
2398 for (
auto [trueResult, falseResult, opResult] :
2399 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2401 if (trueResult == falseResult) {
2402 if (!opResult.use_empty()) {
2403 opResult.replaceAllUsesWith(trueResult);
2409 BoolAttr trueYield, falseYield;
2414 bool trueVal = trueYield.
getValue();
2415 bool falseVal = falseYield.
getValue();
2416 if (!trueVal && falseVal) {
2417 if (!opResult.use_empty()) {
2418 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2419 Value notCond = arith::XOrIOp::create(
2420 rewriter, op.getLoc(), op.getCondition(),
2426 opResult.replaceAllUsesWith(notCond);
2430 if (trueVal && !falseVal) {
2431 if (!opResult.use_empty()) {
2432 opResult.replaceAllUsesWith(op.getCondition());
2462struct CombineIfs :
public OpRewritePattern<IfOp> {
2463 using OpRewritePattern<IfOp>::OpRewritePattern;
2465 LogicalResult matchAndRewrite(IfOp nextIf,
2466 PatternRewriter &rewriter)
const override {
2467 Block *parent = nextIf->getBlock();
2468 if (nextIf == &parent->
front())
2471 auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2479 Block *nextThen =
nullptr;
2480 Block *nextElse =
nullptr;
2481 if (nextIf.getCondition() == prevIf.getCondition()) {
2482 nextThen = nextIf.thenBlock();
2483 if (!nextIf.getElseRegion().empty())
2484 nextElse = nextIf.elseBlock();
2486 if (arith::XOrIOp notv =
2487 nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2488 if (notv.getLhs() == prevIf.getCondition() &&
2490 nextElse = nextIf.thenBlock();
2491 if (!nextIf.getElseRegion().empty())
2492 nextThen = nextIf.elseBlock();
2495 if (arith::XOrIOp notv =
2496 prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2497 if (notv.getLhs() == nextIf.getCondition() &&
2499 nextElse = nextIf.thenBlock();
2500 if (!nextIf.getElseRegion().empty())
2501 nextThen = nextIf.elseBlock();
2505 if (!nextThen && !nextElse)
2508 SmallVector<Value> prevElseYielded;
2509 if (!prevIf.getElseRegion().empty())
2510 prevElseYielded = prevIf.elseYield().getOperands();
2513 for (
auto it : llvm::zip(prevIf.getResults(),
2514 prevIf.thenYield().getOperands(), prevElseYielded))
2515 for (OpOperand &use :
2516 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2520 use.
set(std::get<1>(it));
2525 use.
set(std::get<2>(it));
2530 SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2531 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2533 IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2534 prevIf.getCondition(),
false);
2535 rewriter.
eraseBlock(&combinedIf.getThenRegion().back());
2538 combinedIf.getThenRegion(),
2539 combinedIf.getThenRegion().begin());
2542 YieldOp thenYield = combinedIf.thenYield();
2543 YieldOp thenYield2 = cast<YieldOp>(nextThen->
getTerminator());
2544 rewriter.
mergeBlocks(nextThen, combinedIf.thenBlock());
2547 SmallVector<Value> mergedYields(thenYield.getOperands());
2548 llvm::append_range(mergedYields, thenYield2.getOperands());
2549 YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2555 combinedIf.getElseRegion(),
2556 combinedIf.getElseRegion().begin());
2559 if (combinedIf.getElseRegion().empty()) {
2561 combinedIf.getElseRegion(),
2562 combinedIf.getElseRegion().
begin());
2564 YieldOp elseYield = combinedIf.elseYield();
2565 YieldOp elseYield2 = cast<YieldOp>(nextElse->
getTerminator());
2566 rewriter.
mergeBlocks(nextElse, combinedIf.elseBlock());
2570 SmallVector<Value> mergedElseYields(elseYield.getOperands());
2571 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2573 YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2579 SmallVector<Value> prevValues;
2580 SmallVector<Value> nextValues;
2581 for (
const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2582 if (pair.index() < prevIf.getNumResults())
2583 prevValues.push_back(pair.value());
2585 nextValues.push_back(pair.value());
2594struct RemoveEmptyElseBranch :
public OpRewritePattern<IfOp> {
2595 using OpRewritePattern<IfOp>::OpRewritePattern;
2597 LogicalResult matchAndRewrite(IfOp ifOp,
2598 PatternRewriter &rewriter)
const override {
2600 if (ifOp.getNumResults())
2602 Block *elseBlock = ifOp.elseBlock();
2603 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2607 newIfOp.getThenRegion().begin());
2629struct CombineNestedIfs :
public OpRewritePattern<IfOp> {
2630 using OpRewritePattern<IfOp>::OpRewritePattern;
2632 LogicalResult matchAndRewrite(IfOp op,
2633 PatternRewriter &rewriter)
const override {
2634 auto nestedOps = op.thenBlock()->without_terminator();
2636 if (!llvm::hasSingleElement(nestedOps))
2640 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2643 auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2647 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2650 SmallVector<Value> thenYield(op.thenYield().getOperands());
2651 SmallVector<Value> elseYield;
2653 llvm::append_range(elseYield, op.elseYield().getOperands());
2657 SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2666 for (
const auto &tup : llvm::enumerate(thenYield)) {
2667 if (tup.value().getDefiningOp() == nestedIf) {
2668 auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2669 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2670 elseYield[tup.index()]) {
2675 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2688 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2691 elseYieldsToUpgradeToSelect.push_back(tup.index());
2694 Location loc = op.getLoc();
2695 Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2696 nestedIf.getCondition());
2697 auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2700 SmallVector<Value> results;
2701 llvm::append_range(results, newIf.getResults());
2704 for (
auto idx : elseYieldsToUpgradeToSelect)
2706 arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2707 thenYield[idx], elseYield[idx]);
2709 rewriter.
mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2712 if (!elseYield.empty()) {
2715 YieldOp::create(rewriter, loc, elseYield);
2724void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2725 MLIRContext *context) {
2726 results.
add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2727 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2728 ReplaceIfYieldWithConditionOrValue>(context);
2730 results, IfOp::getOperationName());
2732 IfOp::getOperationName());
2735Block *IfOp::thenBlock() {
return &getThenRegion().back(); }
2736YieldOp IfOp::thenYield() {
return cast<YieldOp>(&thenBlock()->back()); }
2737Block *IfOp::elseBlock() {
2738 Region &r = getElseRegion();
2743YieldOp IfOp::elseYield() {
return cast<YieldOp>(&elseBlock()->back()); }
2749void ParallelOp::build(
2754 result.addOperands(lowerBounds);
2755 result.addOperands(upperBounds);
2756 result.addOperands(steps);
2757 result.addOperands(initVals);
2759 ParallelOp::getOperandSegmentSizeAttr(),
2761 static_cast<int32_t>(upperBounds.size()),
2762 static_cast<int32_t>(steps.size()),
2763 static_cast<int32_t>(initVals.size())}));
2766 OpBuilder::InsertionGuard guard(builder);
2767 unsigned numIVs = steps.size();
2768 SmallVector<Type, 8> argTypes(numIVs, builder.
getIndexType());
2769 SmallVector<Location, 8> argLocs(numIVs,
result.location);
2770 Region *bodyRegion =
result.addRegion();
2773 if (bodyBuilderFn) {
2775 bodyBuilderFn(builder,
result.location,
2780 if (initVals.empty())
2781 ParallelOp::ensureTerminator(*bodyRegion, builder,
result.location);
2784void ParallelOp::build(
2791 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2794 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2798 wrapper = wrappedBuilderFn;
2804LogicalResult ParallelOp::verify() {
2809 if (stepValues.empty())
2811 "needs at least one tuple element for lowerBound, upperBound and step");
2814 for (Value stepValue : stepValues)
2817 return emitOpError(
"constant step operand must be positive");
2821 Block *body = getBody();
2823 return emitOpError() <<
"expects the same number of induction variables: "
2825 <<
" as bound and step values: " << stepValues.size();
2827 if (!arg.getType().isIndex())
2829 "expects arguments for the induction variable to be of index type");
2833 *
this, getRegion(),
"expects body to terminate with 'scf.reduce'");
2838 auto resultsSize = getResults().size();
2839 auto reductionsSize = reduceOp.getReductions().size();
2840 auto initValsSize = getInitVals().size();
2841 if (resultsSize != reductionsSize)
2842 return emitOpError() <<
"expects number of results: " << resultsSize
2843 <<
" to be the same as number of reductions: "
2845 if (resultsSize != initValsSize)
2846 return emitOpError() <<
"expects number of results: " << resultsSize
2847 <<
" to be the same as number of initial values: "
2849 if (reduceOp.getNumOperands() != initValsSize)
2854 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2855 auto resultType = getOperation()->getResult(i).getType();
2856 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2857 if (resultType != reductionOperandType)
2858 return reduceOp.emitOpError()
2859 <<
"expects type of " << i
2860 <<
"-th reduction operand: " << reductionOperandType
2861 <<
" to be the same as the " << i
2862 <<
"-th result type: " << resultType;
2867ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &
result) {
2870 SmallVector<OpAsmParser::Argument, 4> ivs;
2875 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2882 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2890 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2898 SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2909 Region *body =
result.addRegion();
2910 for (
auto &iv : ivs)
2917 ParallelOp::getOperandSegmentSizeAttr(),
2919 static_cast<int32_t>(upper.size()),
2920 static_cast<int32_t>(steps.size()),
2921 static_cast<int32_t>(initVals.size())}));
2930 ParallelOp::ensureTerminator(*body, builder,
result.location);
2934void ParallelOp::print(OpAsmPrinter &p) {
2935 p <<
" (" << getBody()->getArguments() <<
") = (" <<
getLowerBound()
2936 <<
") to (" <<
getUpperBound() <<
") step (" << getStep() <<
")";
2937 if (!getInitVals().empty())
2938 p <<
" init (" << getInitVals() <<
")";
2943 (*this)->getAttrs(),
2944 ParallelOp::getOperandSegmentSizeAttr());
2947SmallVector<Region *> ParallelOp::getLoopRegions() {
return {&getRegion()}; }
2949std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
2950 return SmallVector<Value>{getBody()->getArguments()};
2953std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
2957std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
2961std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
2966 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2968 return ParallelOp();
2969 assert(ivArg.getOwner() &&
"unlinked block argument");
2970 auto *containingOp = ivArg.getOwner()->getParentOp();
2971 return dyn_cast<ParallelOp>(containingOp);
2976struct ParallelOpSingleOrZeroIterationDimsFolder
2980 LogicalResult matchAndRewrite(ParallelOp op,
2987 for (
auto [lb,
ub, step, iv] :
2988 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2989 op.getInductionVars())) {
2990 auto numIterations =
2992 if (numIterations.has_value()) {
2994 if (*numIterations == 0) {
2995 rewriter.
replaceOp(op, op.getInitVals());
3000 if (*numIterations == 1) {
3005 newLowerBounds.push_back(lb);
3006 newUpperBounds.push_back(ub);
3007 newSteps.push_back(step);
3010 if (newLowerBounds.size() == op.getLowerBound().size())
3013 if (newLowerBounds.empty()) {
3016 SmallVector<Value> results;
3017 results.reserve(op.getInitVals().size());
3018 for (
auto &bodyOp : op.getBody()->without_terminator())
3019 rewriter.
clone(bodyOp, mapping);
3020 auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3021 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3022 Block &reduceBlock = reduceOp.getReductions()[i].front();
3023 auto initValIndex = results.size();
3024 mapping.
map(reduceBlock.
getArgument(0), op.getInitVals()[initValIndex]);
3028 rewriter.
clone(reduceBodyOp, mapping);
3031 cast<ReduceReturnOp>(reduceBlock.
getTerminator()).getResult());
3032 results.push_back(
result);
3040 ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3041 newUpperBounds, newSteps, op.getInitVals(),
nullptr);
3047 newOp.getRegion().begin(), mapping);
3048 rewriter.
replaceOp(op, newOp.getResults());
3053struct MergeNestedParallelLoops :
public OpRewritePattern<ParallelOp> {
3054 using OpRewritePattern<ParallelOp>::OpRewritePattern;
3056 LogicalResult matchAndRewrite(ParallelOp op,
3057 PatternRewriter &rewriter)
const override {
3058 Block &outerBody = *op.getBody();
3062 auto innerOp = dyn_cast<ParallelOp>(outerBody.
front());
3067 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3068 llvm::is_contained(innerOp.getUpperBound(), val) ||
3069 llvm::is_contained(innerOp.getStep(), val))
3073 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3076 auto bodyBuilder = [&](OpBuilder &builder, Location ,
3078 Block &innerBody = *innerOp.getBody();
3079 assert(iterVals.size() ==
3087 builder.
clone(op, mapping);
3090 auto concatValues = [](
const auto &first,
const auto &second) {
3091 SmallVector<Value> ret;
3092 ret.reserve(first.size() + second.size());
3093 ret.assign(first.begin(), first.end());
3094 ret.append(second.begin(), second.end());
3098 auto newLowerBounds =
3099 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3100 auto newUpperBounds =
3101 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3102 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3113void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3114 MLIRContext *context) {
3116 .
add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3125void ParallelOp::getSuccessorRegions(
3126 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3130 regions.push_back(RegionSuccessor(&getRegion()));
3138void ReduceOp::build(OpBuilder &builder, OperationState &
result) {}
3140void ReduceOp::build(OpBuilder &builder, OperationState &
result,
3142 result.addOperands(operands);
3143 for (Value v : operands) {
3144 OpBuilder::InsertionGuard guard(builder);
3145 Region *bodyRegion =
result.addRegion();
3147 ArrayRef<Type>{v.getType(), v.getType()},
3152LogicalResult ReduceOp::verifyRegions() {
3153 if (getReductions().size() != getOperands().size())
3154 return emitOpError() <<
"expects number of reduction regions: "
3155 << getReductions().size()
3156 <<
" to be the same as number of reduction operands: "
3157 << getOperands().size();
3160 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3161 auto type = getOperands()[i].getType();
3162 Block &block = getReductions()[i].front();
3164 return emitOpError() << i <<
"-th reduction has an empty body";
3166 llvm::any_of(block.
getArguments(), [&](
const BlockArgument &arg) {
3167 return arg.getType() != type;
3169 return emitOpError() <<
"expected two block arguments with type " << type
3170 <<
" in the " << i <<
"-th reduction region";
3174 return emitOpError(
"reduction bodies must be terminated with an "
3175 "'scf.reduce.return' op");
3182ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
3184 return MutableOperandRange(getOperation(), 0, 0);
3191LogicalResult ReduceReturnOp::verify() {
3194 Block *reductionBody = getOperation()->getBlock();
3196 assert(isa<ReduceOp>(reductionBody->
getParentOp()) &&
"expected scf.reduce");
3198 if (expectedResultType != getResult().
getType())
3199 return emitOpError() <<
"must have type " << expectedResultType
3200 <<
" (the type of the reduction inputs)";
3208void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3209 ::mlir::OperationState &odsState,
TypeRange resultTypes,
3210 ValueRange inits, BodyBuilderFn beforeBuilder,
3211 BodyBuilderFn afterBuilder) {
3215 OpBuilder::InsertionGuard guard(odsBuilder);
3218 SmallVector<Location, 4> beforeArgLocs;
3219 beforeArgLocs.reserve(inits.size());
3220 for (Value operand : inits) {
3221 beforeArgLocs.push_back(operand.getLoc());
3224 Region *beforeRegion = odsState.
addRegion();
3226 inits.getTypes(), beforeArgLocs);
3231 SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.
location);
3233 Region *afterRegion = odsState.
addRegion();
3235 resultTypes, afterArgLocs);
3241ConditionOp WhileOp::getConditionOp() {
3242 return cast<ConditionOp>(getBeforeBody()->getTerminator());
3245YieldOp WhileOp::getYieldOp() {
3246 return cast<YieldOp>(getAfterBody()->getTerminator());
3249std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3250 return getYieldOp().getResultsMutable();
3254 return getBeforeBody()->getArguments();
3258 return getAfterBody()->getArguments();
3262 return getBeforeArguments();
3265OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
3267 "WhileOp is expected to branch only to the first region");
3271void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3272 SmallVectorImpl<RegionSuccessor> ®ions) {
3275 regions.emplace_back(&getBefore());
3279 assert(llvm::is_contained(
3280 {&getAfter(), &getBefore()},
3282 "there are only two regions in a WhileOp");
3286 regions.emplace_back(&getBefore());
3291 regions.emplace_back(&getAfter());
3294ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
3296 return getOperation()->getResults();
3297 if (successor == &getBefore())
3298 return getBefore().getArguments();
3299 if (successor == &getAfter())
3300 return getAfter().getArguments();
3301 llvm_unreachable(
"invalid region successor");
3304SmallVector<Region *> WhileOp::getLoopRegions() {
3305 return {&getBefore(), &getAfter()};
3315ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &
result) {
3316 SmallVector<OpAsmParser::Argument, 4> regionArgs;
3317 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3318 Region *before =
result.addRegion();
3319 Region *after =
result.addRegion();
3321 OptionalParseResult listResult =
3326 FunctionType functionType;
3331 result.addTypes(functionType.getResults());
3333 if (functionType.getNumInputs() != operands.size()) {
3335 <<
"expected as many input types as operands " <<
"(expected "
3336 << operands.size() <<
" got " << functionType.getNumInputs() <<
")";
3346 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3347 regionArgs[i].type = functionType.getInput(i);
3349 return failure(parser.
parseRegion(*before, regionArgs) ||
3355void scf::WhileOp::print(OpAsmPrinter &p) {
3369template <
typename OpTy>
3372 if (left.size() != right.size())
3373 return op.emitOpError(
"expects the same number of ") << message;
3375 for (
unsigned i = 0, e = left.size(); i < e; ++i) {
3376 if (left[i] != right[i]) {
3379 diag.attachNote() <<
"for argument " << i <<
", found " << left[i]
3380 <<
" and " << right[i];
3388LogicalResult scf::WhileOp::verify() {
3391 "expects the 'before' region to terminate with 'scf.condition'");
3392 if (!beforeTerminator)
3397 "expects the 'after' region to terminate with 'scf.yield'");
3398 return success(afterTerminator !=
nullptr);
3435struct WhileMoveIfDown :
public OpRewritePattern<scf::WhileOp> {
3436 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3438 LogicalResult matchAndRewrite(scf::WhileOp op,
3439 PatternRewriter &rewriter)
const override {
3440 auto conditionOp = op.getConditionOp();
3448 auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3454 if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3455 (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3458 assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3459 *ifOp->user_begin() == conditionOp)) &&
3460 "ifOp has unexpected uses");
3462 Location loc = op.getLoc();
3466 for (
auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3467 auto it = llvm::find(ifOp->getResults(), arg);
3468 if (it != ifOp->getResults().end()) {
3469 size_t ifOpIdx = it.getIndex();
3470 Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3471 Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3481 if (&op.getBefore() == operand->get().getParentRegion())
3482 additionalUsedValuesSet.insert(operand->get());
3486 auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3487 auto additionalValueTypes = llvm::map_to_vector(
3488 additionalUsedValues, [](Value val) {
return val.
getType(); });
3489 size_t additionalValueSize = additionalUsedValues.size();
3490 SmallVector<Type> newResultTypes(op.getResultTypes());
3491 newResultTypes.append(additionalValueTypes);
3494 scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3497 newWhileOp.getBefore().takeBody(op.getBefore());
3498 newWhileOp.getAfter().takeBody(op.getAfter());
3499 newWhileOp.getAfter().addArguments(
3500 additionalValueTypes,
3501 SmallVector<Location>(additionalValueSize, loc));
3505 conditionOp.getArgsMutable().append(additionalUsedValues);
3511 additionalUsedValues,
3512 newWhileOp.getAfterArguments().take_back(additionalValueSize),
3513 [&](OpOperand &use) {
3514 return ifOp.getThenRegion().isAncestor(
3515 use.getOwner()->getParentRegion());
3519 rewriter.
eraseOp(ifOp.thenYield());
3521 newWhileOp.getAfterBody()->begin());
3524 newWhileOp->getResults().drop_back(additionalValueSize));
3548struct WhileConditionTruth :
public OpRewritePattern<WhileOp> {
3549 using OpRewritePattern<WhileOp>::OpRewritePattern;
3551 LogicalResult matchAndRewrite(WhileOp op,
3552 PatternRewriter &rewriter)
const override {
3553 auto term = op.getConditionOp();
3557 Value constantTrue =
nullptr;
3559 bool replaced =
false;
3560 for (
auto yieldedAndBlockArgs :
3561 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3562 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3563 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3565 constantTrue = arith::ConstantOp::create(
3566 rewriter, op.getLoc(), term.getCondition().getType(),
3601struct WhileCmpCond :
public OpRewritePattern<scf::WhileOp> {
3602 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3604 LogicalResult matchAndRewrite(scf::WhileOp op,
3605 PatternRewriter &rewriter)
const override {
3606 using namespace scf;
3607 auto cond = op.getConditionOp();
3608 auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3611 bool changed =
false;
3612 for (
auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3613 for (
size_t opIdx = 0; opIdx < 2; opIdx++) {
3614 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3617 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3618 auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3622 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3625 if (cmp2.getPredicate() == cmp.getPredicate())
3626 samePredicate =
true;
3627 else if (cmp2.getPredicate() ==
3628 arith::invertPredicate(cmp.getPredicate()))
3629 samePredicate =
false;
3645static std::optional<SmallVector<unsigned>> getArgsMapping(
ValueRange args1,
3647 if (args1.size() != args2.size())
3648 return std::nullopt;
3650 SmallVector<unsigned> ret(args1.size());
3651 for (
auto &&[i, arg1] : llvm::enumerate(args1)) {
3652 auto it = llvm::find(args2, arg1);
3653 if (it == args2.end())
3654 return std::nullopt;
3656 ret[std::distance(args2.begin(), it)] =
static_cast<unsigned>(i);
3663 llvm::SmallDenseSet<Value> set;
3664 for (Value arg : args) {
3665 if (!set.insert(arg).second)
3675struct WhileOpAlignBeforeArgs :
public OpRewritePattern<WhileOp> {
3678 LogicalResult matchAndRewrite(WhileOp loop,
3679 PatternRewriter &rewriter)
const override {
3680 auto *oldBefore = loop.getBeforeBody();
3681 ConditionOp oldTerm = loop.getConditionOp();
3682 ValueRange beforeArgs = oldBefore->getArguments();
3684 if (beforeArgs == termArgs)
3687 if (hasDuplicates(termArgs))
3690 auto mapping = getArgsMapping(beforeArgs, termArgs);
3695 OpBuilder::InsertionGuard g(rewriter);
3701 auto *oldAfter = loop.getAfterBody();
3703 SmallVector<Type> newResultTypes(beforeArgs.size());
3704 for (
auto &&[i, j] : llvm::enumerate(*mapping))
3705 newResultTypes[j] = loop.getResult(i).getType();
3707 auto newLoop = WhileOp::create(
3708 rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
3710 auto *newBefore = newLoop.getBeforeBody();
3711 auto *newAfter = newLoop.getAfterBody();
3713 SmallVector<Value> newResults(beforeArgs.size());
3714 SmallVector<Value> newAfterArgs(beforeArgs.size());
3715 for (
auto &&[i, j] : llvm::enumerate(*mapping)) {
3716 newResults[i] = newLoop.getResult(j);
3717 newAfterArgs[i] = newAfter->getArgument(j);
3721 newBefore->getArguments());
3731void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3732 MLIRContext *context) {
3733 results.
add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
3734 WhileMoveIfDown>(context);
3736 results, WhileOp::getOperationName());
3738 WhileOp::getOperationName());
3752 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
3755 caseValues.push_back(value);
3764 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
3766 p <<
"case " << value <<
' ';
3771LogicalResult scf::IndexSwitchOp::verify() {
3772 if (getCases().size() != getCaseRegions().size()) {
3774 << getCaseRegions().size() <<
" case regions but "
3775 << getCases().size() <<
" case values";
3779 for (int64_t value : getCases())
3780 if (!valueSet.insert(value).second)
3781 return emitOpError(
"has duplicate case value: ") << value;
3782 auto verifyRegion = [&](Region ®ion,
const Twine &name) -> LogicalResult {
3783 auto yield = dyn_cast<YieldOp>(region.
front().
back());
3785 return emitOpError(
"expected region to end with scf.yield, but got ")
3788 if (yield.getNumOperands() != getNumResults()) {
3789 return (
emitOpError(
"expected each region to return ")
3790 << getNumResults() <<
" values, but " << name <<
" returns "
3791 << yield.getNumOperands())
3792 .attachNote(yield.getLoc())
3793 <<
"see yield operation here";
3795 for (
auto [idx,
result, operand] :
3796 llvm::enumerate(getResultTypes(), yield.getOperands())) {
3798 return yield.emitOpError() <<
"operand " << idx <<
" is null\n";
3799 if (
result == operand.getType())
3802 << idx <<
" of each region to be " <<
result)
3803 .attachNote(yield.getLoc())
3804 << name <<
" returns " << operand.getType() <<
" here";
3811 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3818unsigned scf::IndexSwitchOp::getNumCases() {
return getCases().size(); }
3820Block &scf::IndexSwitchOp::getDefaultBlock() {
3821 return getDefaultRegion().front();
3824Block &scf::IndexSwitchOp::getCaseBlock(
unsigned idx) {
3825 assert(idx < getNumCases() &&
"case index out-of-bounds");
3826 return getCaseRegions()[idx].front();
3829void IndexSwitchOp::getSuccessorRegions(
3830 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
3837 llvm::append_range(successors, getRegions());
3840ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
3845void IndexSwitchOp::getEntrySuccessorRegions(
3846 ArrayRef<Attribute> operands,
3847 SmallVectorImpl<RegionSuccessor> &successors) {
3848 FoldAdaptor adaptor(operands, *
this);
3851 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3853 llvm::append_range(successors, getRegions());
3859 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3860 if (caseValue == arg.getInt()) {
3861 successors.emplace_back(&caseRegion);
3865 successors.emplace_back(&getDefaultRegion());
3868void IndexSwitchOp::getRegionInvocationBounds(
3869 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
3870 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
3871 if (!operandValue) {
3873 bounds.append(getNumRegions(), InvocationBounds(0, 1));
3877 unsigned liveIndex = getNumRegions() - 1;
3878 const auto *it = llvm::find(getCases(), operandValue.getInt());
3879 if (it != getCases().end())
3880 liveIndex = std::distance(getCases().begin(), it);
3881 for (
unsigned i = 0, e = getNumRegions(); i < e; ++i)
3882 bounds.emplace_back(0, i == liveIndex);
3885void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3886 MLIRContext *context) {
3888 results, IndexSwitchOp::getOperationName());
3890 results, IndexSwitchOp::getOperationName());
3897#define GET_OP_CLASSES
3898#include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
static Value defaultReplBuilderFn(OpBuilder &builder, Location loc, Value value)
Default implementation of the non-successor-input replacement builder function.
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
llvm::function_ref< Value(OpBuilder &, Location loc, Type, Value)> ValueTypeCastFnTy
Perform a replacement of one iter OpOperand of an scf.for to the replacement value with a different t...
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
void populateRegionBranchOpInterfaceInliningPattern(RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn=detail::defaultReplBuilderFn, PatternMatcherFn matcherFn=detail::defaultMatcherFn, PatternBenefit benefit=1)
Populate a pattern that inlines the body of region branch ops when there is a single acyclic path thr...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void populateRegionBranchOpInterfaceCanonicalizationPatterns(RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit=1)
Populate canonicalization patterns that simplify successor operands/inputs of region branch operation...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.