31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
36 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
37 #include "mlir/Dialect/Linalg/Passes.h.inc"
40 #define DEBUG_TYPE "linalg-drop-unit-dims"
86 if (!genericOp.hasPureTensorSemantics())
88 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
91 auto outputOperands = genericOp.getDpsInitsMutable();
94 if (genericOp.getMatchingBlockArgument(&op).use_empty())
96 candidates.insert(&op);
99 if (candidates.empty())
103 int64_t origNumInput = genericOp.getNumDpsInputs();
107 newIndexingMaps.append(indexingMaps.begin(),
108 std::next(indexingMaps.begin(), origNumInput));
110 newInputOperands.push_back(op->get());
111 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
113 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
118 llvm::to_vector(genericOp.getDpsInits());
122 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
123 auto empty = rewriter.
create<tensor::EmptyOp>(
126 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
127 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
130 auto newOp = rewriter.
create<GenericOp>(
131 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
132 newIndexingMaps, genericOp.getIteratorTypesArray(),
136 Region ®ion = newOp.getRegion();
139 for (
auto bbarg : genericOp.getRegionInputArgs())
143 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
148 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
149 if (candidates.count(&op))
155 for (
auto &op : genericOp.getBody()->getOperations()) {
156 rewriter.
clone(op, mapper);
158 rewriter.
replaceOp(genericOp, newOp.getResults());
233 const llvm::SmallDenseSet<unsigned> &unitDims,
235 for (IndexOp indexOp :
236 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
239 if (unitDims.count(indexOp.getDim()) != 0) {
243 unsigned droppedDims = llvm::count_if(
244 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
245 if (droppedDims != 0)
247 indexOp.getDim() - droppedDims);
260 auto origResultType = cast<RankedTensorType>(origDest.
getType());
261 if (rankReductionStrategy ==
263 unsigned rank = origResultType.getRank();
269 loc, result, origDest, offsets, sizes, strides);
272 assert(rankReductionStrategy ==
274 "unknown rank reduction strategy");
275 return rewriter.
create<tensor::ExpandShapeOp>(loc, origResultType, result,
286 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
287 if (rankReductionStrategy ==
290 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
292 assert(
succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
293 return *rankReducingExtract;
297 rankReductionStrategy ==
299 "unknown rank reduction strategy");
300 MemRefLayoutAttrInterface layout;
301 auto targetType =
MemRefType::get(targetShape, memrefType.getElementType(),
302 layout, memrefType.getMemorySpace());
303 return rewriter.
create<memref::CollapseShapeOp>(loc, targetType, operand,
306 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
307 if (rankReductionStrategy ==
310 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
312 assert(
succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
313 return *rankReducingExtract;
317 rankReductionStrategy ==
319 "unknown rank reduction strategy");
322 return rewriter.
create<tensor::CollapseShapeOp>(loc, targetType, operand,
325 llvm_unreachable(
"unsupported operand type");
340 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
345 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
349 auto isUnitDim = [&](
unsigned dim) {
350 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
351 unsigned oldPosition = dimExpr.getPosition();
352 return !oldDimsToNewDimsMap.count(oldPosition);
356 if (operandShape[dim] == 1) {
357 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
358 return constAffineExpr && constAffineExpr.getValue() == 0;
364 while (dim < operandShape.size() && isUnitDim(dim))
365 reassociationGroup.push_back(dim++);
366 while (dim < operandShape.size()) {
367 assert(!isUnitDim(dim) &&
"expected non unit-extent");
368 reassociationGroup.push_back(dim);
369 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
370 newIndexExprs.push_back(newExpr);
374 while (dim < operandShape.size() && isUnitDim(dim)) {
375 reassociationGroup.push_back(dim++);
378 reassociationGroup.clear();
382 newIndexExprs, context);
389 if (indexingMaps.empty())
398 "invalid indexing maps for operation");
404 if (allowedUnitDims.empty()) {
406 genericOp,
"control function returns no allowed unit dims to prune");
408 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
409 allowedUnitDims.end());
410 llvm::SmallDenseSet<unsigned> unitDims;
412 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
413 if (dims[dimExpr.getPosition()] == 1 &&
414 unitDimsFilter.count(expr.index()))
415 unitDims.insert(expr.index());
422 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
424 unsigned newDims = 0;
425 for (
auto [index, attr] :
427 if (unitDims.count(index)) {
428 dimReplacements.push_back(
431 newIteratorTypes.push_back(attr);
432 oldDimToNewDimMap[index] = newDims;
433 dimReplacements.push_back(
454 auto hasCollapsibleType = [](
OpOperand &operand) {
455 Type operandType = operand.get().getType();
456 if (
auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
457 return memrefOperandType.getLayout().isIdentity();
459 if (
auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
460 return tensorOperandType.getEncoding() ==
nullptr;
464 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
465 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
467 if (!hasCollapsibleType(opOperand)) {
470 newIndexingMaps.push_back(newIndexingMap);
471 targetShapes.push_back(llvm::to_vector(shape));
472 collapsed.push_back(
false);
473 reassociations.push_back({});
477 rewriter.
getContext(), genericOp, &opOperand, oldDimToNewDimMap,
479 reassociations.push_back(replacementInfo.reassociation);
480 newIndexingMaps.push_back(replacementInfo.indexMap);
481 targetShapes.push_back(replacementInfo.targetShape);
482 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
483 indexingMap.getNumResults()));
488 if (newIndexingMaps == indexingMaps ||
498 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
499 int64_t idx = opOperand.getOperandNumber();
500 if (!collapsed[idx]) {
501 newOperands.push_back(opOperand.get());
504 newOperands.push_back(
collapseValue(rewriter, loc, opOperand.get(),
505 targetShapes[idx], reassociations[idx],
506 options.rankReductionStrategy));
516 resultTypes.reserve(genericOp.getNumResults());
517 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
518 resultTypes.push_back(newOutputs[i].getType());
519 GenericOp replacementOp =
520 rewriter.
create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
521 newIndexingMaps, newIteratorTypes);
523 replacementOp.getRegion().begin());
532 for (
auto [index, result] :
llvm::enumerate(replacementOp.getResults())) {
533 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
534 Value origDest = genericOp.getDpsInitOperand(index)->get();
535 if (!collapsed[opOperandIndex]) {
536 resultReplacements.push_back(result);
539 resultReplacements.push_back(
expandValue(rewriter, loc, result, origDest,
540 reassociations[opOperandIndex],
541 options.rankReductionStrategy));
544 rewriter.
replaceOp(genericOp, resultReplacements);
578 if (allowedUnitDims.empty()) {
580 padOp,
"control function returns no allowed unit dims to prune");
583 if (padOp.getSourceType().getEncoding()) {
585 padOp,
"cannot collapse dims of tensor with encoding");
592 Value paddingVal = padOp.getConstantPaddingValue();
595 padOp,
"unimplemented: non-constant padding value");
599 int64_t padRank = sourceShape.size();
603 return maybeInt && *maybeInt == 0;
606 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
607 allowedUnitDims.end());
608 llvm::SmallDenseSet<unsigned> unitDims;
612 for (
const auto [dim, size, low, high] :
613 zip_equal(llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
614 padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
615 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
616 isStaticZero(high)) {
617 unitDims.insert(dim);
619 newShape.push_back(size);
620 newLowPad.push_back(low);
621 newHighPad.push_back(high);
625 if (unitDims.empty()) {
632 while (dim < padRank && unitDims.contains(dim))
633 reassociationGroup.push_back(dim++);
634 while (dim < padRank) {
635 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
636 reassociationGroup.push_back(dim);
639 while (dim < padRank && unitDims.contains(dim))
640 reassociationGroup.push_back(dim++);
641 reassociationMap.push_back(reassociationGroup);
642 reassociationGroup.clear();
645 Value collapsedSource =
646 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
647 reassociationMap,
options.rankReductionStrategy);
649 auto newPadOp = rewriter.
create<tensor::PadOp>(
650 padOp.getLoc(),
Type(), collapsedSource, newLowPad,
651 newHighPad, paddingVal, padOp.getNofold());
653 Value dest = padOp.getResult();
654 if (
options.rankReductionStrategy ==
657 int64_t numUnitDims = 0;
658 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
659 if (unitDims.contains(dim)) {
665 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
667 dest = rewriter.
create<tensor::EmptyOp>(
668 padOp.getLoc(), expandedSizes,
669 padOp.getResultType().getElementType());
672 Value expandedValue =
673 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
674 reassociationMap,
options.rankReductionStrategy);
675 rewriter.
replaceOp(padOp, expandedValue);
686 struct RankReducedExtractSliceOp
690 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
692 RankedTensorType resultType = sliceOp.getType();
694 for (
auto size : resultType.getShape())
697 if (!reassociation ||
698 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
704 auto rankReducedType = cast<RankedTensorType>(
705 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
706 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
710 Value newSlice = rewriter.
create<tensor::ExtractSliceOp>(
711 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
713 sliceOp, resultType, newSlice, *reassociation);
720 template <
typename InsertOpTy>
726 RankedTensorType sourceType = insertSliceOp.getSourceType();
728 for (
auto size : sourceType.getShape())
731 if (!reassociation ||
732 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
735 Location loc = insertSliceOp.getLoc();
736 tensor::CollapseShapeOp reshapedSource;
742 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
744 reshapedSource = rewriter.
create<tensor::CollapseShapeOp>(
745 loc, insertSliceOp.getSource(), *reassociation);
748 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
749 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
750 insertSliceOp.getMixedStrides());
763 patterns.
add<DropPadUnitDims>(context,
options);
765 patterns.
add<RankReducedExtractSliceOp,
766 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
767 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
769 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
770 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
771 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
772 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
782 options.rankReductionStrategy =
785 patterns.
add<DropPadUnitDims>(context,
options);
787 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
788 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
796 if (
options.rankReductionStrategy ==
799 }
else if (
options.rankReductionStrategy ==
801 ReassociativeReshape) {
808 patterns.
add<MoveInitOperandsToInput>(patterns.
getContext());
813 struct LinalgFoldUnitExtentDimsPass
814 :
public impl::LinalgFoldUnitExtentDimsPassBase<
815 LinalgFoldUnitExtentDimsPass> {
816 using impl::LinalgFoldUnitExtentDimsPassBase<
817 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
818 void runOnOperation()
override {
823 if (useRankReducingSlices) {
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options)
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Transformation to drop unit-extent dimensions from linalg.generic operations.