20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
28 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 #define DEBUG_TYPE "linalg-data-layout-propagation"
38 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
39 for (
Operation &op : genericOp.getBody()->getOperations())
40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
48 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
60 template <
typename OpTy>
61 static FailureOr<PackInfo>
62 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
63 OpTy packOrUnPackOp) {
64 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
65 "applies to only pack or unpack operations");
67 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
72 genericOp.getIteratorTypesArray();
75 int64_t origNumDims = indexingMap.getNumDims();
78 for (
auto [index, innerDimPos, tileSize] :
79 llvm::zip_equal(llvm::seq<unsigned>(0,
innerDimsPos.size()),
81 auto expr = exprs[innerDimPos];
82 if (!isa<AffineDimExpr>(expr))
84 int64_t domainDimPos =
85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
88 packInfo.tiledDimsPos.push_back(domainDimPos);
89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
92 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
93 <<
" to iteration dimension (d" << domainDimPos <<
", d"
94 << packInfo.tileToPointMapping[domainDimPos]
95 <<
"), which has size=("
96 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
102 auto areAllAffineDimExpr = [&](
int dim) {
104 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
112 for (int64_t i : packInfo.tiledDimsPos)
113 if (!areAllAffineDimExpr(i))
130 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
131 auto permutedExpr = indexingMap.getResult(dim);
132 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
133 permutedOuterDims.push_back(dimExpr.getPosition());
140 if (
static_cast<int64_t
>(index) != dim)
143 if (!permutedOuterDims.empty()) {
144 int64_t outerDimIndex = 0;
146 permutedOuterDims.end());
147 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
148 packInfo.outerDimsOnDomainPerm.push_back(
149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
152 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
153 for (
auto dim : packInfo.outerDimsOnDomainPerm)
154 llvm::dbgs() << dim <<
" ";
155 llvm::dbgs() <<
"\n";
173 assert(!perm.empty() &&
"expect perm not to be empty");
174 assert(!exprs.empty() &&
"expect exprs not to be empty");
175 if (exprs.size() == 1)
183 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
184 currentPositionTileLoops[dimExpr.getPosition()] = pos;
186 currentPositionTileLoops[pos] = pos;
188 for (int64_t loopIdx : perm) {
189 if (currentPositionTileLoops.count(loopIdx))
190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
226 static std::tuple<Value, AffineMap>
228 GenericOp genericOp,
OpOperand *opOperand) {
229 int64_t numOrigLoops = genericOp.getNumLoops();
230 int64_t numInnerLoops = packInfo.getNumTiledLoops();
231 int64_t numLoops = numOrigLoops + numInnerLoops;
232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
237 if (genericOp.isScalar(opOperand) || exprs.empty())
238 return std::make_tuple(opOperand->
get(),
244 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
245 int64_t dimPos = dimExpr.getPosition();
246 domainDimToOperandDim[dimPos] = index;
252 for (
auto dimPos : packInfo.tiledDimsPos) {
253 if (!domainDimToOperandDim.count(dimPos))
255 int64_t index = domainDimToOperandDim[dimPos];
256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
263 if (!packInfo.outerDimsOnDomainPerm.empty()) {
264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
269 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
270 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
271 int64_t dimPos = dimExpr.getPosition();
275 assert(isa<AffineConstantExpr>(exprs[i]) &&
276 "Attempted to permute non-constant and non-affine dim expression");
283 auxVec[en.index()] = exprs[en.value()];
291 return std::make_tuple(opOperand->
get(), indexingMap);
293 auto empty = linalg::PackOp::createDestinationTensor(
295 auto packedOperand = b.
create<linalg::PackOp>(
298 return std::make_tuple(packedOperand, indexingMap);
302 static GenericOp packGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
304 const PackInfo &packInfo) {
308 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310 rewriter, loc, packInfo, genericOp, inputOperand);
311 inputOperands.push_back(packedOperand);
312 indexingMaps.push_back(packedIndexingMap);
315 int64_t numInnerLoops = packInfo.getNumTiledLoops();
317 genericOp.getIteratorTypesArray();
318 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
320 indexingMaps.push_back(packedOutIndexingMap);
322 auto newGenericOp = rewriter.
create<linalg::GenericOp>(
323 loc, dest.
getType(), inputOperands, dest, indexingMaps, iterTypes,
326 newGenericOp.getRegion().begin());
373 static FailureOr<GenericOp>
374 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, linalg::PackOp packOp,
376 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
381 if (!controlFn(&packOp.getSourceMutable()))
387 if (hasGatherSemantics(genericOp))
392 if (genericOp.getNumResults() != 1)
399 if (!genericOp->getResult(0).hasOneUse())
406 if (packOp.getPaddingValue())
409 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
410 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
411 if (failed(packInfo))
424 Value packOpDest = packOp.getDest();
427 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
428 packOpDest = rewriter.
create<tensor::EmptyOp>(
429 genericOp->getLoc(), emptyOp.getMixedSizes(),
430 emptyOp.getType().getElementType());
438 auto [packedOutOperand, packedOutIndexingMap] =
439 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
440 genericOp, opOperand);
444 Value dest = packedOutOperand;
445 if (
auto initTensor = genericOp.getDpsInitOperand(0)
447 .getDefiningOp<tensor::EmptyOp>()) {
450 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
455 struct BubbleUpPackOpThroughGenericOpPattern
458 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
462 LogicalResult matchAndRewrite(linalg::PackOp packOp,
465 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466 if (failed(genericOp))
468 rewriter.
replaceOp(packOp, genericOp->getResults());
479 class BubbleUpPackThroughPadOp final :
public OpRewritePattern<linalg::PackOp> {
484 LogicalResult matchAndRewrite(linalg::PackOp packOp,
486 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
491 if (!controlFn(&packOp.getSourceMutable()))
495 if (packOp.getPaddingValue())
502 Value paddingVal = padOp.getConstantPaddingValue();
506 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
512 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
513 llvm::SmallBitVector innerDims(paddedDims.size());
516 if (paddedDims.anyCommon(innerDims))
525 auto empty = linalg::PackOp::createDestinationTensor(
526 rewriter, loc, padOp.getSource(), mixedTiles,
innerDimsPos,
528 auto sourcePack = rewriter.
create<linalg::PackOp>(
529 loc, padOp.getSource(), empty,
innerDimsPos, mixedTiles,
536 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
537 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
542 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
543 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
545 auto newPadOp = rewriter.
create<tensor::PadOp>(
546 loc,
Type(), sourcePack, lowPad, highPad, paddingVal,
551 if (!padOp->hasOneUse()) {
552 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
554 Value unpackedPad = rewriter.
create<linalg::UnPackOp>(
560 rewriter.
replaceOp(packOp, newPadOp.getResult());
583 for (
auto pos : dimsPos) {
585 int64_t projectedPos = reassocIndices[pos].back();
586 for (
auto i : llvm::reverse(reassocIndices[pos])) {
587 int64_t dim = targetShape[i];
588 if (dim > 1 || ShapedType::isDynamic(dim)) {
593 projectedDimsPos.push_back(projectedPos);
595 return projectedDimsPos;
602 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
603 int64_t dim = shape[pos];
604 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
616 static int64_t applyPermutationAndReindexReassoc(
619 if (!permutation.empty())
620 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
623 for (
auto &index : indices) {
651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
652 linalg::PackOp packOp,
660 collapseOp.getReassociationIndices();
669 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, srcShape);
671 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
680 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
682 auto emptyOp = linalg::PackOp::createDestinationTensor(
683 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
684 projectedInnerDimsPos, newOuterDimsPerm);
685 auto newPackOp = rewriter.
create<linalg::PackOp>(
686 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
687 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
694 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
697 newReassocIndices.push_back({nextPos});
701 auto newCollapseOp = rewriter.
create<tensor::CollapseShapeOp>(
702 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
703 rewriter.
replaceOp(packOp, newCollapseOp);
720 for (
auto pos : dimsPos) {
725 if (llvm::is_contained(indices, pos)) {
726 projectedPos.push_back(idx);
731 assert(projectedPos.size() == dimsPos.size() &&
"Invalid dim pos projection");
754 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
755 linalg::PackOp packOp,
762 "non-identity outer dims perm NYI");
767 expandOp.getReassociationIndices();
776 llvm::set_intersection(packDimsPos, expandDimPos);
780 if (packedDims.empty())
785 if (packedDims.size() != 1)
787 packOp,
"only one of the expanded dimensions can be packed");
790 if (packedDims.front() != indices.back())
792 packOp,
"can only pack the inner-most expanded dimension");
797 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
807 RankedTensorType newPackType = linalg::PackOp::inferPackedType(
808 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
814 packOp,
"could not reassociate dims after bubbling up");
816 Value destTensor = linalg::PackOp::createDestinationTensor(
817 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
820 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
821 packOp.getMixedTiles(), packOp.getPaddingValue(),
824 Value newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
825 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
831 class BubbleUpPackOpThroughReshapeOp final
837 LogicalResult matchAndRewrite(linalg::PackOp packOp,
839 Operation *srcOp = packOp.getSource().getDefiningOp();
846 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
847 return ShapedType::isDynamic(size);
853 if (!controlFn(&packOp.getSourceMutable()))
857 .Case([&](tensor::CollapseShapeOp op) {
858 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
860 .Case([&](tensor::ExpandShapeOp op) {
861 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
863 .Default([](
Operation *) {
return failure(); });
889 static LogicalResult pushDownUnPackOpThroughExpandShape(
890 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
893 if (!controlFn(&expandOp.getSrcMutable()))
900 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
905 expandOp.getReassociationIndices();
914 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, dstShape);
916 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
925 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
932 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
935 newReassocIndices.push_back({nextPos});
939 RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
940 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
941 auto newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
942 expandOp.getLoc(), newExpandType, unPackOp.getSource(),
945 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
946 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
947 projectedInnerDimsPos, newOuterDimsPerm);
948 auto newUnPackOp = rewriter.
create<linalg::UnPackOp>(
949 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
950 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
951 rewriter.
replaceOp(expandOp, newUnPackOp);
956 class PushDownUnPackOpThroughReshapeOp final
959 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
964 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
966 Value result = unPackOp.getResult();
972 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
973 return ShapedType::isDynamic(size);
980 .Case([&](tensor::ExpandShapeOp op) {
981 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
984 .Default([](
Operation *) {
return failure(); });
994 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
996 for (
OpOperand &operand : genericOp->getOpOperands()) {
1000 if (unPackedOperand)
1002 unPackedOperand = &operand;
1004 if (!unPackedOperand)
1006 return unPackedOperand;
1043 static FailureOr<std::tuple<GenericOp, Value>>
1044 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
1046 if (genericOp.getNumResults() != 1)
1049 if (hasGatherSemantics(genericOp))
1053 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1054 if (failed(maybeUnPackedOperand))
1056 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1059 linalg::UnPackOp producerUnPackOp =
1061 assert(producerUnPackOp &&
"expect a valid UnPackOp");
1063 if (!controlFn(unPackedOperand))
1067 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1068 if (failed(packInfo))
1072 auto [packedOutOperand, packedOutIndexingMap] =
1073 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1074 genericOp, genericOp.getDpsInitOperand(0));
1075 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1079 Value dest = packedOutOperand;
1080 if (
auto initTensor = genericOp.getDpsInitOperand(0)
1082 .getDefiningOp<tensor::EmptyOp>()) {
1084 dest = destPack.getDest();
1088 GenericOp newGenericOp =
1089 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1091 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1095 return std::make_tuple(newGenericOp, newResult);
1097 auto mixedTiles = destPack.getMixedTiles();
1104 .
create<linalg::UnPackOp>(genericOp.getLoc(), newResult,
1109 return std::make_tuple(newGenericOp, unPackOpRes);
1113 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
1115 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
1119 LogicalResult matchAndRewrite(GenericOp genericOp,
1121 auto genericAndRepl =
1122 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1123 if (failed(genericAndRepl))
1125 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1136 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1140 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1142 linalg::UnPackOp unpackOp =
1143 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1147 if (!controlFn(&padOp.getSourceMutable()))
1152 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1154 llvm::SmallBitVector innerDims(paddedDims.size());
1156 innerDims.flip(dim);
1157 if (paddedDims.anyCommon(innerDims))
1160 Value paddingVal = padOp.getConstantPaddingValue();
1169 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
1170 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
1174 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1175 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1177 auto newPadOp = rewriter.
create<tensor::PadOp>(
1178 loc,
Type(), unpackOp.getSource(), lowPad, highPad,
1179 paddingVal, padOp.getNofold());
1182 Value outputUnPack = rewriter.
create<tensor::EmptyOp>(
1183 loc, padOp.getResultType().getShape(),
1184 padOp.getResultType().getElementType());
1186 Value replacement = rewriter.
create<linalg::UnPackOp>(
1203 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1204 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1205 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1206 patterns.getContext(), controlPackUnPackPropagation);
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Include the generated interface declarations.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...