31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
36 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
37 #include "mlir/Dialect/Linalg/Passes.h.inc"
40 #define DEBUG_TYPE "linalg-drop-unit-dims"
86 if (!genericOp.hasTensorSemantics())
88 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
91 auto outputOperands = genericOp.getDpsInitsMutable();
94 if (genericOp.getMatchingBlockArgument(&op).use_empty())
96 candidates.insert(&op);
99 if (candidates.empty())
103 int64_t origNumInput = genericOp.getNumDpsInputs();
107 newIndexingMaps.append(indexingMaps.begin(),
108 std::next(indexingMaps.begin(), origNumInput));
110 newInputOperands.push_back(op->get());
111 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
113 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
118 llvm::to_vector(genericOp.getDpsInits());
122 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
123 auto empty = rewriter.
create<tensor::EmptyOp>(
126 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
127 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
130 auto newOp = rewriter.
create<GenericOp>(
131 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
132 newIndexingMaps, genericOp.getIteratorTypesArray(),
135 Region ®ion = newOp.getRegion();
137 region.push_back(block);
141 for (
auto bbarg : genericOp.getRegionInputArgs())
145 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
150 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
151 if (candidates.count(&op))
157 for (
auto &op : genericOp.getBody()->getOperations()) {
158 rewriter.
clone(op, mapper);
160 rewriter.
replaceOp(genericOp, newOp.getResults());
235 const llvm::SmallDenseSet<unsigned> &unitDims,
237 for (IndexOp indexOp :
238 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
241 if (unitDims.count(indexOp.getDim()) != 0) {
245 unsigned droppedDims = llvm::count_if(
246 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
247 if (droppedDims != 0)
249 indexOp.getDim() - droppedDims);
262 auto origResultType = cast<RankedTensorType>(origDest.
getType());
263 if (rankReductionStrategy ==
265 unsigned rank = origResultType.getRank();
271 loc, result, origDest, offsets, sizes, strides);
274 assert(rankReductionStrategy ==
276 "unknown rank reduction strategy");
277 return rewriter.
create<tensor::ExpandShapeOp>(loc, origResultType, result,
288 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
289 if (rankReductionStrategy ==
292 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
294 assert(
succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
295 return *rankReducingExtract;
299 rankReductionStrategy ==
301 "unknown rank reduction strategy");
302 MemRefLayoutAttrInterface layout;
303 auto targetType =
MemRefType::get(targetShape, memrefType.getElementType(),
304 layout, memrefType.getMemorySpace());
305 return rewriter.
create<memref::CollapseShapeOp>(loc, targetType, operand,
308 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
309 if (rankReductionStrategy ==
312 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
314 assert(
succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
315 return *rankReducingExtract;
319 rankReductionStrategy ==
321 "unknown rank reduction strategy");
324 return rewriter.
create<tensor::CollapseShapeOp>(loc, targetType, operand,
327 llvm_unreachable(
"unsupported operand type");
342 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
347 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
351 auto isUnitDim = [&](
unsigned dim) {
352 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
353 unsigned oldPosition = dimExpr.getPosition();
354 return !oldDimsToNewDimsMap.count(oldPosition);
358 if (operandShape[dim] == 1) {
359 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
360 return constAffineExpr && constAffineExpr.getValue() == 0;
366 while (dim < operandShape.size() && isUnitDim(dim))
367 reassociationGroup.push_back(dim++);
368 while (dim < operandShape.size()) {
369 assert(!isUnitDim(dim) &&
"expected non unit-extent");
370 reassociationGroup.push_back(dim);
371 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
372 newIndexExprs.push_back(newExpr);
376 while (dim < operandShape.size() && isUnitDim(dim)) {
377 reassociationGroup.push_back(dim++);
380 reassociationGroup.clear();
384 newIndexExprs, context);
391 if (indexingMaps.empty())
400 "invalid indexing maps for operation");
406 if (allowedUnitDims.empty()) {
408 genericOp,
"control function returns no allowed unit dims to prune");
410 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
411 allowedUnitDims.end());
412 llvm::SmallDenseSet<unsigned> unitDims;
414 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
415 if (dims[dimExpr.getPosition()] == 1 &&
416 unitDimsFilter.count(expr.index()))
417 unitDims.insert(expr.index());
424 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
426 unsigned newDims = 0;
427 for (
auto [index, attr] :
429 if (unitDims.count(index)) {
430 dimReplacements.push_back(
433 newIteratorTypes.push_back(attr);
434 oldDimToNewDimMap[index] = newDims;
435 dimReplacements.push_back(
456 auto hasCollapsibleType = [](
OpOperand &operand) {
457 Type operandType = operand.get().getType();
458 if (
auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
459 return memrefOperandType.getLayout().isIdentity();
460 }
else if (
auto tensorOperandType =
461 dyn_cast<RankedTensorType>(operandType)) {
462 return tensorOperandType.getEncoding() ==
nullptr;
466 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
467 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
469 if (!hasCollapsibleType(opOperand)) {
472 newIndexingMaps.push_back(newIndexingMap);
473 targetShapes.push_back(llvm::to_vector(shape));
474 collapsed.push_back(
false);
475 reassociations.push_back({});
479 rewriter.
getContext(), genericOp, &opOperand, oldDimToNewDimMap,
481 reassociations.push_back(replacementInfo.reassociation);
482 newIndexingMaps.push_back(replacementInfo.indexMap);
483 targetShapes.push_back(replacementInfo.targetShape);
484 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
485 indexingMap.getNumResults()));
490 if (newIndexingMaps == indexingMaps ||
500 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
501 int64_t idx = opOperand.getOperandNumber();
502 if (!collapsed[idx]) {
503 newOperands.push_back(opOperand.get());
506 newOperands.push_back(
collapseValue(rewriter, loc, opOperand.get(),
507 targetShapes[idx], reassociations[idx],
508 options.rankReductionStrategy));
518 resultTypes.reserve(genericOp.getNumResults());
519 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520 resultTypes.push_back(newOutputs[i].getType());
521 GenericOp replacementOp =
522 rewriter.
create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
523 newIndexingMaps, newIteratorTypes);
525 replacementOp.getRegion().begin());
534 for (
auto [index, result] :
llvm::enumerate(replacementOp.getResults())) {
535 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
536 Value origDest = genericOp.getDpsInitOperand(index)->get();
537 if (!collapsed[opOperandIndex]) {
538 resultReplacements.push_back(result);
541 resultReplacements.push_back(
expandValue(rewriter, loc, result, origDest,
542 reassociations[opOperandIndex],
543 options.rankReductionStrategy));
546 rewriter.
replaceOp(genericOp, resultReplacements);
568 struct RankReducedExtractSliceOp
572 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
574 RankedTensorType resultType = sliceOp.getType();
579 if (!reassociation ||
580 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
582 auto rankReducedType = cast<RankedTensorType>(
583 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
584 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
588 Value newSlice = rewriter.
create<tensor::ExtractSliceOp>(
589 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
591 sliceOp, resultType, newSlice, *reassociation);
598 template <
typename InsertOpTy>
604 RankedTensorType sourceType = insertSliceOp.getSourceType();
609 if (!reassociation ||
610 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
612 Location loc = insertSliceOp.getLoc();
613 tensor::CollapseShapeOp reshapedSource;
619 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
621 reshapedSource = rewriter.
create<tensor::CollapseShapeOp>(
622 loc, insertSliceOp.getSource(), *reassociation);
625 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
626 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
627 insertSliceOp.getMixedStrides());
641 patterns.
add<RankReducedExtractSliceOp,
642 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
643 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
645 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
646 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
647 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
648 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
658 options.rankReductionStrategy =
662 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
663 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
671 if (
options.rankReductionStrategy ==
674 }
else if (
options.rankReductionStrategy ==
676 ReassociativeReshape) {
683 patterns.
add<MoveInitOperandsToInput>(patterns.
getContext());
688 struct LinalgFoldUnitExtentDimsPass
689 :
public impl::LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
690 void runOnOperation()
override {
695 if (useRankReducingSlices) {
707 return std::make_unique<LinalgFoldUnitExtentDimsPass>();
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options)
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createLinalgFoldUnitExtentDimsPass()
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Transformation to drop unit-extent dimensions from linalg.generic operations.