19 static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
25 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
26 return llvm::count_if(
27 shape, [](int64_t v) {
return ShapedType::isDynamic(v) || v > 1; });
35 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
36 ArrayRef<int64_t> srcShape,
37 ArrayRef<int64_t> innerPackTileSize) {
38 if (getNumGtOneDims(srcShape) > 1) {
39 return rewriter.notifyMatchFailure(
40 op,
"expects non-packed domain to have at most one non-unit dims");
44 if (getNumGtOneDims(innerPackTileSize) > 1) {
45 return rewriter.notifyMatchFailure(
46 op,
"expects at most one non-unit inner tiles");
53 static FailureOr<SmallVector<int64_t>>
54 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
55 if (
auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
56 return SmallVector<int64_t>(transposeOp.getPermutation());
57 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
60 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
62 auto mapRange = linalgOp.getIndexingMapsArray();
63 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
64 mapRange.front() == mapRange.back()) {
67 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
69 AffineMap outMap = mapRange.back();
70 AffineMap inMap = mapRange.front();
73 return llvm::map_to_vector(outMap.getResults(),
74 [&](AffineExpr expr) -> int64_t {
75 return *inMap.getResultPosition(expr);
80 struct SimplifyPackToExpandShape :
public OpRewritePattern<PackOp> {
84 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
86 ArrayRef<ReassociationIndices> reassociation)
const {
87 if (operand.getType() == newOperandType)
90 .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
96 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
97 PackOp packOp)
const {
98 auto outerDimsPerm = packOp.getOuterDimsPerm();
100 return rewriter.notifyMatchFailure(
102 "expects outer_dims_perm is empty or an identity permutation");
105 int64_t srcRank = packOp.getSourceRank();
106 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
107 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
108 return rewriter.notifyMatchFailure(
109 packOp,
"expects packing at the innermost dimension");
114 LogicalResult matchAndRewrite(PackOp packOp,
115 PatternRewriter &rewriter)
const override {
116 if (packOp.getPaddingValue())
117 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
119 RankedTensorType sourceType = packOp.getSourceType();
120 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
121 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
122 packOp.getStaticTiles())) &&
123 !packOp.isLikePad()) {
127 RankedTensorType destType = packOp.getDestType();
132 FailureOr<Value> expanded =
133 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
135 if (failed(expanded)) {
136 return rewriter.notifyMatchFailure(
137 packOp,
"unable to expand source of tensor.pack");
139 rewriter.replaceOp(packOp, *expanded);
144 struct SimplifyUnPackToCollapseShape :
public OpRewritePattern<UnPackOp> {
147 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
148 Type newOperandType, ArrayAttr reassociation)
const {
149 if (operand.getType() == newOperandType)
151 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
152 operand, reassociation);
156 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
157 UnPackOp unpackOp)
const {
158 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
160 return rewriter.notifyMatchFailure(
162 "expects outer_dims_perm is empty or an identity permutation");
165 RankedTensorType sourceType = unpackOp.getSourceType();
166 RankedTensorType destType = unpackOp.getDestType();
167 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
168 return rewriter.notifyMatchFailure(unpackOp,
"expects static shapes");
170 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
171 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
172 return rewriter.notifyMatchFailure(
173 unpackOp,
"expects unpacking on the innermost dimension");
179 LogicalResult matchAndRewrite(UnPackOp unpackOp,
180 PatternRewriter &rewriter)
const override {
181 RankedTensorType destType = unpackOp.getDestType();
182 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
183 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
184 unpackOp.getStaticTiles())) &&
185 !unpackOp.isLikeUnPad()) {
189 RankedTensorType sourceType = unpackOp.getSourceType();
194 Value collapsed = insertCollapse(
195 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
197 rewriter.replaceOp(unpackOp, collapsed);
204 struct FoldPadWithPackOp :
public OpRewritePattern<PackOp> {
207 LogicalResult matchAndRewrite(PackOp packOp,
208 PatternRewriter &rewriter)
const override {
209 auto padOp = packOp.getSource().getDefiningOp<PadOp>();
211 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
214 Value constantPaddingValue = padOp.getConstantPaddingValue();
215 if (!constantPaddingValue)
218 if (
auto paddingValue = packOp.getPaddingValue())
222 rewriter.replaceOpWithNewOp<PackOp>(
223 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
224 packOp.getMixedTiles(), constantPaddingValue,
225 packOp.getOuterDimsPerm());
232 struct FoldUnpackWithExtractSliceOp :
public OpRewritePattern<ExtractSliceOp> {
235 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
236 PatternRewriter &rewriter)
const override {
237 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
241 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
242 return rewriter.notifyMatchFailure(
243 sliceOp,
"rank-reduced folding is not supported");
247 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
248 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
249 return rewriter.notifyMatchFailure(
250 sliceOp,
"expects offsets to be 0s and strides to be 1s");
254 Type elementType = unpackOp.getDestType().getElementType();
255 Value output = rewriter.create<EmptyOp>(
256 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
257 rewriter.replaceOpWithNewOp<UnPackOp>(
258 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
259 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
271 static bool checkAndPermute(ArrayRef<int64_t> permutation,
272 ArrayRef<int64_t> inVec,
273 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
275 for (
unsigned int i = 0; i < rank; ++i) {
276 int64_t remappedPosition = permutation[i];
277 if (remappedPosition >= rank)
280 remappedPosition = inVec[remappedPosition];
281 resVec.push_back(remappedPosition);
289 struct FoldProducerPackWithConsumerLinalgTransposeOp
290 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
293 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
294 PatternRewriter &rewriter)
const override {
295 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
300 FailureOr<SmallVector<int64_t>> maybePerm =
301 getTransposeOpPermutation(linalgOp);
302 if (failed(maybePerm))
305 auto innerDimsPos = packOp.getInnerDimsPos();
306 auto mixedInnerTiles = packOp.getMixedTiles();
307 auto outerDimsPerm = packOp.getOuterDimsPerm();
308 auto transposePerm = maybePerm.value();
309 SmallVector<int64_t> newOuterDimsPermVec;
310 SmallVector<int64_t> newInnerDimsPosVec;
311 SmallVector<OpFoldResult> newMixedInnerTilesVec;
312 int64_t srcRank = packOp.getSourceRank();
314 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
316 return rewriter.notifyMatchFailure(
318 "Cannot fold in tensor.pack if a tile dimension was transposed "
319 "with a non-tile dimension in linalg.transpose.");
322 for (
unsigned int i = srcRank; i < transposePerm.size(); ++i) {
323 int64_t remappedPosition = transposePerm[i] - srcRank;
324 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
325 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
328 Value output = packOp.createDestinationTensor(
329 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
330 newInnerDimsPosVec, newOuterDimsPermVec);
332 rewriter.replaceOpWithNewOp<PackOp>(
333 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
334 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
342 struct FoldConsumerPackWithProducerLinalgTransposeOp
343 :
public OpRewritePattern<PackOp> {
346 LogicalResult matchAndRewrite(PackOp packOp,
347 PatternRewriter &rewriter)
const override {
348 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
352 FailureOr<SmallVector<int64_t>> maybePerm =
353 getTransposeOpPermutation(linalgOp);
354 if (failed(maybePerm))
357 auto transposePermutation = maybePerm.value();
358 auto outerDimsPerm = packOp.getOuterDimsPerm();
359 auto innerDimsPos = packOp.getInnerDimsPos();
360 SmallVector<int64_t> newInnerDimsPosVec;
361 SmallVector<int64_t> newOuterDimsPermVec =
362 llvm::to_vector(transposePermutation);
364 if (!outerDimsPerm.empty())
369 for (
auto dim : innerDimsPos)
370 newInnerDimsPosVec.push_back(transposePermutation[dim]);
372 Value output = packOp.createDestinationTensor(
373 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
374 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
376 rewriter.replaceOpWithNewOp<PackOp>(
377 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
378 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
386 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
387 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
390 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
391 PatternRewriter &rewriter)
const override {
392 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
397 FailureOr<SmallVector<int64_t>> maybePerm =
398 getTransposeOpPermutation(linalgOp);
399 if (failed(maybePerm))
402 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
403 auto innerDimsPos = unPackOp.getInnerDimsPos();
404 SmallVector<int64_t> newInnerDimsPosVec;
405 SmallVector<int64_t> newOuterDimsPermVec =
410 for (
auto dim : innerDimsPos)
411 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
413 if (!outerDimsPerm.empty())
417 rewriter.replaceOpWithNewOp<UnPackOp>(
418 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
419 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
427 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
428 :
public OpRewritePattern<UnPackOp> {
431 LogicalResult matchAndRewrite(UnPackOp unPackOp,
432 PatternRewriter &rewriter)
const override {
433 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
437 FailureOr<SmallVector<int64_t>> maybePerm =
438 getTransposeOpPermutation(linalgOp);
439 if (failed(maybePerm))
442 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
447 SmallVector<int64_t> inverseTransposePerm =
449 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
450 auto innerDimsPos = unPackOp.getInnerDimsPos();
451 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
452 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
453 SmallVector<int64_t> newOuterDimsPermVec;
454 SmallVector<int64_t> newInnerDimsPosVec;
455 SmallVector<OpFoldResult> newMixedInnerTilesVec;
456 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
457 newOuterDimsPermVec, destRank))
458 return rewriter.notifyMatchFailure(
460 "Cannot fold in tensor.unpack if a tile dimension was transposed "
461 "with a non-tile dimension in linalg.transpose.");
464 for (
unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
465 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
466 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
467 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
471 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
472 Value output = rewriter.create<tensor::EmptyOp>(
473 unPackOp->getLoc(), unpackOpResultDims[0], elemType);
475 rewriter.replaceOpWithNewOp<UnPackOp>(
476 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
477 newMixedInnerTilesVec, newOuterDimsPermVec);
485 patterns.
insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
486 FoldProducerPackWithConsumerLinalgTransposeOp,
487 FoldConsumerPackWithProducerLinalgTransposeOp,
488 FoldConsumerUnPackWithProducerLinalgTransposeOp,
489 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
494 patterns.
add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
@ Type
An inlay hint that for a type annotation.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...