20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39 for (
Operation &op : genericOp.getBody()->getOperations())
40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
48 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
60 template <
typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
63 OpTy packOrUnPackOp) {
64 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
65 "applies to only pack or unpack operations");
67 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
72 genericOp.getIteratorTypesArray();
75 int64_t origNumDims = indexingMap.getNumDims();
78 for (
auto [index, innerDimPos, tileSize] :
79 llvm::zip_equal(llvm::seq<unsigned>(0,
innerDimsPos.size()),
81 auto expr = exprs[innerDimPos];
82 if (!isa<AffineDimExpr>(expr))
84 int64_t domainDimPos =
85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
88 packInfo.tiledDimsPos.push_back(domainDimPos);
89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
92 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
93 <<
" to iteration dimension (d" << domainDimPos <<
", d"
94 << packInfo.tileToPointMapping[domainDimPos]
95 <<
"), which has size=("
96 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
102 auto areAllAffineDimExpr = [&](
int dim) {
104 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
112 for (int64_t i : packInfo.tiledDimsPos)
113 if (!areAllAffineDimExpr(i))
130 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131 auto permutedExpr = indexingMap.getResult(dim);
132 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133 permutedOuterDims.push_back(dimExpr.getPosition());
140 if (
static_cast<int64_t
>(index) != dim)
143 if (!permutedOuterDims.empty()) {
144 int64_t outerDimIndex = 0;
146 permutedOuterDims.end());
147 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148 packInfo.outerDimsOnDomainPerm.push_back(
149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
152 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
153 for (
auto dim : packInfo.outerDimsOnDomainPerm)
154 llvm::dbgs() << dim <<
" ";
155 llvm::dbgs() <<
"\n";
173 assert(!perm.empty() &&
"expect perm not to be empty");
174 assert(!exprs.empty() &&
"expect exprs not to be empty");
175 if (exprs.size() == 1)
183 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184 currentPositionTileLoops[dimExpr.getPosition()] = pos;
186 currentPositionTileLoops[pos] = pos;
188 for (int64_t loopIdx : perm) {
189 if (currentPositionTileLoops.count(loopIdx))
190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
226 static std::tuple<Value, AffineMap>
228 GenericOp genericOp,
OpOperand *opOperand) {
229 int64_t numOrigLoops = genericOp.getNumLoops();
230 int64_t numInnerLoops = packInfo.getNumTiledLoops();
231 int64_t numLoops = numOrigLoops + numInnerLoops;
232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
237 if (genericOp.isScalar(opOperand) || exprs.empty())
238 return std::make_tuple(opOperand->
get(),
244 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245 int64_t dimPos = dimExpr.getPosition();
246 domainDimToOperandDim[dimPos] = index;
252 for (
auto dimPos : packInfo.tiledDimsPos) {
253 if (!domainDimToOperandDim.count(dimPos))
255 int64_t index = domainDimToOperandDim[dimPos];
256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
263 if (!packInfo.outerDimsOnDomainPerm.empty()) {
264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
269 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
270 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271 int64_t dimPos = dimExpr.getPosition();
275 assert(isa<AffineConstantExpr>(exprs[i]) &&
276 "Attempted to permute non-constant and non-affine dim expression");
283 auxVec[en.index()] = exprs[en.value()];
291 return std::make_tuple(opOperand->
get(), indexingMap);
293 auto empty = linalg::PackOp::createDestinationTensor(
295 auto packedOperand = b.
create<linalg::PackOp>(
298 return std::make_tuple(packedOperand, indexingMap);
307 static GenericOp packGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
309 const PackInfo &packInfo,
310 bool isFoldableUnpackPack) {
315 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
316 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
317 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
318 llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
320 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
321 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
322 rewriter, loc, packInfo, genericOp, inputOperand);
323 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
324 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
325 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
326 inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
328 inputOperandsFromUnpackedSource.push_back(packedOperand);
330 inputOperands.push_back(packedOperand);
331 indexingMaps.push_back(packedIndexingMap);
336 if (isFoldableUnpackPack) {
337 inputOperands = inputOperandsFromUnpackedSource;
339 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
340 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
341 dest = destUnPack.getSource();
346 int64_t numInnerLoops = packInfo.getNumTiledLoops();
348 genericOp.getIteratorTypesArray();
349 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
351 indexingMaps.push_back(packedOutIndexingMap);
353 auto newGenericOp = rewriter.
create<linalg::GenericOp>(
354 loc, dest.
getType(), inputOperands, dest, indexingMaps, iterTypes,
357 newGenericOp.getRegion().begin());
404 static FailureOr<GenericOp>
405 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, linalg::PackOp packOp,
407 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
412 if (!controlFn(&packOp.getSourceMutable()))
418 if (hasGatherSemantics(genericOp))
423 if (genericOp.getNumResults() != 1)
430 if (!genericOp->getResult(0).hasOneUse())
437 if (packOp.getPaddingValue())
440 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
441 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
442 if (failed(packInfo))
455 Value packOpDest = packOp.getDest();
458 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
459 packOpDest = rewriter.
create<tensor::EmptyOp>(
460 genericOp->getLoc(), emptyOp.getMixedSizes(),
461 emptyOp.getType().getElementType());
469 auto [packedOutOperand, packedOutIndexingMap] =
470 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
471 genericOp, opOperand);
475 Value dest = packedOutOperand;
476 if (
auto initTensor = genericOp.getDpsInitOperand(0)
478 .getDefiningOp<tensor::EmptyOp>()) {
483 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
488 struct BubbleUpPackOpThroughGenericOpPattern
491 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
495 LogicalResult matchAndRewrite(linalg::PackOp packOp,
498 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
499 if (failed(genericOp))
501 rewriter.
replaceOp(packOp, genericOp->getResults());
512 class BubbleUpPackThroughPadOp final :
public OpRewritePattern<linalg::PackOp> {
517 LogicalResult matchAndRewrite(linalg::PackOp packOp,
519 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
524 if (!controlFn(&packOp.getSourceMutable()))
528 if (packOp.getPaddingValue())
535 Value paddingVal = padOp.getConstantPaddingValue();
539 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
545 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
546 llvm::SmallBitVector innerDims(paddedDims.size());
549 if (paddedDims.anyCommon(innerDims))
558 auto empty = linalg::PackOp::createDestinationTensor(
559 rewriter, loc, padOp.getSource(), mixedTiles,
innerDimsPos,
561 auto sourcePack = rewriter.
create<linalg::PackOp>(
562 loc, padOp.getSource(), empty,
innerDimsPos, mixedTiles,
569 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
570 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
575 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
576 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
578 auto newPadOp = rewriter.
create<tensor::PadOp>(
579 loc,
Type(), sourcePack, lowPad, highPad, paddingVal,
584 if (!padOp->hasOneUse()) {
585 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
587 Value unpackedPad = rewriter.
create<linalg::UnPackOp>(
593 rewriter.
replaceOp(packOp, newPadOp.getResult());
616 for (
auto pos : dimsPos) {
618 int64_t projectedPos = reassocIndices[pos].back();
619 for (
auto i : llvm::reverse(reassocIndices[pos])) {
620 int64_t dim = targetShape[i];
621 if (dim > 1 || ShapedType::isDynamic(dim)) {
626 projectedDimsPos.push_back(projectedPos);
628 return projectedDimsPos;
635 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
636 int64_t dim = shape[pos];
637 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
649 static int64_t applyPermutationAndReindexReassoc(
652 if (!permutation.empty())
653 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
656 for (
auto &index : indices) {
684 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
685 linalg::PackOp packOp,
693 collapseOp.getReassociationIndices();
702 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, srcShape);
704 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
713 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
715 auto emptyOp = linalg::PackOp::createDestinationTensor(
716 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
717 projectedInnerDimsPos, newOuterDimsPerm);
718 auto newPackOp = rewriter.
create<linalg::PackOp>(
719 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
720 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
727 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
730 newReassocIndices.push_back({nextPos});
734 auto newCollapseOp = rewriter.
create<tensor::CollapseShapeOp>(
735 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
736 rewriter.
replaceOp(packOp, newCollapseOp);
753 for (
auto pos : dimsPos) {
758 if (llvm::is_contained(indices, pos)) {
759 projectedPos.push_back(idx);
764 assert(projectedPos.size() == dimsPos.size() &&
"Invalid dim pos projection");
787 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
788 linalg::PackOp packOp,
795 "non-identity outer dims perm NYI");
800 expandOp.getReassociationIndices();
809 llvm::set_intersection(packDimsPos, expandDimPos);
813 if (packedDims.empty())
818 if (packedDims.size() != 1)
820 packOp,
"only one of the expanded dimensions can be packed");
823 if (packedDims.front() != indices.back())
825 packOp,
"can only pack the inner-most expanded dimension");
830 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
840 RankedTensorType newPackType = linalg::PackOp::inferPackedType(
841 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
847 packOp,
"could not reassociate dims after bubbling up");
849 Value destTensor = linalg::PackOp::createDestinationTensor(
850 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
853 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
854 packOp.getMixedTiles(), packOp.getPaddingValue(),
857 Value newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
858 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
864 class BubbleUpPackOpThroughReshapeOp final
870 LogicalResult matchAndRewrite(linalg::PackOp packOp,
872 Operation *srcOp = packOp.getSource().getDefiningOp();
879 if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
883 if (!controlFn(&packOp.getSourceMutable()))
887 .Case([&](tensor::CollapseShapeOp op) {
888 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
890 .Case([&](tensor::ExpandShapeOp op) {
891 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
893 .Default([](
Operation *) {
return failure(); });
919 static LogicalResult pushDownUnPackOpThroughExpandShape(
920 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
923 if (!controlFn(&expandOp.getSrcMutable()))
930 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
935 expandOp.getReassociationIndices();
944 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, dstShape);
946 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
955 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
962 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
965 newReassocIndices.push_back({nextPos});
969 RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
970 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
971 auto newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
972 expandOp.getLoc(), newExpandType, unPackOp.getSource(),
975 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
976 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
977 projectedInnerDimsPos, newOuterDimsPerm);
978 auto newUnPackOp = rewriter.
create<linalg::UnPackOp>(
979 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
980 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
981 rewriter.
replaceOp(expandOp, newUnPackOp);
986 class PushDownUnPackOpThroughReshapeOp final
989 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
994 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
996 Value result = unPackOp.getResult();
1002 if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1007 .Case([&](tensor::ExpandShapeOp op) {
1008 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1011 .Default([](
Operation *) {
return failure(); });
1021 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1023 for (
OpOperand &operand : genericOp->getOpOperands()) {
1027 if (unPackedOperand)
1029 unPackedOperand = &operand;
1031 if (!unPackedOperand)
1033 return unPackedOperand;
1070 static FailureOr<std::tuple<GenericOp, Value>>
1071 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
1073 if (genericOp.getNumResults() != 1)
1076 if (hasGatherSemantics(genericOp))
1080 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1081 if (failed(maybeUnPackedOperand))
1083 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1086 linalg::UnPackOp producerUnPackOp =
1088 assert(producerUnPackOp &&
"expect a valid UnPackOp");
1090 if (!controlFn(unPackedOperand))
1094 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1095 if (failed(packInfo))
1099 auto [packedOutOperand, packedOutIndexingMap] =
1100 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1101 genericOp, genericOp.getDpsInitOperand(0));
1102 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1106 Value dest = packedOutOperand;
1107 if (
auto initTensor = genericOp.getDpsInitOperand(0)
1109 .getDefiningOp<tensor::EmptyOp>()) {
1111 dest = destPack.getDest();
1118 GenericOp newGenericOp =
1119 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1122 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1126 return std::make_tuple(newGenericOp, newResult);
1128 auto mixedTiles = destPack.getMixedTiles();
1135 .
create<linalg::UnPackOp>(genericOp.getLoc(), newResult,
1140 return std::make_tuple(newGenericOp, unPackOpRes);
1144 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
1146 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
1150 LogicalResult matchAndRewrite(GenericOp genericOp,
1152 auto genericAndRepl =
1153 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1154 if (failed(genericAndRepl))
1156 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1167 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1171 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1173 linalg::UnPackOp unpackOp =
1174 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1178 if (!controlFn(&padOp.getSourceMutable()))
1183 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1185 llvm::SmallBitVector innerDims(paddedDims.size());
1187 innerDims.flip(dim);
1188 if (paddedDims.anyCommon(innerDims))
1191 Value paddingVal = padOp.getConstantPaddingValue();
1200 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
1201 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
1205 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1206 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1208 auto newPadOp = rewriter.
create<tensor::PadOp>(
1209 loc,
Type(), unpackOp.getSource(), lowPad, highPad,
1210 paddingVal, padOp.getNofold());
1213 Value outputUnPack = rewriter.
create<tensor::EmptyOp>(
1214 loc, padOp.getResultType().getShape(),
1215 padOp.getResultType().getElementType());
1217 Value replacement = rewriter.
create<linalg::UnPackOp>(
1234 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1235 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1236 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1237 patterns.getContext(), controlPackUnPackPropagation);
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...