19 static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
25 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
26 return llvm::count_if(
27 shape, [](int64_t v) {
return ShapedType::isDynamic(v) || v > 1; });
35 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
36 ArrayRef<int64_t> srcShape,
37 ArrayRef<int64_t> innerPackTileSize) {
38 if (getNumGtOneDims(srcShape) > 1) {
39 return rewriter.notifyMatchFailure(
40 op,
"expects non-packed domain to have at most one non-unit dims");
44 if (getNumGtOneDims(innerPackTileSize) > 1) {
45 return rewriter.notifyMatchFailure(
46 op,
"expects at most one non-unit inner tiles");
52 struct SimplifyPackToExpandShape :
public OpRewritePattern<PackOp> {
55 Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
56 Type newOperandType, ArrayAttr reassociation)
const {
57 if (operand.getType() == newOperandType)
59 return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
64 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
65 PackOp packOp)
const {
66 auto outerDimsPerm = packOp.getOuterDimsPerm();
68 return rewriter.notifyMatchFailure(
70 "expects outer_dims_perm is empty or an identity permutation");
73 int64_t srcRank = packOp.getSourceRank();
74 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
75 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
76 return rewriter.notifyMatchFailure(
77 packOp,
"expects packing at the innermost dimension");
82 LogicalResult matchAndRewrite(PackOp packOp,
83 PatternRewriter &rewriter)
const override {
84 if (packOp.getPaddingValue())
85 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
87 RankedTensorType sourceType = packOp.getSourceType();
88 if (
failed(isPackOnInnerMostDim(rewriter, packOp)) &&
89 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
90 packOp.getStaticTiles()))) {
94 RankedTensorType destType = packOp.getDestType();
99 Value expanded = insertExpand(
100 rewriter, packOp.getLoc(), packOp.getSource(), destType,
102 rewriter.replaceOp(packOp, expanded);
107 struct SimplifyUnPackToCollapseShape :
public OpRewritePattern<UnPackOp> {
110 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
111 Type newOperandType, ArrayAttr reassociation)
const {
112 if (operand.getType() == newOperandType)
114 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
115 operand, reassociation);
119 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
120 UnPackOp unpackOp)
const {
121 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
123 return rewriter.notifyMatchFailure(
125 "expects outer_dims_perm is empty or an identity permutation");
128 RankedTensorType sourceType = unpackOp.getSourceType();
129 RankedTensorType destType = unpackOp.getDestType();
130 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
131 return rewriter.notifyMatchFailure(unpackOp,
"expects static shapes");
133 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
134 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
135 return rewriter.notifyMatchFailure(
136 unpackOp,
"expects unpacking on the innermost dimension");
142 LogicalResult matchAndRewrite(UnPackOp unpackOp,
143 PatternRewriter &rewriter)
const override {
144 RankedTensorType destType = unpackOp.getDestType();
145 if (
failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
146 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
147 unpackOp.getStaticTiles()))) {
151 RankedTensorType sourceType = unpackOp.getSourceType();
156 Value collapsed = insertCollapse(
157 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
159 rewriter.replaceOp(unpackOp, collapsed);
166 struct FoldPadWithPackOp :
public OpRewritePattern<PackOp> {
169 LogicalResult matchAndRewrite(PackOp packOp,
170 PatternRewriter &rewriter)
const override {
171 auto padOp = packOp.getSource().getDefiningOp<PadOp>();
173 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
176 Value constantPaddingValue = padOp.getConstantPaddingValue();
177 if (!constantPaddingValue)
180 if (
auto paddingValue = packOp.getPaddingValue())
184 rewriter.replaceOpWithNewOp<PackOp>(
185 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
186 packOp.getMixedTiles(), constantPaddingValue,
187 packOp.getOuterDimsPerm());
194 struct FoldUnpackWithExtractSliceOp :
public OpRewritePattern<ExtractSliceOp> {
197 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
198 PatternRewriter &rewriter)
const override {
199 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
203 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
204 return rewriter.notifyMatchFailure(
205 sliceOp,
"rank-reduced folding is not supported");
209 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
210 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
211 return rewriter.notifyMatchFailure(
212 sliceOp,
"expects offsets to be 0s and strides to be 1s");
216 Type elementType = unpackOp.getDestType().getElementType();
217 Value output = rewriter.create<EmptyOp>(
218 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
219 rewriter.replaceOpWithNewOp<UnPackOp>(
220 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
221 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
233 static bool checkAndPermute(ArrayRef<int64_t> permutation,
234 ArrayRef<int64_t> inVec,
235 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
237 for (
unsigned int i = 0; i < rank; ++i) {
238 int64_t remappedPosition = permutation[i];
240 if (!inVec.empty()) {
241 if (remappedPosition >= rank) {
244 remappedPosition = inVec[remappedPosition];
247 resVec.push_back(remappedPosition);
255 struct FoldProducerPackWithConsumerLinalgTransposeOp
256 :
public OpRewritePattern<linalg::TransposeOp> {
259 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
260 PatternRewriter &rewriter)
const override {
261 auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
266 auto innerDimsPos = packOp.getInnerDimsPos();
267 auto mixedInnerTiles = packOp.getMixedTiles();
268 auto outerDimsPerm = packOp.getOuterDimsPerm();
269 auto transposePerm = transposeOp.getPermutation();
270 SmallVector<int64_t> newOuterDimsPermVec;
271 SmallVector<int64_t> newInnerDimsPosVec;
272 SmallVector<OpFoldResult> newMixedInnerTilesVec;
273 int64_t srcRank = packOp.getSourceRank();
275 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
277 return rewriter.notifyMatchFailure(
279 "Cannot fold in tensor.pack if a tile dimension was transposed "
280 "with a non-tile dimension in linalg.transpose.");
283 for (
unsigned int i = srcRank; i < transposePerm.size(); ++i) {
284 int64_t remappedPosition = transposePerm[i] - srcRank;
285 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
286 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
289 Value output = packOp.createDestinationTensor(
290 rewriter, transposeOp.getLoc(), packOp.getSource(),
291 newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
293 rewriter.replaceOpWithNewOp<PackOp>(
294 transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
295 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
303 struct FoldConsumerPackWithProducerLinalgTransposeOp
304 :
public OpRewritePattern<PackOp> {
307 LogicalResult matchAndRewrite(PackOp packOp,
308 PatternRewriter &rewriter)
const override {
309 auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
314 auto transposePermutation = transposeOp.getPermutation();
315 auto outerDimsPerm = packOp.getOuterDimsPerm();
316 auto innerDimsPos = packOp.getInnerDimsPos();
317 SmallVector<int64_t> newInnerDimsPosVec;
318 SmallVector<int64_t> newOuterDimsPermVec =
319 llvm::to_vector(transposePermutation);
321 if (!outerDimsPerm.empty())
326 for (
auto dim : innerDimsPos)
327 newInnerDimsPosVec.push_back(transposePermutation[dim]);
329 Value output = packOp.createDestinationTensor(
330 rewriter, packOp.getLoc(), transposeOp.getOperand(0),
331 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
333 rewriter.replaceOpWithNewOp<PackOp>(
334 packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
335 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
343 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
344 :
public OpRewritePattern<linalg::TransposeOp> {
347 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
348 PatternRewriter &rewriter)
const override {
349 auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
354 auto transposePermutation = transposeOp.getPermutation();
355 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
356 auto innerDimsPos = unPackOp.getInnerDimsPos();
357 SmallVector<int64_t> newInnerDimsPosVec;
358 SmallVector<int64_t> newOuterDimsPermVec =
359 llvm::to_vector(transposePermutation);
361 if (!outerDimsPerm.empty())
366 for (
auto dim : innerDimsPos)
367 newInnerDimsPosVec.push_back(transposePermutation[dim]);
369 Value output = unPackOp.createDestinationTensor(
370 rewriter, transposeOp.getLoc(), unPackOp.getSource(),
371 unPackOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
373 rewriter.replaceOpWithNewOp<UnPackOp>(
374 transposeOp, unPackOp.getSource(), output, newInnerDimsPosVec,
375 unPackOp.getMixedTiles(), newOuterDimsPermVec);
383 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
384 :
public OpRewritePattern<UnPackOp> {
387 LogicalResult matchAndRewrite(UnPackOp unPackOp,
388 PatternRewriter &rewriter)
const override {
390 unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
395 auto transposePermutation = transposeOp.getPermutation();
396 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
397 auto innerDimsPos = unPackOp.getInnerDimsPos();
398 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
399 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
400 SmallVector<int64_t> newOuterDimsPermVec;
401 SmallVector<int64_t> newInnerDimsPosVec;
402 SmallVector<OpFoldResult> newMixedInnerTilesVec;
404 if (!checkAndPermute(transposePermutation, outerDimsPerm,
405 newOuterDimsPermVec, destRank))
406 return rewriter.notifyMatchFailure(
408 "Cannot fold in tensor.unpack if a tile dimension was transposed "
409 "with a non-tile dimension in linalg.transpose.");
412 for (
unsigned int i = destRank; i < transposePermutation.size(); ++i) {
413 int64_t remappedPosition = transposePermutation[i] - destRank;
414 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
415 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
418 Value output = unPackOp.createDestinationTensor(
419 rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
420 newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
422 rewriter.replaceOpWithNewOp<UnPackOp>(
423 unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
424 newMixedInnerTilesVec, newOuterDimsPermVec);
432 patterns.
insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
433 FoldProducerPackWithConsumerLinalgTransposeOp,
434 FoldConsumerPackWithProducerLinalgTransposeOp,
435 FoldConsumerUnPackWithProducerLinalgTransposeOp,
436 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
441 patterns.
add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
@ Type
An inlay hint that for a type annotation.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...