20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39 for (
Operation &op : genericOp.getBody()->getOperations())
40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
48 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
60 template <
typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
63 OpTy packOrUnPackOp) {
64 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
65 "applies to only pack or unpack operations");
67 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
72 genericOp.getIteratorTypesArray();
75 int64_t origNumDims = indexingMap.getNumDims();
78 for (
auto [index, innerDimPos, tileSize] :
79 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81 auto expr = exprs[innerDimPos];
82 if (!isa<AffineDimExpr>(expr))
84 int64_t domainDimPos =
85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
88 packInfo.tiledDimsPos.push_back(domainDimPos);
89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
92 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
93 <<
" to iteration dimension (d" << domainDimPos <<
", d"
94 << packInfo.tileToPointMapping[domainDimPos]
95 <<
"), which has size=("
96 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
102 auto areAllAffineDimExpr = [&](
int dim) {
104 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
112 for (int64_t i : packInfo.tiledDimsPos)
113 if (!areAllAffineDimExpr(i))
130 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131 auto permutedExpr = indexingMap.getResult(dim);
132 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133 permutedOuterDims.push_back(dimExpr.getPosition());
140 if (
static_cast<int64_t
>(index) != dim)
143 if (!permutedOuterDims.empty()) {
144 int64_t outerDimIndex = 0;
146 permutedOuterDims.end());
147 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148 packInfo.outerDimsOnDomainPerm.push_back(
149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
152 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
153 for (
auto dim : packInfo.outerDimsOnDomainPerm)
154 llvm::dbgs() << dim <<
" ";
155 llvm::dbgs() <<
"\n";
173 assert(!perm.empty() &&
"expect perm not to be empty");
174 assert(!exprs.empty() &&
"expect exprs not to be empty");
175 if (exprs.size() == 1)
183 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184 currentPositionTileLoops[dimExpr.getPosition()] = pos;
186 currentPositionTileLoops[pos] = pos;
188 for (int64_t loopIdx : perm) {
189 if (currentPositionTileLoops.count(loopIdx))
190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
192 return outerDimsPerm;
226 static std::tuple<Value, AffineMap>
228 GenericOp genericOp,
OpOperand *opOperand) {
229 int64_t numOrigLoops = genericOp.getNumLoops();
230 int64_t numInnerLoops = packInfo.getNumTiledLoops();
231 int64_t numLoops = numOrigLoops + numInnerLoops;
232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
237 if (genericOp.isScalar(opOperand) || exprs.empty())
238 return std::make_tuple(opOperand->
get(),
244 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245 int64_t dimPos = dimExpr.getPosition();
246 domainDimToOperandDim[dimPos] = index;
252 for (
auto dimPos : packInfo.tiledDimsPos) {
253 if (!domainDimToOperandDim.count(dimPos))
255 int64_t index = domainDimToOperandDim[dimPos];
256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257 innerDimsPos.push_back(index);
263 if (!packInfo.outerDimsOnDomainPerm.empty()) {
264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
269 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
270 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271 int64_t dimPos = dimExpr.getPosition();
275 assert(isa<AffineConstantExpr>(exprs[i]) &&
276 "Attempted to permute non-constant and non-affine dim expression");
280 if (!outerDimsPerm.empty()) {
282 for (
const auto &en :
enumerate(outerDimsPerm))
283 auxVec[en.index()] = exprs[en.value()];
290 if (innerDimsPos.empty() && outerDimsPerm.empty())
291 return std::make_tuple(opOperand->
get(), indexingMap);
293 auto empty = tensor::PackOp::createDestinationTensor(
294 b, loc, opOperand->
get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295 auto packedOperand = b.
create<tensor::PackOp>(
296 loc, opOperand->
get(), empty, innerDimsPos, innerTileSizes,
297 std::nullopt, outerDimsPerm);
298 return std::make_tuple(packedOperand, indexingMap);
302 static GenericOp packGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
304 const PackInfo &packInfo) {
308 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310 rewriter, loc, packInfo, genericOp, inputOperand);
311 inputOperands.push_back(packedOperand);
312 indexingMaps.push_back(packedIndexingMap);
315 int64_t numInnerLoops = packInfo.getNumTiledLoops();
317 genericOp.getIteratorTypesArray();
318 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
320 indexingMaps.push_back(packedOutIndexingMap);
322 auto newGenericOp = rewriter.
create<linalg::GenericOp>(
323 loc, dest.
getType(), inputOperands, dest, indexingMaps, iterTypes,
326 newGenericOp.getRegion().begin());
373 static FailureOr<GenericOp>
374 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, tensor::PackOp packOp,
376 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
381 if (!controlFn(&packOp.getSourceMutable()))
387 if (hasGatherSemantics(genericOp))
392 if (genericOp.getNumResults() != 1)
399 if (!genericOp->getResult(0).hasOneUse())
412 Value packOpDest = packOp.getDest();
415 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
416 packOpDest = rewriter.
create<tensor::EmptyOp>(
417 genericOp->getLoc(), emptyOp.getMixedSizes(),
418 emptyOp.getType().getElementType());
429 if (packOp.getPaddingValue())
432 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434 if (failed(packInfo))
438 auto [packedOutOperand, packedOutIndexingMap] =
439 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
440 genericOp, opOperand);
444 Value dest = packedOutOperand;
445 if (
auto initTensor = genericOp.getDpsInitOperand(0)
447 .getDefiningOp<tensor::EmptyOp>()) {
450 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
455 struct BubbleUpPackOpThroughGenericOpPattern
458 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
462 LogicalResult matchAndRewrite(tensor::PackOp packOp,
465 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466 if (failed(genericOp))
468 rewriter.
replaceOp(packOp, genericOp->getResults());
479 class BubbleUpPackThroughPadOp final :
public OpRewritePattern<tensor::PackOp> {
484 LogicalResult matchAndRewrite(tensor::PackOp packOp,
486 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
491 if (!controlFn(&packOp.getSourceMutable()))
495 if (packOp.getPaddingValue())
502 Value paddingVal = padOp.getConstantPaddingValue();
506 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
512 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
513 llvm::SmallBitVector innerDims(paddedDims.size());
514 for (int64_t dim : innerDimsPos)
516 if (paddedDims.anyCommon(innerDims))
525 auto empty = tensor::PackOp::createDestinationTensor(
526 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
528 auto sourcePack = rewriter.
create<tensor::PackOp>(
529 loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
530 std::nullopt, outerDimsPerm);
535 if (!outerDimsPerm.empty()) {
536 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
541 size_t pointLoopsSize = innerDimsPos.size();
542 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
543 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
545 auto newPadOp = rewriter.
create<tensor::PadOp>(
546 loc,
Type(), sourcePack, lowPad, highPad, paddingVal,
551 if (!padOp->hasOneUse()) {
552 auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
553 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
554 Value unpackedPad = rewriter.
create<tensor::UnPackOp>(
555 loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
560 rewriter.
replaceOp(packOp, newPadOp.getResult());
583 for (
auto pos : dimsPos) {
585 int64_t projectedPos = reassocIndices[pos].back();
586 for (
auto i : llvm::reverse(reassocIndices[pos])) {
587 int64_t dim = targetShape[i];
588 if (dim > 1 || ShapedType::isDynamic(dim)) {
593 projectedDimsPos.push_back(projectedPos);
595 return projectedDimsPos;
602 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
603 int64_t dim = shape[pos];
604 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
616 static int64_t applyPermutationAndReindexReassoc(
619 if (!permutation.empty())
620 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
623 for (
auto &index : indices) {
651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
652 tensor::PackOp packOp,
660 collapseOp.getReassociationIndices();
669 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
671 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
679 for (
auto outerPos : outerDimsPerm) {
680 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
681 reassocIndices[outerPos].begin(),
682 reassocIndices[outerPos].end());
685 auto emptyOp = tensor::PackOp::createDestinationTensor(
686 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
687 projectedInnerDimsPos, newOuterDimsPerm);
688 auto newPackOp = rewriter.
create<tensor::PackOp>(
689 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
690 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
697 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
699 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
700 newReassocIndices.push_back({nextPos});
704 auto newCollapseOp = rewriter.
create<tensor::CollapseShapeOp>(
705 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
706 rewriter.
replaceOp(packOp, newCollapseOp);
723 for (
auto pos : dimsPos) {
728 if (llvm::is_contained(indices, pos)) {
729 projectedPos.push_back(idx);
734 assert(projectedPos.size() == dimsPos.size() &&
"Invalid dim pos projection");
757 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
758 tensor::PackOp packOp,
765 "non-identity outer dims perm NYI");
770 expandOp.getReassociationIndices();
773 packInnerDims.end());
780 llvm::set_intersection(packDimsPos, expandDimPos);
784 if (packedDims.empty())
789 if (packedDims.size() != 1)
791 packOp,
"only one of the expanded dimensions can be packed");
794 if (packedDims.front() != indices.back())
796 packOp,
"can only pack the inner-most expanded dimension");
801 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
811 RankedTensorType newPackType = tensor::PackOp::inferPackedType(
812 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
818 packOp,
"could not reassociate dims after bubbling up");
820 Value destTensor = tensor::PackOp::createDestinationTensor(
821 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
824 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
825 packOp.getMixedTiles(), packOp.getPaddingValue(),
828 Value newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
829 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
835 class BubbleUpPackOpThroughReshapeOp final
841 LogicalResult matchAndRewrite(tensor::PackOp packOp,
843 Operation *srcOp = packOp.getSource().getDefiningOp();
850 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
851 return ShapedType::isDynamic(size);
857 if (!controlFn(&packOp.getSourceMutable()))
861 .Case([&](tensor::CollapseShapeOp op) {
862 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
864 .Case([&](tensor::ExpandShapeOp op) {
865 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
867 .Default([](
Operation *) {
return failure(); });
893 static LogicalResult pushDownUnPackOpThroughExpandShape(
894 tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
897 if (!controlFn(&expandOp.getSrcMutable()))
904 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
909 expandOp.getReassociationIndices();
918 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
920 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
928 for (
auto outerPos : outerDimsPerm) {
929 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
930 reassocIndices[outerPos].begin(),
931 reassocIndices[outerPos].end());
939 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
941 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
942 newReassocIndices.push_back({nextPos});
946 RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
947 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
948 auto newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
949 expandOp.getLoc(), newExpandType, unPackOp.getSource(),
952 auto emptyOp = tensor::UnPackOp::createDestinationTensor(
953 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
954 projectedInnerDimsPos, newOuterDimsPerm);
955 auto newUnPackOp = rewriter.
create<tensor::UnPackOp>(
956 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
957 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
958 rewriter.
replaceOp(expandOp, newUnPackOp);
963 class PushDownUnPackOpThroughReshapeOp final
966 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
971 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
973 Value result = unPackOp.getResult();
979 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
980 return ShapedType::isDynamic(size);
987 .Case([&](tensor::ExpandShapeOp op) {
988 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
991 .Default([](
Operation *) {
return failure(); });
1001 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1003 for (
OpOperand &operand : genericOp->getOpOperands()) {
1007 if (unPackedOperand)
1009 unPackedOperand = &operand;
1011 if (!unPackedOperand)
1013 return unPackedOperand;
1050 static FailureOr<std::tuple<GenericOp, Value>>
1051 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
1053 if (genericOp.getNumResults() != 1)
1056 if (hasGatherSemantics(genericOp))
1060 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1061 if (failed(maybeUnPackedOperand))
1063 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1066 tensor::UnPackOp producerUnPackOp =
1068 assert(producerUnPackOp &&
"expect a valid UnPackOp");
1070 if (!controlFn(unPackedOperand))
1074 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1075 if (failed(packInfo))
1079 auto [packedOutOperand, packedOutIndexingMap] =
1080 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1081 genericOp, genericOp.getDpsInitOperand(0));
1082 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
1086 Value dest = packedOutOperand;
1087 if (
auto initTensor = genericOp.getDpsInitOperand(0)
1089 .getDefiningOp<tensor::EmptyOp>()) {
1091 dest = destPack.getDest();
1095 GenericOp newGenericOp =
1096 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1098 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1102 return std::make_tuple(newGenericOp, newResult);
1104 auto mixedTiles = destPack.getMixedTiles();
1105 auto innerDimsPos = destPack.getInnerDimsPos();
1106 auto outerDimsPerm = destPack.getOuterDimsPerm();
1111 .
create<tensor::UnPackOp>(genericOp.getLoc(), newResult,
1112 destPack.getSource(), innerDimsPos,
1113 mixedTiles, outerDimsPerm)
1116 return std::make_tuple(newGenericOp, unPackOpRes);
1120 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
1122 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
1126 LogicalResult matchAndRewrite(GenericOp genericOp,
1128 auto genericAndRepl =
1129 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1130 if (failed(genericAndRepl))
1132 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1143 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1147 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1149 tensor::UnPackOp unpackOp =
1150 padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1154 if (!controlFn(&padOp.getSourceMutable()))
1159 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1161 llvm::SmallBitVector innerDims(paddedDims.size());
1162 for (int64_t dim : innerDimsPos)
1163 innerDims.flip(dim);
1164 if (paddedDims.anyCommon(innerDims))
1167 Value paddingVal = padOp.getConstantPaddingValue();
1175 if (!outerDimsPerm.empty()) {
1176 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1177 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1180 size_t pointLoopsSize = innerDimsPos.size();
1181 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1182 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1184 auto newPadOp = rewriter.
create<tensor::PadOp>(
1185 loc,
Type(), unpackOp.getSource(), lowPad, highPad,
1186 paddingVal, padOp.getNofold());
1189 Value outputUnPack = rewriter.
create<tensor::EmptyOp>(
1190 loc, padOp.getResultType().getShape(),
1191 padOp.getResultType().getElementType());
1193 Value replacement = rewriter.
create<tensor::UnPackOp>(
1194 loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1195 unpackOp.getMixedTiles(), outerDimsPerm);
1210 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1211 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1212 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1213 patterns.getContext(), controlPackUnPackPropagation);
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...