20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39 for (
Operation &op : genericOp.getBody()->getOperations())
40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
48 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
60 template <
typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
63 OpTy packOrUnPackOp) {
64 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
65 "applies to only pack or unpack operations");
67 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
72 genericOp.getIteratorTypesArray();
75 int64_t origNumDims = indexingMap.getNumDims();
78 for (
auto [index, innerDimPos, tileSize] :
79 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
80 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
81 auto expr = exprs[innerDimPos];
82 if (!isa<AffineDimExpr>(expr))
84 int64_t domainDimPos =
85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
88 packInfo.tiledDimsPos.push_back(domainDimPos);
89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
92 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
93 <<
" to iteration dimension (d" << domainDimPos <<
", d"
94 << packInfo.tileToPointMapping[domainDimPos]
95 <<
"), which has size=("
96 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
102 auto areAllAffineDimExpr = [&](
int dim) {
104 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
112 for (int64_t i : packInfo.tiledDimsPos)
113 if (!areAllAffineDimExpr(i))
130 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131 auto permutedExpr = indexingMap.getResult(dim);
132 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133 permutedOuterDims.push_back(dimExpr.getPosition());
140 if (
static_cast<int64_t
>(index) != dim)
143 if (!permutedOuterDims.empty()) {
144 int64_t outerDimIndex = 0;
146 permutedOuterDims.end());
147 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148 packInfo.outerDimsOnDomainPerm.push_back(
149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
152 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
153 for (
auto dim : packInfo.outerDimsOnDomainPerm)
154 llvm::dbgs() << dim <<
" ";
155 llvm::dbgs() <<
"\n";
173 assert(!perm.empty() &&
"expect perm not to be empty");
174 assert(!exprs.empty() &&
"expect exprs not to be empty");
175 if (exprs.size() == 1)
183 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184 currentPositionTileLoops[dimExpr.getPosition()] = pos;
186 currentPositionTileLoops[pos] = pos;
188 for (int64_t loopIdx : perm) {
189 if (currentPositionTileLoops.count(loopIdx))
190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
192 return outerDimsPerm;
226 static std::tuple<Value, AffineMap>
228 GenericOp genericOp,
OpOperand *opOperand) {
229 int64_t numOrigLoops = genericOp.getNumLoops();
230 int64_t numInnerLoops = packInfo.getNumTiledLoops();
231 int64_t numLoops = numOrigLoops + numInnerLoops;
232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
237 if (genericOp.isScalar(opOperand) || exprs.empty())
238 return std::make_tuple(opOperand->
get(),
244 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245 int64_t dimPos = dimExpr.getPosition();
246 domainDimToOperandDim[dimPos] = index;
252 for (
auto dimPos : packInfo.tiledDimsPos) {
253 if (!domainDimToOperandDim.count(dimPos))
255 int64_t index = domainDimToOperandDim[dimPos];
256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
257 innerDimsPos.push_back(index);
263 if (!packInfo.outerDimsOnDomainPerm.empty()) {
264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
269 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
270 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271 int64_t dimPos = dimExpr.getPosition();
275 assert(isa<AffineConstantExpr>(exprs[i]) &&
276 "Attempted to permute non-constant and non-affine dim expression");
280 if (!outerDimsPerm.empty()) {
282 for (
const auto &en :
enumerate(outerDimsPerm))
283 auxVec[en.index()] = exprs[en.value()];
290 if (innerDimsPos.empty() && outerDimsPerm.empty())
291 return std::make_tuple(opOperand->
get(), indexingMap);
293 auto empty = tensor::PackOp::createDestinationTensor(
294 b, loc, opOperand->
get(), innerTileSizes, innerDimsPos, outerDimsPerm);
295 auto packedOperand = b.
create<tensor::PackOp>(
296 loc, opOperand->
get(), empty, innerDimsPos, innerTileSizes,
297 std::nullopt, outerDimsPerm);
298 return std::make_tuple(packedOperand, indexingMap);
302 static GenericOp packGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
304 const PackInfo &packInfo) {
308 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310 rewriter, loc, packInfo, genericOp, inputOperand);
311 inputOperands.push_back(packedOperand);
312 indexingMaps.push_back(packedIndexingMap);
315 int64_t numInnerLoops = packInfo.getNumTiledLoops();
317 genericOp.getIteratorTypesArray();
318 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
320 indexingMaps.push_back(packedOutIndexingMap);
322 auto newGenericOp = rewriter.
create<linalg::GenericOp>(
323 loc, dest.
getType(), inputOperands, dest, indexingMaps, iterTypes,
326 newGenericOp.getRegion().begin());
373 static FailureOr<GenericOp>
374 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, tensor::PackOp packOp,
376 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
381 if (!controlFn(&packOp.getSourceMutable()))
387 if (hasGatherSemantics(genericOp))
392 if (genericOp.getNumResults() != 1)
399 if (!genericOp->getResult(0).hasOneUse())
412 Value packOpDest = packOp.getDest();
415 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
416 packOpDest = rewriter.
create<tensor::EmptyOp>(
417 genericOp->getLoc(), emptyOp.getMixedSizes(),
418 emptyOp.getType().getElementType());
429 if (packOp.getPaddingValue())
432 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434 if (failed(packInfo))
438 auto [packedOutOperand, packedOutIndexingMap] =
439 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
440 genericOp, opOperand);
444 Value dest = packedOutOperand;
445 if (
auto initTensor = genericOp.getDpsInitOperand(0)
447 .getDefiningOp<tensor::EmptyOp>()) {
450 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
455 struct BubbleUpPackOpThroughGenericOpPattern
458 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
462 LogicalResult matchAndRewrite(tensor::PackOp packOp,
465 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466 if (failed(genericOp))
468 rewriter.
replaceOp(packOp, genericOp->getResults());
479 class BubbleUpPackThroughPadOp final :
public OpRewritePattern<tensor::PackOp> {
484 LogicalResult matchAndRewrite(tensor::PackOp packOp,
486 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
491 if (!controlFn(&packOp.getSourceMutable()))
495 if (packOp.getPaddingValue())
502 Value paddingVal = padOp.getConstantPaddingValue();
506 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
512 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
513 llvm::SmallBitVector innerDims(paddedDims.size());
514 for (int64_t dim : innerDimsPos)
516 if (paddedDims.anyCommon(innerDims))
525 auto empty = tensor::PackOp::createDestinationTensor(
526 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
528 auto sourcePack = rewriter.
create<tensor::PackOp>(
529 loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
530 std::nullopt, outerDimsPerm);
535 if (!outerDimsPerm.empty()) {
536 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
541 size_t pointLoopsSize = innerDimsPos.size();
542 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
543 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
545 auto newPadOp = rewriter.
create<tensor::PadOp>(
546 loc,
Type(), sourcePack, lowPad, highPad, paddingVal,
551 if (!padOp->hasOneUse()) {
552 auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
553 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
554 Value unpackedPad = rewriter.
create<tensor::UnPackOp>(
555 loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
560 rewriter.
replaceOp(packOp, newPadOp.getResult());
583 for (
auto pos : dimsPos) {
585 int64_t projectedPos = reassocIndices[pos].back();
586 for (
auto i : llvm::reverse(reassocIndices[pos])) {
587 int64_t dim = targetShape[i];
588 if (dim > 1 || ShapedType::isDynamic(dim)) {
593 projectedDimsPos.push_back(projectedPos);
595 return projectedDimsPos;
602 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
603 int64_t dim = shape[pos];
604 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
616 static int64_t applyPermutationAndReindexReassoc(
619 if (!permutation.empty())
620 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
623 for (
auto &index : indices) {
651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
652 tensor::PackOp packOp,
660 collapseOp.getReassociationIndices();
669 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
671 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
679 for (
auto outerPos : outerDimsPerm) {
680 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
681 reassocIndices[outerPos].begin(),
682 reassocIndices[outerPos].end());
685 auto emptyOp = tensor::PackOp::createDestinationTensor(
686 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
687 projectedInnerDimsPos, newOuterDimsPerm);
688 auto newPackOp = rewriter.
create<tensor::PackOp>(
689 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
690 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
697 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
699 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
700 newReassocIndices.push_back({nextPos});
704 auto newCollapseOp = rewriter.
create<tensor::CollapseShapeOp>(
705 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
706 rewriter.
replaceOp(packOp, newCollapseOp);
723 for (
auto pos : dimsPos) {
728 if (llvm::any_of(indices,
729 [&](int64_t expandDim) {
return expandDim == pos; })) {
730 projectedPos.push_back(idx);
735 assert(projectedPos.size() == dimsPos.size() &&
"Invalid dim pos projection");
758 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
759 tensor::PackOp packOp,
766 "non-identity outer dims perm NYI");
771 expandOp.getReassociationIndices();
774 packInnerDims.end());
781 llvm::set_intersection(packDimsPos, expandDimPos);
785 if (packedDims.empty())
790 if (packedDims.size() != 1)
792 packOp,
"only one of the expanded dimensions can be packed");
795 if (packedDims.front() != indices.back())
797 packOp,
"can only pack the inner-most expanded dimension");
802 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
812 RankedTensorType newPackType = tensor::PackOp::inferPackedType(
813 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
819 packOp,
"could not reassociate dims after bubbling up");
821 Value destTensor = tensor::PackOp::createDestinationTensor(
822 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
825 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
826 packOp.getMixedTiles(), packOp.getPaddingValue(),
829 Value newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
830 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
836 class BubbleUpPackOpThroughReshapeOp final
842 LogicalResult matchAndRewrite(tensor::PackOp packOp,
844 Operation *srcOp = packOp.getSource().getDefiningOp();
851 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
852 return ShapedType::isDynamic(size);
858 if (!controlFn(&packOp.getSourceMutable()))
862 .Case([&](tensor::CollapseShapeOp op) {
863 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
865 .Case([&](tensor::ExpandShapeOp op) {
866 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
868 .Default([](
Operation *) {
return failure(); });
894 static LogicalResult pushDownUnPackOpThroughExpandShape(
895 tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
898 if (!controlFn(&expandOp.getSrcMutable()))
905 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
910 expandOp.getReassociationIndices();
919 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
921 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
929 for (
auto outerPos : outerDimsPerm) {
930 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
931 reassocIndices[outerPos].begin(),
932 reassocIndices[outerPos].end());
940 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
942 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
943 newReassocIndices.push_back({nextPos});
947 RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
948 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
949 auto newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
950 expandOp.getLoc(), newExpandType, unPackOp.getSource(),
953 auto emptyOp = tensor::UnPackOp::createDestinationTensor(
954 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
955 projectedInnerDimsPos, newOuterDimsPerm);
956 auto newUnPackOp = rewriter.
create<tensor::UnPackOp>(
957 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
958 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
959 rewriter.
replaceOp(expandOp, newUnPackOp);
964 class PushDownUnPackOpThroughReshapeOp final
967 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
972 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
974 Value result = unPackOp.getResult();
980 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
981 return ShapedType::isDynamic(size);
988 .Case([&](tensor::ExpandShapeOp op) {
989 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
992 .Default([](
Operation *) {
return failure(); });
1002 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1004 for (
OpOperand &operand : genericOp->getOpOperands()) {
1008 if (unPackedOperand)
1010 unPackedOperand = &operand;
1012 if (!unPackedOperand)
1014 return unPackedOperand;
1051 static FailureOr<std::tuple<GenericOp, Value>>
1052 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
1054 if (genericOp.getNumResults() != 1)
1057 if (hasGatherSemantics(genericOp))
1061 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1062 if (failed(maybeUnPackedOperand))
1064 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1067 tensor::UnPackOp producerUnPackOp =
1069 assert(producerUnPackOp &&
"expect a valid UnPackOp");
1071 if (!controlFn(unPackedOperand))
1075 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1076 if (failed(packInfo))
1080 auto [packedOutOperand, packedOutIndexingMap] =
1081 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1082 genericOp, genericOp.getDpsInitOperand(0));
1083 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
1087 Value dest = packedOutOperand;
1088 if (
auto initTensor = genericOp.getDpsInitOperand(0)
1090 .getDefiningOp<tensor::EmptyOp>()) {
1092 dest = destPack.getDest();
1096 GenericOp newGenericOp =
1097 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1099 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1103 return std::make_tuple(newGenericOp, newResult);
1105 auto mixedTiles = destPack.getMixedTiles();
1106 auto innerDimsPos = destPack.getInnerDimsPos();
1107 auto outerDimsPerm = destPack.getOuterDimsPerm();
1112 auto loc = genericOp.getLoc();
1113 Value unPackDest = producerUnPackOp.getDest();
1114 auto genericOutType =
1115 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
1116 if (producerUnPackOp.getDestType() != genericOutType ||
1117 !genericOutType.hasStaticShape()) {
1118 unPackDest = tensor::UnPackOp::createDestinationTensor(
1119 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
1125 .
create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
1126 mixedTiles, outerDimsPerm)
1129 return std::make_tuple(newGenericOp, unPackOpRes);
1133 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
1135 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
1139 LogicalResult matchAndRewrite(GenericOp genericOp,
1141 auto genericAndRepl =
1142 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1143 if (failed(genericAndRepl))
1145 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1156 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1160 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1162 tensor::UnPackOp unpackOp =
1163 padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1167 if (!controlFn(&padOp.getSourceMutable()))
1172 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1174 llvm::SmallBitVector innerDims(paddedDims.size());
1175 for (int64_t dim : innerDimsPos)
1176 innerDims.flip(dim);
1177 if (paddedDims.anyCommon(innerDims))
1180 Value paddingVal = padOp.getConstantPaddingValue();
1188 if (!outerDimsPerm.empty()) {
1189 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1190 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1193 size_t pointLoopsSize = innerDimsPos.size();
1194 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1195 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1197 auto newPadOp = rewriter.
create<tensor::PadOp>(
1198 loc,
Type(), unpackOp.getSource(), lowPad, highPad,
1199 paddingVal, padOp.getNofold());
1202 Value outputUnPack = rewriter.
create<tensor::EmptyOp>(
1203 loc, padOp.getResultType().getShape(),
1204 padOp.getResultType().getElementType());
1206 Value replacement = rewriter.
create<tensor::UnPackOp>(
1207 loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1208 unpackOp.getMixedTiles(), outerDimsPerm);
1223 .
insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1224 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1225 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1226 patterns.
getContext(), controlPackUnPackPropagation);
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...