20 #include "llvm/Support/Debug.h"
24 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
31 #define DEBUG_TYPE "linalg-data-layout-propagation"
35 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
36 for (
Operation &op : genericOp.getBody()->getOperations())
37 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
45 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
57 template <
typename OpTy>
59 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
60 OpTy packOrUnPackOp) {
61 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
62 "applies to only pack or unpack operations");
64 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
66 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
69 genericOp.getIteratorTypesArray();
72 int64_t origNumDims = indexingMap.getNumDims();
75 for (
auto [index, innerDimPos, tileSize] :
76 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
77 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
78 auto expr = exprs[innerDimPos];
79 if (!isa<AffineDimExpr>(expr))
81 int64_t domainDimPos =
82 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
85 packInfo.tiledDimsPos.push_back(domainDimPos);
86 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
87 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
90 <<
" to iteration dimension (d" << domainDimPos <<
", d"
91 << packInfo.tileToPointMapping[domainDimPos]
92 <<
"), which has size=("
93 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
99 auto areAllAffineDimExpr = [&](
int dim) {
101 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
102 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
109 for (int64_t i : packInfo.tiledDimsPos)
110 if (!areAllAffineDimExpr(i))
127 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
128 auto permutedExpr = indexingMap.getResult(dim);
129 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
130 permutedOuterDims.push_back(dimExpr.getPosition());
137 if (
static_cast<int64_t
>(index) != dim)
140 if (!permutedOuterDims.empty()) {
141 int64_t outerDimIndex = 0;
143 permutedOuterDims.end());
144 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
145 packInfo.outerDimsOnDomainPerm.push_back(
146 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
149 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
150 for (
auto dim : packInfo.outerDimsOnDomainPerm)
151 llvm::dbgs() << dim <<
" ";
152 llvm::dbgs() <<
"\n";
170 assert(!perm.empty() &&
"expect perm not to be empty");
171 assert(!exprs.empty() &&
"expect exprs not to be empty");
172 if (exprs.size() == 1)
180 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
181 currentPositionTileLoops[dimExpr.getPosition()] = pos;
183 currentPositionTileLoops[pos] = pos;
185 for (int64_t loopIdx : perm) {
186 if (currentPositionTileLoops.count(loopIdx))
187 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
189 return outerDimsPerm;
223 static std::tuple<Value, AffineMap>
225 GenericOp genericOp,
OpOperand *opOperand) {
226 int64_t numOrigLoops = genericOp.getNumLoops();
227 int64_t numInnerLoops = packInfo.getNumTiledLoops();
228 int64_t numLoops = numOrigLoops + numInnerLoops;
229 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
234 if (genericOp.isScalar(opOperand) || exprs.empty())
235 return std::make_tuple(opOperand->
get(),
241 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
242 int64_t dimPos = dimExpr.getPosition();
243 domainDimToOperandDim[dimPos] = index;
249 for (
auto dimPos : packInfo.tiledDimsPos) {
250 if (!domainDimToOperandDim.count(dimPos))
252 int64_t index = domainDimToOperandDim[dimPos];
253 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
254 innerDimsPos.push_back(index);
260 if (!packInfo.outerDimsOnDomainPerm.empty()) {
261 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
266 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
267 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
268 int64_t dimPos = dimExpr.getPosition();
272 assert(isa<AffineConstantExpr>(exprs[i]) &&
273 "Attempted to permute non-constant and non-affine dim expression");
277 if (!outerDimsPerm.empty()) {
279 for (
const auto &en :
enumerate(outerDimsPerm))
280 auxVec[en.index()] = exprs[en.value()];
287 if (innerDimsPos.empty() && outerDimsPerm.empty())
288 return std::make_tuple(opOperand->
get(), indexingMap);
290 auto empty = tensor::PackOp::createDestinationTensor(
291 b, loc, opOperand->
get(), innerTileSizes, innerDimsPos, outerDimsPerm);
292 auto packedOperand = b.
create<tensor::PackOp>(
293 loc, opOperand->
get(), empty, innerDimsPos, innerTileSizes,
294 std::nullopt, outerDimsPerm);
295 return std::make_tuple(packedOperand, indexingMap);
299 static GenericOp packGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
301 const PackInfo &packInfo) {
305 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
306 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
307 rewriter, loc, packInfo, genericOp, inputOperand);
308 inputOperands.push_back(packedOperand);
309 indexingMaps.push_back(packedIndexingMap);
312 int64_t numInnerLoops = packInfo.getNumTiledLoops();
314 genericOp.getIteratorTypesArray();
315 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
317 indexingMaps.push_back(packedOutIndexingMap);
319 auto newGenericOp = rewriter.
create<linalg::GenericOp>(
320 loc, dest.
getType(), inputOperands, dest, indexingMaps, iterTypes,
323 newGenericOp.getRegion().begin());
371 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, tensor::PackOp packOp,
373 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
378 if (!controlFn(genericOp))
384 if (hasGatherSemantics(genericOp))
389 if (genericOp.getNumResults() != 1)
396 if (!genericOp->getResult(0).hasOneUse())
409 Value packOpDest = packOp.getDest();
412 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
413 packOpDest = rewriter.
create<tensor::EmptyOp>(
414 genericOp->getLoc(), emptyOp.getMixedSizes(),
415 emptyOp.getType().getElementType());
426 if (packOp.getPaddingValue())
429 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
430 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
435 auto [packedOutOperand, packedOutIndexingMap] =
436 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
437 genericOp, opOperand);
441 Value dest = packedOutOperand;
442 if (
auto initTensor = genericOp.getDpsInitOperand(0)
444 .getDefiningOp<tensor::EmptyOp>()) {
447 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
452 struct BubbleUpPackOpThroughGenericOpPattern
455 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
462 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
465 rewriter.
replaceOp(packOp, genericOp->getResults());
478 for (
OpOperand &operand : genericOp->getOpOperands()) {
484 unPackedOperand = &operand;
486 if (!unPackedOperand)
488 return unPackedOperand;
526 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp) {
527 if (genericOp.getNumResults() != 1)
530 if (hasGatherSemantics(genericOp))
534 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
535 if (
failed(maybeUnPackedOperand))
537 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
540 tensor::UnPackOp producerUnPackOp =
542 assert(producerUnPackOp &&
"expect a valid UnPackOp");
544 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
549 auto [packedOutOperand, packedOutIndexingMap] =
550 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
551 genericOp, genericOp.getDpsInitOperand(0));
552 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
556 Value dest = packedOutOperand;
557 if (
auto initTensor = genericOp.getDpsInitOperand(0)
559 .getDefiningOp<tensor::EmptyOp>()) {
561 dest = destPack.getDest();
565 GenericOp newGenericOp =
566 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
568 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
572 return std::make_tuple(newGenericOp, newResult);
574 auto mixedTiles = destPack.getMixedTiles();
575 auto innerDimsPos = destPack.getInnerDimsPos();
576 auto outerDimsPerm = destPack.getOuterDimsPerm();
581 auto loc = genericOp.getLoc();
582 Value unPackDest = producerUnPackOp.getDest();
583 auto genericOutType =
584 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
585 if (producerUnPackOp.getDestType() != genericOutType ||
586 !genericOutType.hasStaticShape()) {
587 unPackDest = tensor::UnPackOp::createDestinationTensor(
588 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
594 .
create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
595 mixedTiles, outerDimsPerm)
598 return std::make_tuple(newGenericOp, unPackOpRes);
602 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
604 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
610 if (!controlFn(genericOp))
613 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
614 if (
failed(genericAndRepl))
616 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
627 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
633 tensor::UnPackOp unpackOp =
634 padOp.getSource().getDefiningOp<tensor::UnPackOp>();
638 if (!controlFn(padOp))
643 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
645 llvm::SmallBitVector innerDims(paddedDims.size());
646 for (int64_t dim : innerDimsPos)
648 if (paddedDims.anyCommon(innerDims))
651 Value paddingVal = padOp.getConstantPaddingValue();
659 if (!outerDimsPerm.empty()) {
660 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
661 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
664 size_t pointLoopsSize = innerDimsPos.size();
665 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
666 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
668 auto newPadOp = rewriter.
create<tensor::PadOp>(
669 loc,
Type(), unpackOp.getSource(), lowPad, highPad,
670 paddingVal, padOp.getNofold());
673 Value outputUnPack = rewriter.
create<tensor::EmptyOp>(
674 loc, padOp.getResultType().getShape(),
675 padOp.getResultType().getElementType());
677 Value replacement = rewriter.
create<tensor::UnPackOp>(
678 loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
679 unpackOp.getMixedTiles(), outerDimsPerm);
693 patterns.
insert<BubbleUpPackOpThroughGenericOpPattern,
694 PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
695 patterns.
getContext(), controlPackUnPackPropagation);
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(Operation *op)> ControlPropagationFn
Function type which is used to control propagation of tensor.pack/unpack ops.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...