20 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21 return llvm::count_if(
22 shape, [](int64_t v) {
return ShapedType::isDynamic(v) || v > 1; });
30 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31 ArrayRef<int64_t> srcShape,
32 ArrayRef<int64_t> innerPackTileSize) {
33 if (getNumGtOneDims(srcShape) > 1) {
34 return rewriter.notifyMatchFailure(
35 op,
"expects non-packed domain to have at most one non-unit dims");
39 if (getNumGtOneDims(innerPackTileSize) > 1) {
40 return rewriter.notifyMatchFailure(
41 op,
"expects at most one non-unit inner tiles");
48 static FailureOr<SmallVector<int64_t>>
49 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50 if (
auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
51 return SmallVector<int64_t>(transposeOp.getPermutation());
52 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
55 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
57 auto mapRange = linalgOp.getIndexingMapsArray();
58 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59 mapRange.front() == mapRange.back()) {
62 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
64 AffineMap outMap = mapRange.back();
65 AffineMap inMap = mapRange.front();
68 return llvm::map_to_vector(outMap.getResults(),
69 [&](AffineExpr expr) -> int64_t {
70 return *inMap.getResultPosition(expr);
75 struct SimplifyPackToExpandShape :
public OpRewritePattern<PackOp> {
79 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
81 ArrayRef<ReassociationIndices> reassociation)
const {
82 if (operand.getType() == newOperandType)
85 .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92 PackOp packOp)
const {
95 return rewriter.notifyMatchFailure(
97 "expects outer_dims_perm is empty or an identity permutation");
100 int64_t srcRank = packOp.getSourceRank();
101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103 return rewriter.notifyMatchFailure(
104 packOp,
"expects packing at the innermost dimension");
109 LogicalResult matchAndRewrite(PackOp packOp,
110 PatternRewriter &rewriter)
const override {
111 if (packOp.getPaddingValue())
112 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
114 RankedTensorType sourceType = packOp.getSourceType();
115 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117 packOp.getStaticTiles())) &&
118 !packOp.isLikePad()) {
122 RankedTensorType destType = packOp.getDestType();
127 FailureOr<Value> expanded =
128 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
130 if (failed(expanded)) {
131 return rewriter.notifyMatchFailure(
132 packOp,
"unable to expand source of tensor.pack");
134 rewriter.replaceOp(packOp, *expanded);
139 struct SimplifyUnPackToCollapseShape :
public OpRewritePattern<UnPackOp> {
142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143 Type newOperandType, ArrayAttr reassociation)
const {
144 if (operand.getType() == newOperandType)
146 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
147 operand, reassociation);
151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152 UnPackOp unpackOp)
const {
155 return rewriter.notifyMatchFailure(
157 "expects outer_dims_perm is empty or an identity permutation");
160 RankedTensorType sourceType = unpackOp.getSourceType();
161 RankedTensorType destType = unpackOp.getDestType();
162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163 return rewriter.notifyMatchFailure(unpackOp,
"expects static shapes");
165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167 return rewriter.notifyMatchFailure(
168 unpackOp,
"expects unpacking on the innermost dimension");
174 LogicalResult matchAndRewrite(UnPackOp unpackOp,
175 PatternRewriter &rewriter)
const override {
176 RankedTensorType destType = unpackOp.getDestType();
177 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179 unpackOp.getStaticTiles())) &&
180 !unpackOp.isLikeUnPad()) {
184 RankedTensorType sourceType = unpackOp.getSourceType();
189 Value collapsed = insertCollapse(
190 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
192 rewriter.replaceOp(unpackOp, collapsed);
199 struct FoldPadWithPackOp :
public OpRewritePattern<PackOp> {
202 LogicalResult matchAndRewrite(PackOp packOp,
203 PatternRewriter &rewriter)
const override {
204 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
206 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
209 Value constantPaddingValue = padOp.getConstantPaddingValue();
210 if (!constantPaddingValue)
213 if (
auto paddingValue = packOp.getPaddingValue())
217 rewriter.replaceOpWithNewOp<PackOp>(
218 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
219 packOp.getMixedTiles(), constantPaddingValue,
220 packOp.getOuterDimsPerm());
227 struct FoldUnpackWithExtractSliceOp
228 :
public OpRewritePattern<tensor::ExtractSliceOp> {
231 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
232 PatternRewriter &rewriter)
const override {
233 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
237 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
238 return rewriter.notifyMatchFailure(
239 sliceOp,
"rank-reduced folding is not supported");
245 return rewriter.notifyMatchFailure(
246 sliceOp,
"expects offsets to be 0s and strides to be 1s");
250 Type elementType = unpackOp.getDestType().getElementType();
251 Value output = rewriter.create<tensor::EmptyOp>(
252 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
253 rewriter.replaceOpWithNewOp<UnPackOp>(
254 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
255 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
267 static bool checkAndPermute(ArrayRef<int64_t> permutation,
268 ArrayRef<int64_t> inVec,
269 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
271 for (
unsigned int i = 0; i < rank; ++i) {
272 int64_t remappedPosition = permutation[i];
273 if (remappedPosition >= rank)
276 remappedPosition = inVec[remappedPosition];
277 resVec.push_back(remappedPosition);
285 struct FoldProducerPackWithConsumerLinalgTransposeOp
286 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
289 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
290 PatternRewriter &rewriter)
const override {
291 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
296 FailureOr<SmallVector<int64_t>> maybePerm =
297 getTransposeOpPermutation(linalgOp);
298 if (failed(maybePerm))
302 auto mixedInnerTiles = packOp.getMixedTiles();
304 auto transposePerm = maybePerm.value();
305 SmallVector<int64_t> newOuterDimsPermVec;
306 SmallVector<int64_t> newInnerDimsPosVec;
307 SmallVector<OpFoldResult> newMixedInnerTilesVec;
308 int64_t srcRank = packOp.getSourceRank();
310 if (!checkAndPermute(transposePerm,
outerDimsPerm, newOuterDimsPermVec,
312 return rewriter.notifyMatchFailure(
314 "Cannot fold in tensor.pack if a tile dimension was transposed "
315 "with a non-tile dimension in linalg.transpose.");
318 for (
unsigned int i = srcRank; i < transposePerm.size(); ++i) {
319 int64_t remappedPosition = transposePerm[i] - srcRank;
320 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
321 newInnerDimsPosVec.push_back(
innerDimsPos[remappedPosition]);
324 Value output = packOp.createDestinationTensor(
325 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
326 newInnerDimsPosVec, newOuterDimsPermVec);
328 rewriter.replaceOpWithNewOp<PackOp>(
329 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
330 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
338 struct FoldConsumerPackWithProducerLinalgTransposeOp
339 :
public OpRewritePattern<PackOp> {
342 LogicalResult matchAndRewrite(PackOp packOp,
343 PatternRewriter &rewriter)
const override {
344 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
348 FailureOr<SmallVector<int64_t>> maybePerm =
349 getTransposeOpPermutation(linalgOp);
350 if (failed(maybePerm))
353 auto transposePermutation = maybePerm.value();
356 SmallVector<int64_t> newInnerDimsPosVec;
357 SmallVector<int64_t> newOuterDimsPermVec =
358 llvm::to_vector(transposePermutation);
366 newInnerDimsPosVec.push_back(transposePermutation[dim]);
368 Value output = packOp.createDestinationTensor(
369 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
370 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
372 rewriter.replaceOpWithNewOp<PackOp>(
373 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
374 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
382 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
386 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
387 PatternRewriter &rewriter)
const override {
388 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
393 FailureOr<SmallVector<int64_t>> maybePerm =
394 getTransposeOpPermutation(linalgOp);
395 if (failed(maybePerm))
400 SmallVector<int64_t> newInnerDimsPosVec;
401 SmallVector<int64_t> newOuterDimsPermVec =
407 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
413 rewriter.replaceOpWithNewOp<UnPackOp>(
414 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
415 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
423 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424 :
public OpRewritePattern<UnPackOp> {
427 LogicalResult matchAndRewrite(UnPackOp unPackOp,
428 PatternRewriter &rewriter)
const override {
429 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
433 FailureOr<SmallVector<int64_t>> maybePerm =
434 getTransposeOpPermutation(linalgOp);
435 if (failed(maybePerm))
438 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
443 SmallVector<int64_t> inverseTransposePerm =
447 int64_t destRank = unPackOp.getSourceRank() -
innerDimsPos.size();
448 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
449 SmallVector<int64_t> newOuterDimsPermVec;
450 SmallVector<int64_t> newInnerDimsPosVec;
451 SmallVector<OpFoldResult> newMixedInnerTilesVec;
453 newOuterDimsPermVec, destRank))
454 return rewriter.notifyMatchFailure(
456 "Cannot fold in tensor.unpack if a tile dimension was transposed "
457 "with a non-tile dimension in linalg.transpose.");
460 for (
unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
461 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
462 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
463 newInnerDimsPosVec.push_back(
innerDimsPos[remappedPosition]);
467 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
468 Value output = rewriter.create<tensor::EmptyOp>(
469 unPackOp->getLoc(), unpackOpResultDims[0], elemType);
471 rewriter.replaceOpWithNewOp<UnPackOp>(
472 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
473 newMixedInnerTilesVec, newOuterDimsPermVec);
481 struct FoldEmptyTensorWithPackOp :
public OpRewritePattern<PackOp> {
484 LogicalResult matchAndRewrite(PackOp packOp,
485 PatternRewriter &rewriter)
const override {
487 auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
493 if (packOp.getPaddingValue())
494 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
497 rewriter.replaceOp(packOp, packOp.getDest());
505 struct FoldEmptyTensorWithUnPackOp :
public OpRewritePattern<UnPackOp> {
508 LogicalResult matchAndRewrite(UnPackOp unPackOp,
509 PatternRewriter &rewriter)
const override {
511 auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
516 rewriter.replaceOp(unPackOp, unPackOp.getDest());
525 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526 FoldProducerPackWithConsumerLinalgTransposeOp,
527 FoldConsumerPackWithProducerLinalgTransposeOp,
528 FoldConsumerUnPackWithProducerLinalgTransposeOp,
529 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
534 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
540 patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...