20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
25 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 #define DEBUG_TYPE "linalg-data-layout-propagation"
36 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37 for (
Operation &op : genericOp.getBody()->getOperations())
38 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
46 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
58 template <
typename OpTy>
60 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
61 OpTy packOrUnPackOp) {
62 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
63 "applies to only pack or unpack operations");
65 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
67 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70 genericOp.getIteratorTypesArray();
73 int64_t origNumDims = indexingMap.getNumDims();
76 for (
auto [index, innerDimPos, tileSize] :
77 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79 auto expr = exprs[innerDimPos];
80 if (!isa<AffineDimExpr>(expr))
82 int64_t domainDimPos =
83 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86 packInfo.tiledDimsPos.push_back(domainDimPos);
87 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
90 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
91 <<
" to iteration dimension (d" << domainDimPos <<
", d"
92 << packInfo.tileToPointMapping[domainDimPos]
93 <<
"), which has size=("
94 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
100 auto areAllAffineDimExpr = [&](
int dim) {
102 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
103 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
110 for (int64_t i : packInfo.tiledDimsPos)
111 if (!areAllAffineDimExpr(i))
128 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129 auto permutedExpr = indexingMap.getResult(dim);
130 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131 permutedOuterDims.push_back(dimExpr.getPosition());
138 if (
static_cast<int64_t
>(index) != dim)
141 if (!permutedOuterDims.empty()) {
142 int64_t outerDimIndex = 0;
144 permutedOuterDims.end());
145 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146 packInfo.outerDimsOnDomainPerm.push_back(
147 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
151 for (
auto dim : packInfo.outerDimsOnDomainPerm)
152 llvm::dbgs() << dim <<
" ";
153 llvm::dbgs() <<
"\n";
171 assert(!perm.empty() &&
"expect perm not to be empty");
172 assert(!exprs.empty() &&
"expect exprs not to be empty");
173 if (exprs.size() == 1)
181 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182 currentPositionTileLoops[dimExpr.getPosition()] = pos;
184 currentPositionTileLoops[pos] = pos;
186 for (int64_t loopIdx : perm) {
187 if (currentPositionTileLoops.count(loopIdx))
188 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
190 return outerDimsPerm;
224 static std::tuple<Value, AffineMap>
226 GenericOp genericOp,
OpOperand *opOperand) {
227 int64_t numOrigLoops = genericOp.getNumLoops();
228 int64_t numInnerLoops = packInfo.getNumTiledLoops();
229 int64_t numLoops = numOrigLoops + numInnerLoops;
230 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
235 if (genericOp.isScalar(opOperand) || exprs.empty())
236 return std::make_tuple(opOperand->
get(),
242 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
243 int64_t dimPos = dimExpr.getPosition();
244 domainDimToOperandDim[dimPos] = index;
250 for (
auto dimPos : packInfo.tiledDimsPos) {
251 if (!domainDimToOperandDim.count(dimPos))
253 int64_t index = domainDimToOperandDim[dimPos];
254 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
255 innerDimsPos.push_back(index);
261 if (!packInfo.outerDimsOnDomainPerm.empty()) {
262 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
267 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
268 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
269 int64_t dimPos = dimExpr.getPosition();
273 assert(isa<AffineConstantExpr>(exprs[i]) &&
274 "Attempted to permute non-constant and non-affine dim expression");
278 if (!outerDimsPerm.empty()) {
280 for (
const auto &en :
enumerate(outerDimsPerm))
281 auxVec[en.index()] = exprs[en.value()];
288 if (innerDimsPos.empty() && outerDimsPerm.empty())
289 return std::make_tuple(opOperand->
get(), indexingMap);
291 auto empty = tensor::PackOp::createDestinationTensor(
292 b, loc, opOperand->
get(), innerTileSizes, innerDimsPos, outerDimsPerm);
293 auto packedOperand = b.
create<tensor::PackOp>(
294 loc, opOperand->
get(), empty, innerDimsPos, innerTileSizes,
295 std::nullopt, outerDimsPerm);
296 return std::make_tuple(packedOperand, indexingMap);
300 static GenericOp packGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
302 const PackInfo &packInfo) {
306 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
307 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
308 rewriter, loc, packInfo, genericOp, inputOperand);
309 inputOperands.push_back(packedOperand);
310 indexingMaps.push_back(packedIndexingMap);
313 int64_t numInnerLoops = packInfo.getNumTiledLoops();
315 genericOp.getIteratorTypesArray();
316 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
318 indexingMaps.push_back(packedOutIndexingMap);
320 auto newGenericOp = rewriter.
create<linalg::GenericOp>(
321 loc, dest.
getType(), inputOperands, dest, indexingMaps, iterTypes,
324 newGenericOp.getRegion().begin());
372 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, tensor::PackOp packOp,
374 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
379 if (!controlFn(genericOp))
385 if (hasGatherSemantics(genericOp))
390 if (genericOp.getNumResults() != 1)
397 if (!genericOp->getResult(0).hasOneUse())
410 Value packOpDest = packOp.getDest();
413 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
414 packOpDest = rewriter.
create<tensor::EmptyOp>(
415 genericOp->getLoc(), emptyOp.getMixedSizes(),
416 emptyOp.getType().getElementType());
427 if (packOp.getPaddingValue())
430 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
431 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
436 auto [packedOutOperand, packedOutIndexingMap] =
437 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
438 genericOp, opOperand);
442 Value dest = packedOutOperand;
443 if (
auto initTensor = genericOp.getDpsInitOperand(0)
445 .getDefiningOp<tensor::EmptyOp>()) {
448 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
453 struct BubbleUpPackOpThroughGenericOpPattern
456 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
463 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
466 rewriter.
replaceOp(packOp, genericOp->getResults());
477 class BubbleUpPackThroughPadOp final :
public OpRewritePattern<tensor::PackOp> {
484 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
489 if (!controlFn(padOp))
492 if (!padOp.getResult().hasOneUse())
496 if (packOp.getPaddingValue())
503 Value paddingVal = padOp.getConstantPaddingValue();
507 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
514 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
515 llvm::SmallBitVector innerDims(paddedDims.size());
516 for (int64_t dim : innerDimsPos)
518 if (paddedDims.anyCommon(innerDims))
525 auto empty = tensor::PackOp::createDestinationTensor(
526 rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
528 Value packedSource = rewriter.
create<tensor::PackOp>(
529 loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
530 std::nullopt, outerDimsPerm);
535 if (!outerDimsPerm.empty()) {
536 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
537 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
541 size_t pointLoopsSize = innerDimsPos.size();
542 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
543 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
545 auto newPadOp = rewriter.
create<tensor::PadOp>(
546 loc,
Type(), packedSource, lowPad, highPad, paddingVal,
548 rewriter.
replaceOp(packOp, newPadOp.getResult());
570 for (
auto pos : dimsPos) {
572 int64_t projectedPos = reassocIndices[pos].back();
573 for (
auto i : llvm::reverse(reassocIndices[pos])) {
574 int64_t dim = targetShape[i];
575 if (dim > 1 || ShapedType::isDynamic(dim)) {
580 projectedDimsPos.push_back(projectedPos);
582 return projectedDimsPos;
589 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
590 int64_t dim = shape[pos];
591 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
603 static int64_t applyPermutationAndReindexReassoc(
606 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
609 for (
auto &index : indices) {
637 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
638 tensor::PackOp packOp,
646 collapseOp.getReassociationIndices();
655 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
657 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
665 for (
auto outerPos : outerDimsPerm) {
666 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
667 reassocIndices[outerPos].begin(),
668 reassocIndices[outerPos].end());
671 auto emptyOp = tensor::PackOp::createDestinationTensor(
672 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
673 projectedInnerDimsPos, newOuterDimsPerm);
674 auto newPackOp = rewriter.
create<tensor::PackOp>(
675 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
676 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
683 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
685 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
686 newReassocIndices.push_back({nextPos});
690 auto newCollapseOp = rewriter.
create<tensor::CollapseShapeOp>(
691 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
692 rewriter.
replaceOp(packOp, newCollapseOp);
697 class BubbleUpPackOpThroughReshapeOp final
705 Operation *srcOp = packOp.getSource().getDefiningOp();
712 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
713 return ShapedType::isDynamic(size);
719 if (!controlFn(srcOp))
723 .Case([&](tensor::CollapseShapeOp op) {
724 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
753 pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
754 tensor::ExpandShapeOp expandOp,
762 expandOp.getReassociationIndices();
771 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
773 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
781 for (
auto outerPos : outerDimsPerm) {
782 newOuterDimsPerm.insert(newOuterDimsPerm.end(),
783 reassocIndices[outerPos].begin(),
784 reassocIndices[outerPos].end());
792 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
794 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
795 newReassocIndices.push_back({nextPos});
799 RankedTensorType newExpandType =
800 tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
801 projectedInnerDimsPos, newOuterDimsPerm);
802 auto newExpandOp = rewriter.
create<tensor::ExpandShapeOp>(
803 expandOp.getLoc(), newExpandType, unPackOp.getSource(),
806 auto emptyOp = tensor::UnPackOp::createDestinationTensor(
807 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
808 projectedInnerDimsPos, newOuterDimsPerm);
809 auto newUnPackOp = rewriter.
create<tensor::UnPackOp>(
810 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
811 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
812 rewriter.
replaceOp(expandOp, newUnPackOp);
817 class PushDownUnPackOpThroughReshapeOp final
820 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
827 Value result = unPackOp.getResult();
833 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
834 return ShapedType::isDynamic(size);
841 if (!controlFn(consumerOp))
845 .Case([&](tensor::ExpandShapeOp op) {
846 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
860 for (
OpOperand &operand : genericOp->getOpOperands()) {
866 unPackedOperand = &operand;
868 if (!unPackedOperand)
870 return unPackedOperand;
908 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp) {
909 if (genericOp.getNumResults() != 1)
912 if (hasGatherSemantics(genericOp))
916 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
917 if (
failed(maybeUnPackedOperand))
919 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
922 tensor::UnPackOp producerUnPackOp =
924 assert(producerUnPackOp &&
"expect a valid UnPackOp");
926 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
931 auto [packedOutOperand, packedOutIndexingMap] =
932 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
933 genericOp, genericOp.getDpsInitOperand(0));
934 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
938 Value dest = packedOutOperand;
939 if (
auto initTensor = genericOp.getDpsInitOperand(0)
941 .getDefiningOp<tensor::EmptyOp>()) {
943 dest = destPack.getDest();
947 GenericOp newGenericOp =
948 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
950 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
954 return std::make_tuple(newGenericOp, newResult);
956 auto mixedTiles = destPack.getMixedTiles();
957 auto innerDimsPos = destPack.getInnerDimsPos();
958 auto outerDimsPerm = destPack.getOuterDimsPerm();
963 auto loc = genericOp.getLoc();
964 Value unPackDest = producerUnPackOp.getDest();
965 auto genericOutType =
966 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
967 if (producerUnPackOp.getDestType() != genericOutType ||
968 !genericOutType.hasStaticShape()) {
969 unPackDest = tensor::UnPackOp::createDestinationTensor(
970 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
976 .
create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
977 mixedTiles, outerDimsPerm)
980 return std::make_tuple(newGenericOp, unPackOpRes);
984 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
986 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
992 if (!controlFn(genericOp))
995 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
996 if (
failed(genericAndRepl))
998 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1009 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1015 tensor::UnPackOp unpackOp =
1016 padOp.getSource().getDefiningOp<tensor::UnPackOp>();
1020 if (!controlFn(padOp))
1025 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1027 llvm::SmallBitVector innerDims(paddedDims.size());
1028 for (int64_t dim : innerDimsPos)
1029 innerDims.flip(dim);
1030 if (paddedDims.anyCommon(innerDims))
1033 Value paddingVal = padOp.getConstantPaddingValue();
1041 if (!outerDimsPerm.empty()) {
1042 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1043 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1046 size_t pointLoopsSize = innerDimsPos.size();
1047 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1048 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1050 auto newPadOp = rewriter.
create<tensor::PadOp>(
1051 loc,
Type(), unpackOp.getSource(), lowPad, highPad,
1052 paddingVal, padOp.getNofold());
1055 Value outputUnPack = rewriter.
create<tensor::EmptyOp>(
1056 loc, padOp.getResultType().getShape(),
1057 padOp.getResultType().getElementType());
1059 Value replacement = rewriter.
create<tensor::UnPackOp>(
1060 loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1061 unpackOp.getMixedTiles(), outerDimsPerm);
1076 .
insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1077 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1078 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1079 patterns.
getContext(), controlPackUnPackPropagation);
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(Operation *op)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...