MLIR  21.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
14 
15 namespace mlir {
16 namespace linalg {
17 namespace {
18 
19 /// Returns the number of shape sizes that is either dynamic or greater than 1.
20 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21  return llvm::count_if(
22  shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
23 }
24 
25 /// Returns success() if there is only 1 dimension size in non-packed domain
26 /// being greater than 1 and packing only happens on the dimension.
27 /// Note: this method should only be used by pack/unpack to reshape conversion.
28 /// It assumes that non-unit inner tile size must be used by the non-unit
29 /// dimension.
30 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31  ArrayRef<int64_t> srcShape,
32  ArrayRef<int64_t> innerPackTileSize) {
33  if (getNumGtOneDims(srcShape) > 1) {
34  return rewriter.notifyMatchFailure(
35  op, "expects non-packed domain to have at most one non-unit dims");
36  }
37  // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38  // will faill on getting reassociation maps.
39  if (getNumGtOneDims(innerPackTileSize) > 1) {
40  return rewriter.notifyMatchFailure(
41  op, "expects at most one non-unit inner tiles");
42  }
43  return success();
44 }
45 
46 // If the `linalgOp` represents a transpose, return the permutation vector for
47 // the transpose. Otherwise, return failure.
48 static FailureOr<SmallVector<int64_t>>
49 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
51  return SmallVector<int64_t>(transposeOp.getPermutation());
52  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
53  return failure();
54 
55  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
56  return failure();
57  auto mapRange = linalgOp.getIndexingMapsArray();
58  if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59  mapRange.front() == mapRange.back()) {
60  return failure();
61  }
62  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
63  return failure();
64  AffineMap outMap = mapRange.back();
65  AffineMap inMap = mapRange.front();
66  // To get the permutation, look at each output index and find which
67  // dimension in the input we're reading from for that index.
68  return llvm::map_to_vector(outMap.getResults(),
69  [&](AffineExpr expr) -> int64_t {
70  return *inMap.getResultPosition(expr);
71  });
72 }
73 
74 /// Packing one-dimensional tensor can be expressed as an expand shape op.
75 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
77 
78  FailureOr<Value>
79  insertExpand(RewriterBase &rewriter, Location loc, Value operand,
80  Type newOperandType,
81  ArrayRef<ReassociationIndices> reassociation) const {
82  if (operand.getType() == newOperandType)
83  return operand;
84  return rewriter
85  .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
86  reassociation)
87  .getResult();
88  }
89 
90  /// Returns success() if it is only packing on the innermost dimension.
91  LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92  PackOp packOp) const {
93  auto outerDimsPerm = packOp.getOuterDimsPerm();
95  return rewriter.notifyMatchFailure(
96  packOp,
97  "expects outer_dims_perm is empty or an identity permutation");
98  }
99 
100  int64_t srcRank = packOp.getSourceRank();
101  ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103  return rewriter.notifyMatchFailure(
104  packOp, "expects packing at the innermost dimension");
105  }
106  return success();
107  }
108 
109  LogicalResult matchAndRewrite(PackOp packOp,
110  PatternRewriter &rewriter) const override {
111  if (packOp.getPaddingValue())
112  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113 
114  RankedTensorType sourceType = packOp.getSourceType();
115  if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116  failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117  packOp.getStaticTiles())) &&
118  !packOp.isLikePad()) {
119  return failure();
120  }
121 
122  RankedTensorType destType = packOp.getDestType();
123  auto reassociation =
124  getReassociationIndicesForReshape(sourceType, destType);
125  if (!reassociation)
126  return failure();
127  FailureOr<Value> expanded =
128  insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
129  *reassociation);
130  if (failed(expanded)) {
131  return rewriter.notifyMatchFailure(
132  packOp, "unable to expand source of tensor.pack");
133  }
134  rewriter.replaceOp(packOp, *expanded);
135  return success();
136  }
137 };
138 
139 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
141 
142  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143  Type newOperandType, ArrayAttr reassociation) const {
144  if (operand.getType() == newOperandType)
145  return operand;
146  return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
147  operand, reassociation);
148  }
149 
150  /// Returns success() if it is unpacking on the innermost dimension.
151  LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152  UnPackOp unpackOp) const {
153  auto outerDimsPerm = unpackOp.getOuterDimsPerm();
155  return rewriter.notifyMatchFailure(
156  unpackOp,
157  "expects outer_dims_perm is empty or an identity permutation");
158  }
159 
160  RankedTensorType sourceType = unpackOp.getSourceType();
161  RankedTensorType destType = unpackOp.getDestType();
162  if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163  return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
164 
165  ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167  return rewriter.notifyMatchFailure(
168  unpackOp, "expects unpacking on the innermost dimension");
169  }
170 
171  return success();
172  }
173 
174  LogicalResult matchAndRewrite(UnPackOp unpackOp,
175  PatternRewriter &rewriter) const override {
176  RankedTensorType destType = unpackOp.getDestType();
177  if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178  failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179  unpackOp.getStaticTiles())) &&
180  !unpackOp.isLikeUnPad()) {
181  return failure();
182  }
183 
184  RankedTensorType sourceType = unpackOp.getSourceType();
185  auto reassociation =
186  getReassociationIndicesForReshape(sourceType, destType);
187  if (!reassociation)
188  return failure();
189  Value collapsed = insertCollapse(
190  rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
191  getReassociationIndicesAttribute(rewriter, *reassociation));
192  rewriter.replaceOp(unpackOp, collapsed);
193  return success();
194  }
195 };
196 
197 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198 /// the pad op has zero low paddings, or if `pack` has no padding values.
199 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
201 
202  LogicalResult matchAndRewrite(PackOp packOp,
203  PatternRewriter &rewriter) const override {
204  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
205 
206  if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
207  return failure();
208 
209  Value constantPaddingValue = padOp.getConstantPaddingValue();
210  if (!constantPaddingValue)
211  return failure();
212 
213  if (auto paddingValue = packOp.getPaddingValue())
214  if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
215  return failure();
216 
217  rewriter.replaceOpWithNewOp<PackOp>(
218  packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
219  packOp.getMixedTiles(), constantPaddingValue,
220  packOp.getOuterDimsPerm());
221  return success();
222  }
223 };
224 
225 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226 /// has extract_slice semantics.
227 struct FoldUnpackWithExtractSliceOp
228  : public OpRewritePattern<tensor::ExtractSliceOp> {
230 
231  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
232  PatternRewriter &rewriter) const override {
233  auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
234  if (!unpackOp)
235  return failure();
236 
237  if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
238  return rewriter.notifyMatchFailure(
239  sliceOp, "rank-reduced folding is not supported");
240  }
241 
242  // Check all offsets are zeros, and all strides are ones.
243  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
244  !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
245  return rewriter.notifyMatchFailure(
246  sliceOp, "expects offsets to be 0s and strides to be 1s");
247  }
248 
249  // Create a new empty output tensor.
250  Type elementType = unpackOp.getDestType().getElementType();
251  Value output = rewriter.create<tensor::EmptyOp>(
252  sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
253  rewriter.replaceOpWithNewOp<UnPackOp>(
254  sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
255  unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
256  return success();
257  }
258 };
259 
260 // Applies 'permutation' on 'inVec' and stores the result in resVec.
261 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
262 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
263 // greater than the rank specified. If it's so then return false.
264 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
265 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
266 // not allowed since `3` exceeds the value of the rank in the given range.
267 static bool checkAndPermute(ArrayRef<int64_t> permutation,
268  ArrayRef<int64_t> inVec,
269  SmallVectorImpl<int64_t> &resVec, int64_t rank) {
270 
271  for (unsigned int i = 0; i < rank; ++i) {
272  int64_t remappedPosition = permutation[i];
273  if (remappedPosition >= rank)
274  return false;
275  if (!inVec.empty())
276  remappedPosition = inVec[remappedPosition];
277  resVec.push_back(remappedPosition);
278  }
279 
280  return true;
281 }
282 
283 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
284 /// semantics.
285 struct FoldProducerPackWithConsumerLinalgTransposeOp
286  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
288 
289  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
290  PatternRewriter &rewriter) const override {
291  auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
292 
293  if (!packOp)
294  return failure();
295 
296  FailureOr<SmallVector<int64_t>> maybePerm =
297  getTransposeOpPermutation(linalgOp);
298  if (failed(maybePerm))
299  return failure();
300 
301  auto innerDimsPos = packOp.getInnerDimsPos();
302  auto mixedInnerTiles = packOp.getMixedTiles();
303  auto outerDimsPerm = packOp.getOuterDimsPerm();
304  auto transposePerm = maybePerm.value();
305  SmallVector<int64_t> newOuterDimsPermVec;
306  SmallVector<int64_t> newInnerDimsPosVec;
307  SmallVector<OpFoldResult> newMixedInnerTilesVec;
308  int64_t srcRank = packOp.getSourceRank();
309 
310  if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
311  srcRank))
312  return rewriter.notifyMatchFailure(
313  linalgOp,
314  "Cannot fold in tensor.pack if a tile dimension was transposed "
315  "with a non-tile dimension in linalg.transpose.");
316 
317  // Process transpose operation for tiled inner dimensions
318  for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
319  int64_t remappedPosition = transposePerm[i] - srcRank;
320  newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
321  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
322  }
323 
324  Value output = packOp.createDestinationTensor(
325  rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
326  newInnerDimsPosVec, newOuterDimsPermVec);
327 
328  rewriter.replaceOpWithNewOp<PackOp>(
329  linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
330  newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
331 
332  return success();
333  }
334 };
335 
336 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337 /// semantics.
338 struct FoldConsumerPackWithProducerLinalgTransposeOp
339  : public OpRewritePattern<PackOp> {
341 
342  LogicalResult matchAndRewrite(PackOp packOp,
343  PatternRewriter &rewriter) const override {
344  auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
345  if (!linalgOp)
346  return failure();
347 
348  FailureOr<SmallVector<int64_t>> maybePerm =
349  getTransposeOpPermutation(linalgOp);
350  if (failed(maybePerm))
351  return failure();
352 
353  auto transposePermutation = maybePerm.value();
354  auto outerDimsPerm = packOp.getOuterDimsPerm();
355  auto innerDimsPos = packOp.getInnerDimsPos();
356  SmallVector<int64_t> newInnerDimsPosVec;
357  SmallVector<int64_t> newOuterDimsPermVec =
358  llvm::to_vector(transposePermutation);
359 
360  if (!outerDimsPerm.empty())
361  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
362 
363  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
364  // permutation rank won't necessarily be equal in all cases.
365  for (auto dim : innerDimsPos)
366  newInnerDimsPosVec.push_back(transposePermutation[dim]);
367 
368  Value output = packOp.createDestinationTensor(
369  rewriter, packOp.getLoc(), linalgOp->getOperand(0),
370  packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
371 
372  rewriter.replaceOpWithNewOp<PackOp>(
373  packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
374  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
375 
376  return success();
377  }
378 };
379 
380 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
381 /// transpose semantics.
382 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
385 
386  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
387  PatternRewriter &rewriter) const override {
388  auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
389 
390  if (!unPackOp)
391  return failure();
392 
393  FailureOr<SmallVector<int64_t>> maybePerm =
394  getTransposeOpPermutation(linalgOp);
395  if (failed(maybePerm))
396  return failure();
397 
398  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
399  auto innerDimsPos = unPackOp.getInnerDimsPos();
400  SmallVector<int64_t> newInnerDimsPosVec;
401  SmallVector<int64_t> newOuterDimsPermVec =
402  invertPermutationVector(maybePerm.value());
403 
404  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
405  // permutation rank won't necessarily be equal in all cases.
406  for (auto dim : innerDimsPos)
407  newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
408 
409  if (!outerDimsPerm.empty())
410  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
411 
412  // Reuse the destination of the transpose op.
413  rewriter.replaceOpWithNewOp<UnPackOp>(
414  linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
415  newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
416 
417  return success();
418  }
419 };
420 
421 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
422 /// transpose semantics.
423 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424  : public OpRewritePattern<UnPackOp> {
426 
427  LogicalResult matchAndRewrite(UnPackOp unPackOp,
428  PatternRewriter &rewriter) const override {
429  auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
430  if (!linalgOp)
431  return failure();
432 
433  FailureOr<SmallVector<int64_t>> maybePerm =
434  getTransposeOpPermutation(linalgOp);
435  if (failed(maybePerm))
436  return failure();
437 
438  SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
439  if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
440  return failure();
441  }
442 
443  SmallVector<int64_t> inverseTransposePerm =
444  invertPermutationVector(maybePerm.value());
445  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
446  auto innerDimsPos = unPackOp.getInnerDimsPos();
447  int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
448  auto mixedInnerTilesVec = unPackOp.getMixedTiles();
449  SmallVector<int64_t> newOuterDimsPermVec;
450  SmallVector<int64_t> newInnerDimsPosVec;
451  SmallVector<OpFoldResult> newMixedInnerTilesVec;
452  if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
453  newOuterDimsPermVec, destRank))
454  return rewriter.notifyMatchFailure(
455  unPackOp,
456  "Cannot fold in tensor.unpack if a tile dimension was transposed "
457  "with a non-tile dimension in linalg.transpose.");
458 
459  // Process transpose operation for tiled inner dimensions
460  for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
461  int64_t remappedPosition = inverseTransposePerm[i] - destRank;
462  newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
463  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
464  }
465 
466  auto elemType =
467  cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
468  Value output = rewriter.create<tensor::EmptyOp>(
469  unPackOp->getLoc(), unpackOpResultDims[0], elemType);
470 
471  rewriter.replaceOpWithNewOp<UnPackOp>(
472  unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
473  newMixedInnerTilesVec, newOuterDimsPermVec);
474 
475  return success();
476  }
477 };
478 
479 /// tensor.empty does not define any tensor contents, so an unpadded pack
480 /// can be folded away.
481 struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
483 
484  LogicalResult matchAndRewrite(PackOp packOp,
485  PatternRewriter &rewriter) const override {
486  // Check for tensor.empty source.
487  auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
488  if (!emptyOp)
489  return failure();
490 
491  // Check for padding.
492  // Packing with padding cannot be simply removed.
493  if (packOp.getPaddingValue())
494  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
495 
496  // Replace the pack directly with its destination.
497  rewriter.replaceOp(packOp, packOp.getDest());
498 
499  return success();
500  }
501 };
502 
503 /// tensor.empty does not define any tensor contents, so an unpack
504 /// can be folded away.
505 struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
507 
508  LogicalResult matchAndRewrite(UnPackOp unPackOp,
509  PatternRewriter &rewriter) const override {
510  // Check for tensor.empty source.
511  auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
512  if (!emptyOp)
513  return failure();
514 
515  // Replace the unpack directly with its destination.
516  rewriter.replaceOp(unPackOp, unPackOp.getDest());
517 
518  return success();
519  }
520 };
521 
522 } // namespace
523 
525  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526  FoldProducerPackWithConsumerLinalgTransposeOp,
527  FoldConsumerPackWithProducerLinalgTransposeOp,
528  FoldConsumerUnPackWithProducerLinalgTransposeOp,
529  FoldProducerUnPackWithConsumerLinalgTransposeOp>(
530  patterns.getContext());
531 }
532 
534  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
535  patterns.getContext());
536 }
537 
540  patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
541  patterns.getContext());
542 }
543 
544 } // namespace linalg
545 } // namespace mlir
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4497
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4495
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:336
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323