20 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21 return llvm::count_if(
22 shape, [](int64_t v) {
return ShapedType::isDynamic(v) || v > 1; });
30 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31 ArrayRef<int64_t> srcShape,
32 ArrayRef<int64_t> innerPackTileSize) {
33 if (getNumGtOneDims(srcShape) > 1) {
34 return rewriter.notifyMatchFailure(
35 op,
"expects non-packed domain to have at most one non-unit dims");
39 if (getNumGtOneDims(innerPackTileSize) > 1) {
40 return rewriter.notifyMatchFailure(
41 op,
"expects at most one non-unit inner tiles");
48 static FailureOr<SmallVector<int64_t>>
49 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50 if (
auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
51 return SmallVector<int64_t>(transposeOp.getPermutation());
52 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
55 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
57 auto mapRange = linalgOp.getIndexingMapsArray();
58 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59 mapRange.front() == mapRange.back()) {
62 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
64 AffineMap outMap = mapRange.back();
65 AffineMap inMap = mapRange.front();
68 return llvm::map_to_vector(outMap.getResults(),
69 [&](AffineExpr expr) -> int64_t {
70 return *inMap.getResultPosition(expr);
75 struct SimplifyPackToExpandShape :
public OpRewritePattern<PackOp> {
79 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
81 ArrayRef<ReassociationIndices> reassociation)
const {
82 if (operand.getType() == newOperandType)
85 .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92 PackOp packOp)
const {
93 auto outerDimsPerm = packOp.getOuterDimsPerm();
95 return rewriter.notifyMatchFailure(
97 "expects outer_dims_perm is empty or an identity permutation");
100 int64_t srcRank = packOp.getSourceRank();
101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103 return rewriter.notifyMatchFailure(
104 packOp,
"expects packing at the innermost dimension");
109 LogicalResult matchAndRewrite(PackOp packOp,
110 PatternRewriter &rewriter)
const override {
111 if (packOp.getPaddingValue())
112 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
114 RankedTensorType sourceType = packOp.getSourceType();
115 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117 packOp.getStaticTiles())) &&
118 !packOp.isLikePad()) {
122 RankedTensorType destType = packOp.getDestType();
127 FailureOr<Value> expanded =
128 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
130 if (failed(expanded)) {
131 return rewriter.notifyMatchFailure(
132 packOp,
"unable to expand source of tensor.pack");
134 rewriter.replaceOp(packOp, *expanded);
139 struct SimplifyUnPackToCollapseShape :
public OpRewritePattern<UnPackOp> {
142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143 Type newOperandType, ArrayAttr reassociation)
const {
144 if (operand.getType() == newOperandType)
146 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
147 operand, reassociation);
151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152 UnPackOp unpackOp)
const {
153 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
155 return rewriter.notifyMatchFailure(
157 "expects outer_dims_perm is empty or an identity permutation");
160 RankedTensorType sourceType = unpackOp.getSourceType();
161 RankedTensorType destType = unpackOp.getDestType();
162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163 return rewriter.notifyMatchFailure(unpackOp,
"expects static shapes");
165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167 return rewriter.notifyMatchFailure(
168 unpackOp,
"expects unpacking on the innermost dimension");
174 LogicalResult matchAndRewrite(UnPackOp unpackOp,
175 PatternRewriter &rewriter)
const override {
176 RankedTensorType destType = unpackOp.getDestType();
177 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179 unpackOp.getStaticTiles())) &&
180 !unpackOp.isLikeUnPad()) {
184 RankedTensorType sourceType = unpackOp.getSourceType();
189 Value collapsed = insertCollapse(
190 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
192 rewriter.replaceOp(unpackOp, collapsed);
199 struct FoldPadWithPackOp :
public OpRewritePattern<PackOp> {
202 LogicalResult matchAndRewrite(PackOp packOp,
203 PatternRewriter &rewriter)
const override {
204 auto padOp = packOp.getSource().getDefiningOp<PadOp>();
206 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
209 Value constantPaddingValue = padOp.getConstantPaddingValue();
210 if (!constantPaddingValue)
213 if (
auto paddingValue = packOp.getPaddingValue())
217 rewriter.replaceOpWithNewOp<PackOp>(
218 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
219 packOp.getMixedTiles(), constantPaddingValue,
220 packOp.getOuterDimsPerm());
227 struct FoldUnpackWithExtractSliceOp :
public OpRewritePattern<ExtractSliceOp> {
230 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
231 PatternRewriter &rewriter)
const override {
232 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
236 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
237 return rewriter.notifyMatchFailure(
238 sliceOp,
"rank-reduced folding is not supported");
244 return rewriter.notifyMatchFailure(
245 sliceOp,
"expects offsets to be 0s and strides to be 1s");
249 Type elementType = unpackOp.getDestType().getElementType();
250 Value output = rewriter.create<EmptyOp>(
251 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
252 rewriter.replaceOpWithNewOp<UnPackOp>(
253 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
254 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
266 static bool checkAndPermute(ArrayRef<int64_t> permutation,
267 ArrayRef<int64_t> inVec,
268 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
270 for (
unsigned int i = 0; i < rank; ++i) {
271 int64_t remappedPosition = permutation[i];
272 if (remappedPosition >= rank)
275 remappedPosition = inVec[remappedPosition];
276 resVec.push_back(remappedPosition);
284 struct FoldProducerPackWithConsumerLinalgTransposeOp
285 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
288 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
289 PatternRewriter &rewriter)
const override {
290 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
295 FailureOr<SmallVector<int64_t>> maybePerm =
296 getTransposeOpPermutation(linalgOp);
297 if (failed(maybePerm))
300 auto innerDimsPos = packOp.getInnerDimsPos();
301 auto mixedInnerTiles = packOp.getMixedTiles();
302 auto outerDimsPerm = packOp.getOuterDimsPerm();
303 auto transposePerm = maybePerm.value();
304 SmallVector<int64_t> newOuterDimsPermVec;
305 SmallVector<int64_t> newInnerDimsPosVec;
306 SmallVector<OpFoldResult> newMixedInnerTilesVec;
307 int64_t srcRank = packOp.getSourceRank();
309 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
311 return rewriter.notifyMatchFailure(
313 "Cannot fold in tensor.pack if a tile dimension was transposed "
314 "with a non-tile dimension in linalg.transpose.");
317 for (
unsigned int i = srcRank; i < transposePerm.size(); ++i) {
318 int64_t remappedPosition = transposePerm[i] - srcRank;
319 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
320 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
323 Value output = packOp.createDestinationTensor(
324 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
325 newInnerDimsPosVec, newOuterDimsPermVec);
327 rewriter.replaceOpWithNewOp<PackOp>(
328 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
329 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
337 struct FoldConsumerPackWithProducerLinalgTransposeOp
338 :
public OpRewritePattern<PackOp> {
341 LogicalResult matchAndRewrite(PackOp packOp,
342 PatternRewriter &rewriter)
const override {
343 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
347 FailureOr<SmallVector<int64_t>> maybePerm =
348 getTransposeOpPermutation(linalgOp);
349 if (failed(maybePerm))
352 auto transposePermutation = maybePerm.value();
353 auto outerDimsPerm = packOp.getOuterDimsPerm();
354 auto innerDimsPos = packOp.getInnerDimsPos();
355 SmallVector<int64_t> newInnerDimsPosVec;
356 SmallVector<int64_t> newOuterDimsPermVec =
357 llvm::to_vector(transposePermutation);
359 if (!outerDimsPerm.empty())
364 for (
auto dim : innerDimsPos)
365 newInnerDimsPosVec.push_back(transposePermutation[dim]);
367 Value output = packOp.createDestinationTensor(
368 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
369 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
371 rewriter.replaceOpWithNewOp<PackOp>(
372 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
373 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
381 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
382 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
385 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
386 PatternRewriter &rewriter)
const override {
387 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
392 FailureOr<SmallVector<int64_t>> maybePerm =
393 getTransposeOpPermutation(linalgOp);
394 if (failed(maybePerm))
397 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
398 auto innerDimsPos = unPackOp.getInnerDimsPos();
399 SmallVector<int64_t> newInnerDimsPosVec;
400 SmallVector<int64_t> newOuterDimsPermVec =
405 for (
auto dim : innerDimsPos)
406 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
408 if (!outerDimsPerm.empty())
412 rewriter.replaceOpWithNewOp<UnPackOp>(
413 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
414 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
422 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
423 :
public OpRewritePattern<UnPackOp> {
426 LogicalResult matchAndRewrite(UnPackOp unPackOp,
427 PatternRewriter &rewriter)
const override {
428 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
432 FailureOr<SmallVector<int64_t>> maybePerm =
433 getTransposeOpPermutation(linalgOp);
434 if (failed(maybePerm))
437 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
442 SmallVector<int64_t> inverseTransposePerm =
444 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
445 auto innerDimsPos = unPackOp.getInnerDimsPos();
446 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
447 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
448 SmallVector<int64_t> newOuterDimsPermVec;
449 SmallVector<int64_t> newInnerDimsPosVec;
450 SmallVector<OpFoldResult> newMixedInnerTilesVec;
451 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
452 newOuterDimsPermVec, destRank))
453 return rewriter.notifyMatchFailure(
455 "Cannot fold in tensor.unpack if a tile dimension was transposed "
456 "with a non-tile dimension in linalg.transpose.");
459 for (
unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
460 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
461 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
462 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
466 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
467 Value output = rewriter.create<tensor::EmptyOp>(
468 unPackOp->getLoc(), unpackOpResultDims[0], elemType);
470 rewriter.replaceOpWithNewOp<UnPackOp>(
471 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
472 newMixedInnerTilesVec, newOuterDimsPermVec);
480 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
481 FoldProducerPackWithConsumerLinalgTransposeOp,
482 FoldConsumerPackWithProducerLinalgTransposeOp,
483 FoldConsumerUnPackWithProducerLinalgTransposeOp,
484 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
489 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
@ Type
An inlay hint that for a type annotation.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...