MLIR  22.0.0git
VectorDropLeadUnitDim.cpp
Go to the documentation of this file.
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <numeric>
10 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 #define DEBUG_TYPE "vector-drop-unit-dim"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 
25 // Trims leading one dimensions from `oldType` and returns the result type.
26 // Returns `vector<1xT>` if `oldType` only has one element.
27 static VectorType trimLeadingOneDims(VectorType oldType) {
28  ArrayRef<int64_t> oldShape = oldType.getShape();
29  ArrayRef<int64_t> newShape = oldShape;
30 
31  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
32  ArrayRef<bool> newScalableDims = oldScalableDims;
33 
34  while (!newShape.empty() && newShape.front() == 1 &&
35  !newScalableDims.front()) {
36  newShape = newShape.drop_front(1);
37  newScalableDims = newScalableDims.drop_front(1);
38  }
39 
40  // Make sure we have at least 1 dimension per vector type requirements.
41  if (newShape.empty()) {
42  newShape = oldShape.take_back();
43  newScalableDims = oldType.getScalableDims().take_back();
44  }
45  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
46 }
47 
48 /// Return a smallVector of size `rank` containing all zeros.
49 static SmallVector<int64_t> splatZero(int64_t rank) {
50  return SmallVector<int64_t>(rank, 0);
51 }
52 namespace {
53 
54 // Casts away leading one dimensions in vector.extract_strided_slice's vector
55 // input by inserting vector.broadcast.
56 struct CastAwayExtractStridedSliceLeadingOneDim
57  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
58  using Base::Base;
59 
60  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
61  PatternRewriter &rewriter) const override {
62  // vector.extract_strided_slice requires the input and output vector to have
63  // the same rank. Here we drop leading one dimensions from the input vector
64  // type to make sure we don't cause mismatch.
65  VectorType oldSrcType = extractOp.getSourceVectorType();
66  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
67 
68  if (newSrcType.getRank() == oldSrcType.getRank())
69  return failure();
70 
71  int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
72 
73  VectorType oldDstType = extractOp.getType();
74  VectorType newDstType =
75  VectorType::get(oldDstType.getShape().drop_front(dropCount),
76  oldDstType.getElementType(),
77  oldDstType.getScalableDims().drop_front(dropCount));
78 
79  Location loc = extractOp.getLoc();
80 
81  Value newSrcVector = vector::ExtractOp::create(
82  rewriter, loc, extractOp.getSource(), splatZero(dropCount));
83 
84  // The offsets/sizes/strides attribute can have a less number of elements
85  // than the input vector's rank: it is meant for the leading dimensions.
86  auto newOffsets = rewriter.getArrayAttr(
87  extractOp.getOffsets().getValue().drop_front(dropCount));
88  auto newSizes = rewriter.getArrayAttr(
89  extractOp.getSizes().getValue().drop_front(dropCount));
90  auto newStrides = rewriter.getArrayAttr(
91  extractOp.getStrides().getValue().drop_front(dropCount));
92 
93  auto newExtractOp = vector::ExtractStridedSliceOp::create(
94  rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
95  newStrides);
96 
97  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
98  newExtractOp);
99 
100  return success();
101  }
102 };
103 
104 // Casts away leading one dimensions in vector.insert_strided_slice's vector
105 // inputs by inserting vector.broadcast.
106 struct CastAwayInsertStridedSliceLeadingOneDim
107  : public OpRewritePattern<vector::InsertStridedSliceOp> {
108  using Base::Base;
109 
110  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
111  PatternRewriter &rewriter) const override {
112  VectorType oldSrcType = insertOp.getSourceVectorType();
113  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
114  VectorType oldDstType = insertOp.getDestVectorType();
115  VectorType newDstType = trimLeadingOneDims(oldDstType);
116 
117  int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
118  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
119  if (srcDropCount == 0 && dstDropCount == 0)
120  return failure();
121 
122  // Trim leading one dimensions from both operands.
123  Location loc = insertOp.getLoc();
124 
125  Value newSrcVector = vector::ExtractOp::create(
126  rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
127  Value newDstVector = vector::ExtractOp::create(
128  rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
129 
130  auto newOffsets = rewriter.getArrayAttr(
131  insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
132  auto newStrides = rewriter.getArrayAttr(
133  insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
134 
135  auto newInsertOp = vector::InsertStridedSliceOp::create(
136  rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
137  newStrides);
138 
139  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
140  newInsertOp);
141 
142  return success();
143  }
144 };
145 
146 // Casts away leading one dimensions in vector.insert's vector inputs by
147 // inserting vector.broadcast.
148 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
149  using Base::Base;
150 
151  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
152  PatternRewriter &rewriter) const override {
153  Type oldSrcType = insertOp.getValueToStoreType();
154  Type newSrcType = oldSrcType;
155  int64_t oldSrcRank = 0, newSrcRank = 0;
156  if (auto type = dyn_cast<VectorType>(oldSrcType)) {
157  newSrcType = trimLeadingOneDims(type);
158  oldSrcRank = type.getRank();
159  newSrcRank = cast<VectorType>(newSrcType).getRank();
160  }
161 
162  VectorType oldDstType = insertOp.getDestVectorType();
163  VectorType newDstType = trimLeadingOneDims(oldDstType);
164 
165  int64_t srcDropCount = oldSrcRank - newSrcRank;
166  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
167  if (srcDropCount == 0 && dstDropCount == 0)
168  return failure();
169 
170  // Trim leading one dimensions from both operands.
171  Location loc = insertOp.getLoc();
172 
173  Value newSrcVector = insertOp.getValueToStore();
174  if (oldSrcRank != 0) {
175  newSrcVector = vector::ExtractOp::create(
176  rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
177  }
178  Value newDstVector = vector::ExtractOp::create(
179  rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
180 
181  // New position rank needs to be computed in two steps: (1) if destination
182  // type has leading unit dims, we also trim the position array accordingly,
183  // then (2) if source type also has leading unit dims, we need to append
184  // zeroes to the position array accordingly.
185  unsigned oldPosRank = insertOp.getNumIndices();
186  unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
187  SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
188  SmallVector<OpFoldResult> newPosition =
189  llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
190  newPosition.resize(newDstType.getRank() - newSrcRank,
191  rewriter.getI64IntegerAttr(0));
192 
193  auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
194  newDstVector, newPosition);
195 
196  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
197  newInsertOp);
198 
199  return success();
200  }
201 };
202 
203 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
204  VectorType newType, AffineMap newMap,
205  VectorType oldMaskType) {
206  // Infer the type of the new mask from the new map.
207  VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
208 
209  // If the new mask is broadcastable to the old result type, we can safely
210  // use a `vector.extract` to get the new mask. Otherwise the best we can
211  // do is shape cast.
212  if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
214  int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
215  return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim));
216  }
217  return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
218 }
219 
220 // Turns vector.transfer_read on vector with leading 1 dimensions into
221 // vector.shape_cast followed by vector.transfer_read on vector without leading
222 // 1 dimensions.
223 struct CastAwayTransferReadLeadingOneDim
224  : public OpRewritePattern<vector::TransferReadOp> {
225  using Base::Base;
226 
227  LogicalResult matchAndRewrite(vector::TransferReadOp read,
228  PatternRewriter &rewriter) const override {
229  // TODO(#78787): Not supported masked op yet.
230  if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
231  return failure();
232  // TODO: support 0-d corner case.
233  if (read.getTransferRank() == 0)
234  return failure();
235 
236  auto shapedType = cast<ShapedType>(read.getBase().getType());
237  if (shapedType.getElementType() != read.getVectorType().getElementType())
238  return failure();
239 
240  VectorType oldType = read.getVectorType();
241  VectorType newType = trimLeadingOneDims(oldType);
242 
243  if (newType == oldType)
244  return failure();
245 
246  AffineMap oldMap = read.getPermutationMap();
247  ArrayRef<AffineExpr> newResults =
248  oldMap.getResults().take_back(newType.getRank());
249  AffineMap newMap =
250  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
251  rewriter.getContext());
252 
253  ArrayAttr inBoundsAttr;
254  if (read.getInBounds())
255  inBoundsAttr = rewriter.getArrayAttr(
256  read.getInBoundsAttr().getValue().take_back(newType.getRank()));
257 
258  Value mask = Value();
259  if (read.getMask()) {
260  VectorType maskType = read.getMaskType();
261  mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
262  newType, newMap, maskType);
263  }
264 
265  auto newRead = vector::TransferReadOp::create(
266  rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
267  AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
268  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
269 
270  return success();
271  }
272 };
273 
274 // Turns vector.transfer_write on vector with leading 1 dimensions into
275 // vector.shape_cast followed by vector.transfer_write on vector without leading
276 // 1 dimensions.
277 struct CastAwayTransferWriteLeadingOneDim
278  : public OpRewritePattern<vector::TransferWriteOp> {
279  using Base::Base;
280 
281  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
282  PatternRewriter &rewriter) const override {
283  // TODO(#78787): Not supported masked op yet.
284  if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
285  return failure();
286  // TODO: support 0-d corner case.
287  if (write.getTransferRank() == 0)
288  return failure();
289 
290  auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
291  if (shapedType.getElementType() != write.getVectorType().getElementType())
292  return failure();
293 
294  VectorType oldType = write.getVectorType();
295  VectorType newType = trimLeadingOneDims(oldType);
296  if (newType == oldType)
297  return failure();
298  int64_t dropDim = oldType.getRank() - newType.getRank();
299 
300  AffineMap oldMap = write.getPermutationMap();
301  ArrayRef<AffineExpr> newResults =
302  oldMap.getResults().take_back(newType.getRank());
303  AffineMap newMap =
304  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
305  rewriter.getContext());
306 
307  ArrayAttr inBoundsAttr;
308  if (write.getInBounds())
309  inBoundsAttr = rewriter.getArrayAttr(
310  write.getInBoundsAttr().getValue().take_back(newType.getRank()));
311 
312  auto newVector = vector::ExtractOp::create(
313  rewriter, write.getLoc(), write.getVector(), splatZero(dropDim));
314 
315  if (write.getMask()) {
316  VectorType maskType = write.getMaskType();
317  Value newMask = dropUnitDimsFromMask(
318  rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
319  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
320  write, newVector, write.getBase(), write.getIndices(),
321  AffineMapAttr::get(newMap), newMask, inBoundsAttr);
322  return success();
323  }
324 
325  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
326  write, newVector, write.getBase(), write.getIndices(),
327  AffineMapAttr::get(newMap), inBoundsAttr);
328  return success();
329  }
330 };
331 
332 } // namespace
333 
334 FailureOr<Value>
335 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
336  MaskingOpInterface maskingOp,
337  RewriterBase &rewriter) {
338  VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
339  if (oldAccType == nullptr)
340  return failure();
341  if (oldAccType.getRank() < 2)
342  return failure();
343  if (oldAccType.getShape()[0] != 1)
344  return failure();
345  // currently we support only dropping one dim but the pattern can be applied
346  // greedily to drop more.
347  int64_t dropDim = 1;
348 
349  auto oldIndexingMaps = contractOp.getIndexingMapsArray();
350  SmallVector<AffineMap> newIndexingMaps;
351 
352  auto oldIteratorTypes = contractOp.getIteratorTypes();
353  SmallVector<Attribute> newIteratorTypes;
354 
355  int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
356 
357  if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
358  // only parallel type iterators can be dropped.
359  return failure();
360 
361  for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
362  int64_t currDim = it.index();
363  if (currDim == dimToDrop)
364  continue;
365  newIteratorTypes.push_back(it.value());
366  }
367 
368  SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
369  contractOp.getAcc()};
370  SmallVector<Value> newOperands;
371  auto loc = contractOp.getLoc();
372 
373  for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
374  // Check if the dim to be dropped exists as a leading dim in the operand
375  // if it does then we use vector.extract to drop it.
376  bool validExtract = false;
377  SmallVector<AffineExpr> results;
378  auto map = it.value();
379  int64_t orginalZeroDim = it.value().getDimPosition(0);
380  if (orginalZeroDim != dimToDrop) {
381  // There are two reasons to be in this path, 1. We need to
382  // transpose the operand to make the dim to be dropped
383  // leading. 2. The dim to be dropped does not exist and in
384  // that case we dont want to add a unit transpose but we must
385  // check all the indices to make sure this is the case.
386  bool transposeNeeded = false;
388  SmallVector<AffineExpr> transposeResults;
389 
390  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
391  int64_t currDim = map.getDimPosition(i);
392  if (currDim == dimToDrop) {
393  transposeNeeded = true;
394  perm.insert(perm.begin(), i);
395  auto targetExpr = rewriter.getAffineDimExpr(currDim);
396  transposeResults.insert(transposeResults.begin(), targetExpr);
397  } else {
398  perm.push_back(i);
399  auto targetExpr = rewriter.getAffineDimExpr(currDim);
400  transposeResults.push_back(targetExpr);
401  }
402  }
403 
404  // Checks if only the outer, unit dimensions (of size 1) are permuted.
405  // Such transposes do not materially effect the underlying vector and can
406  // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
407  bool transposeNonOuterUnitDims = false;
408  auto operandShape = cast<ShapedType>(operands[it.index()].getType());
409  for (auto [index, dim] :
410  llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
411  if (dim != static_cast<int64_t>(index) &&
412  operandShape.getDimSize(index) != 1) {
413  transposeNonOuterUnitDims = true;
414  break;
415  }
416  }
417 
418  // Do the transpose now if needed so that we can drop the
419  // correct dim using extract later.
420  if (transposeNeeded) {
421  map = AffineMap::get(map.getNumDims(), 0, transposeResults,
422  contractOp.getContext());
423  if (transposeNonOuterUnitDims) {
424  operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
425  loc, operands[it.index()], perm);
426  }
427  }
428  }
429  // We have taken care to have the dim to be dropped be
430  // the leading dim. If its still not leading that means it
431  // does not exist in this operand and hence we do not need
432  // an extract.
433  if (map.getDimPosition(0) == dimToDrop)
434  validExtract = true;
435 
436  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
437  int64_t currDim = map.getDimPosition(i);
438  if (currDim == dimToDrop)
439  // This is the dim we are dropping.
440  continue;
441  auto targetExpr = rewriter.getAffineDimExpr(
442  currDim < dimToDrop ? currDim : currDim - 1);
443  results.push_back(targetExpr);
444  }
445  newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
446  contractOp.getContext()));
447  // Extract if its a valid extraction, otherwise use the operand
448  // without extraction.
449  newOperands.push_back(validExtract
450  ? vector::ExtractOp::create(rewriter, loc,
451  operands[it.index()],
452  splatZero(dropDim))
453  : operands[it.index()]);
454  }
455 
456  // Depending on whether this vector.contract is masked, the replacing Op
457  // should either be a new vector.contract Op or vector.mask Op.
458  Operation *newOp = vector::ContractionOp::create(
459  rewriter, loc, newOperands[0], newOperands[1], newOperands[2],
460  rewriter.getAffineMapArrayAttr(newIndexingMaps),
461  rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
462 
463  if (maskingOp) {
464  auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
465  splatZero(dropDim));
466 
467  newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
468  }
469 
470  return vector::BroadcastOp::create(rewriter, loc,
471  contractOp->getResultTypes()[0],
472  newOp->getResults()[0])
473  .getResult();
474 }
475 
476 namespace {
477 
478 /// Turns vector.contract on vector with leading 1 dimensions into
479 /// vector.extract followed by vector.contract on vector without leading
480 /// 1 dimensions. Also performs transpose of lhs and rhs operands if required
481 /// prior to extract.
482 struct CastAwayContractionLeadingOneDim
483  : public MaskableOpRewritePattern<vector::ContractionOp> {
484  using MaskableOpRewritePattern::MaskableOpRewritePattern;
485 
486  FailureOr<Value>
487  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
488  MaskingOpInterface maskingOp,
489  PatternRewriter &rewriter) const override {
490  return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
491  }
492 };
493 
494 /// Looks at elementwise operations on vectors with at least one leading
495 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
496 /// and cast aways the leading one dimensions (_plural_) and then broadcasts
497 /// the results.
498 ///
499 /// Example before:
500 /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
501 /// Example after:
502 /// %2 = arith.mulf %0, %1 : vector<4x1xf32>
503 /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
504 ///
505 /// Does support scalable vectors.
506 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
507 public:
508  CastAwayElementwiseLeadingOneDim(MLIRContext *context,
509  PatternBenefit benefit = 1)
510  : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
511 
512  LogicalResult matchAndRewrite(Operation *op,
513  PatternRewriter &rewriter) const override {
515  return failure();
516  auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
517  if (!vecType)
518  return failure();
519  VectorType newVecType = trimLeadingOneDims(vecType);
520  if (newVecType == vecType)
521  return failure();
522  int64_t dropDim = vecType.getRank() - newVecType.getRank();
523  SmallVector<Value, 4> newOperands;
524  for (Value operand : op->getOperands()) {
525  if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
526  newOperands.push_back(vector::ExtractOp::create(
527  rewriter, op->getLoc(), operand, splatZero(dropDim)));
528  } else {
529  newOperands.push_back(operand);
530  }
531  }
532  Operation *newOp =
533  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
534  newOperands, newVecType, op->getAttrs());
535  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
536  newOp->getResult(0));
537  return success();
538  }
539 };
540 
541 // Drops leading 1 dimensions from vector.constant_mask and inserts a
542 // vector.broadcast back to the original shape.
543 struct CastAwayConstantMaskLeadingOneDim
544  : public OpRewritePattern<vector::ConstantMaskOp> {
545  using Base::Base;
546 
547  LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
548  PatternRewriter &rewriter) const override {
549  VectorType oldType = mask.getType();
550  VectorType newType = trimLeadingOneDims(oldType);
551 
552  if (newType == oldType)
553  return failure();
554 
555  int64_t dropDim = oldType.getRank() - newType.getRank();
556  ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
557 
558  // If any of the dropped unit dims has a size of `0`, the entire mask is a
559  // zero mask, else the unit dim has no effect on the mask.
560  int64_t flatLeadingSize =
561  llvm::product_of(dimSizes.take_front(dropDim + 1));
562  SmallVector<int64_t> newDimSizes = {flatLeadingSize};
563  newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
564 
565  auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
566  newType, newDimSizes);
567  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
568  return success();
569  }
570 };
571 
572 } // namespace
573 
574 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
576  patterns
577  .add<CastAwayExtractStridedSliceLeadingOneDim,
578  CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
579  CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
580  CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
581  CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
582 }
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:318
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< Value > castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2916
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4815
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:150
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163