MLIR  21.0.0git
VectorDropLeadUnitDim.cpp
Go to the documentation of this file.
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <numeric>
10 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #define DEBUG_TYPE "vector-drop-unit-dim"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 // Trims leading one dimensions from `oldType` and returns the result type.
25 // Returns `vector<1xT>` if `oldType` only has one element.
26 static VectorType trimLeadingOneDims(VectorType oldType) {
27  ArrayRef<int64_t> oldShape = oldType.getShape();
28  ArrayRef<int64_t> newShape = oldShape;
29 
30  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
31  ArrayRef<bool> newScalableDims = oldScalableDims;
32 
33  while (!newShape.empty() && newShape.front() == 1 &&
34  !newScalableDims.front()) {
35  newShape = newShape.drop_front(1);
36  newScalableDims = newScalableDims.drop_front(1);
37  }
38 
39  // Make sure we have at least 1 dimension per vector type requirements.
40  if (newShape.empty()) {
41  newShape = oldShape.take_back();
42  newScalableDims = oldType.getScalableDims().take_back();
43  }
44  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
45 }
46 
47 /// Return a smallVector of size `rank` containing all zeros.
48 static SmallVector<int64_t> splatZero(int64_t rank) {
49  return SmallVector<int64_t>(rank, 0);
50 }
51 namespace {
52 
53 // Casts away leading one dimensions in vector.extract_strided_slice's vector
54 // input by inserting vector.broadcast.
55 struct CastAwayExtractStridedSliceLeadingOneDim
56  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
58 
59  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
60  PatternRewriter &rewriter) const override {
61  // vector.extract_strided_slice requires the input and output vector to have
62  // the same rank. Here we drop leading one dimensions from the input vector
63  // type to make sure we don't cause mismatch.
64  VectorType oldSrcType = extractOp.getSourceVectorType();
65  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
66 
67  if (newSrcType.getRank() == oldSrcType.getRank())
68  return failure();
69 
70  int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
71 
72  VectorType oldDstType = extractOp.getType();
73  VectorType newDstType =
74  VectorType::get(oldDstType.getShape().drop_front(dropCount),
75  oldDstType.getElementType(),
76  oldDstType.getScalableDims().drop_front(dropCount));
77 
78  Location loc = extractOp.getLoc();
79 
80  Value newSrcVector = rewriter.create<vector::ExtractOp>(
81  loc, extractOp.getVector(), splatZero(dropCount));
82 
83  // The offsets/sizes/strides attribute can have a less number of elements
84  // than the input vector's rank: it is meant for the leading dimensions.
85  auto newOffsets = rewriter.getArrayAttr(
86  extractOp.getOffsets().getValue().drop_front(dropCount));
87  auto newSizes = rewriter.getArrayAttr(
88  extractOp.getSizes().getValue().drop_front(dropCount));
89  auto newStrides = rewriter.getArrayAttr(
90  extractOp.getStrides().getValue().drop_front(dropCount));
91 
92  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
93  loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
94 
95  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
96  newExtractOp);
97 
98  return success();
99  }
100 };
101 
102 // Casts away leading one dimensions in vector.insert_strided_slice's vector
103 // inputs by inserting vector.broadcast.
104 struct CastAwayInsertStridedSliceLeadingOneDim
105  : public OpRewritePattern<vector::InsertStridedSliceOp> {
107 
108  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
109  PatternRewriter &rewriter) const override {
110  VectorType oldSrcType = insertOp.getSourceVectorType();
111  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
112  VectorType oldDstType = insertOp.getDestVectorType();
113  VectorType newDstType = trimLeadingOneDims(oldDstType);
114 
115  int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
116  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
117  if (srcDropCount == 0 && dstDropCount == 0)
118  return failure();
119 
120  // Trim leading one dimensions from both operands.
121  Location loc = insertOp.getLoc();
122 
123  Value newSrcVector = rewriter.create<vector::ExtractOp>(
124  loc, insertOp.getValueToStore(), splatZero(srcDropCount));
125  Value newDstVector = rewriter.create<vector::ExtractOp>(
126  loc, insertOp.getDest(), splatZero(dstDropCount));
127 
128  auto newOffsets = rewriter.getArrayAttr(
129  insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
130  auto newStrides = rewriter.getArrayAttr(
131  insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
132 
133  auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
134  loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
135 
136  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
137  newInsertOp);
138 
139  return success();
140  }
141 };
142 
143 // Casts away leading one dimensions in vector.insert's vector inputs by
144 // inserting vector.broadcast.
145 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
147 
148  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
149  PatternRewriter &rewriter) const override {
150  Type oldSrcType = insertOp.getValueToStoreType();
151  Type newSrcType = oldSrcType;
152  int64_t oldSrcRank = 0, newSrcRank = 0;
153  if (auto type = dyn_cast<VectorType>(oldSrcType)) {
154  newSrcType = trimLeadingOneDims(type);
155  oldSrcRank = type.getRank();
156  newSrcRank = cast<VectorType>(newSrcType).getRank();
157  }
158 
159  VectorType oldDstType = insertOp.getDestVectorType();
160  VectorType newDstType = trimLeadingOneDims(oldDstType);
161 
162  int64_t srcDropCount = oldSrcRank - newSrcRank;
163  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
164  if (srcDropCount == 0 && dstDropCount == 0)
165  return failure();
166 
167  // Trim leading one dimensions from both operands.
168  Location loc = insertOp.getLoc();
169 
170  Value newSrcVector = insertOp.getValueToStore();
171  if (oldSrcRank != 0) {
172  newSrcVector = rewriter.create<vector::ExtractOp>(
173  loc, insertOp.getValueToStore(), splatZero(srcDropCount));
174  }
175  Value newDstVector = rewriter.create<vector::ExtractOp>(
176  loc, insertOp.getDest(), splatZero(dstDropCount));
177 
178  // New position rank needs to be computed in two steps: (1) if destination
179  // type has leading unit dims, we also trim the position array accordingly,
180  // then (2) if source type also has leading unit dims, we need to append
181  // zeroes to the position array accordingly.
182  unsigned oldPosRank = insertOp.getNumIndices();
183  unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
184  SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
185  SmallVector<OpFoldResult> newPosition =
186  llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
187  newPosition.resize(newDstType.getRank() - newSrcRank,
188  rewriter.getI64IntegerAttr(0));
189 
190  auto newInsertOp = rewriter.create<vector::InsertOp>(
191  loc, newSrcVector, newDstVector, newPosition);
192 
193  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
194  newInsertOp);
195 
196  return success();
197  }
198 };
199 
200 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
201  VectorType newType, AffineMap newMap,
202  VectorType oldMaskType) {
203  // Infer the type of the new mask from the new map.
204  VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
205 
206  // If the new mask is broadcastable to the old result type, we can safely
207  // use a `vector.extract` to get the new mask. Otherwise the best we can
208  // do is shape cast.
209  if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
211  int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
212  return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
213  }
214  return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
215 }
216 
217 // Turns vector.transfer_read on vector with leading 1 dimensions into
218 // vector.shape_cast followed by vector.transfer_read on vector without leading
219 // 1 dimensions.
220 struct CastAwayTransferReadLeadingOneDim
221  : public OpRewritePattern<vector::TransferReadOp> {
223 
224  LogicalResult matchAndRewrite(vector::TransferReadOp read,
225  PatternRewriter &rewriter) const override {
226  // TODO(#78787): Not supported masked op yet.
227  if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
228  return failure();
229  // TODO: support 0-d corner case.
230  if (read.getTransferRank() == 0)
231  return failure();
232 
233  auto shapedType = cast<ShapedType>(read.getSource().getType());
234  if (shapedType.getElementType() != read.getVectorType().getElementType())
235  return failure();
236 
237  VectorType oldType = read.getVectorType();
238  VectorType newType = trimLeadingOneDims(oldType);
239 
240  if (newType == oldType)
241  return failure();
242 
243  AffineMap oldMap = read.getPermutationMap();
244  ArrayRef<AffineExpr> newResults =
245  oldMap.getResults().take_back(newType.getRank());
246  AffineMap newMap =
247  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
248  rewriter.getContext());
249 
250  ArrayAttr inBoundsAttr;
251  if (read.getInBounds())
252  inBoundsAttr = rewriter.getArrayAttr(
253  read.getInBoundsAttr().getValue().take_back(newType.getRank()));
254 
255  Value mask = Value();
256  if (read.getMask()) {
257  VectorType maskType = read.getMaskType();
258  mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
259  newType, newMap, maskType);
260  }
261 
262  auto newRead = rewriter.create<vector::TransferReadOp>(
263  read.getLoc(), newType, read.getSource(), read.getIndices(),
264  AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
265  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
266 
267  return success();
268  }
269 };
270 
271 // Turns vector.transfer_write on vector with leading 1 dimensions into
272 // vector.shape_cast followed by vector.transfer_write on vector without leading
273 // 1 dimensions.
274 struct CastAwayTransferWriteLeadingOneDim
275  : public OpRewritePattern<vector::TransferWriteOp> {
277 
278  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
279  PatternRewriter &rewriter) const override {
280  // TODO(#78787): Not supported masked op yet.
281  if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
282  return failure();
283  // TODO: support 0-d corner case.
284  if (write.getTransferRank() == 0)
285  return failure();
286 
287  auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
288  if (shapedType.getElementType() != write.getVectorType().getElementType())
289  return failure();
290 
291  VectorType oldType = write.getVectorType();
292  VectorType newType = trimLeadingOneDims(oldType);
293  if (newType == oldType)
294  return failure();
295  int64_t dropDim = oldType.getRank() - newType.getRank();
296 
297  AffineMap oldMap = write.getPermutationMap();
298  ArrayRef<AffineExpr> newResults =
299  oldMap.getResults().take_back(newType.getRank());
300  AffineMap newMap =
301  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
302  rewriter.getContext());
303 
304  ArrayAttr inBoundsAttr;
305  if (write.getInBounds())
306  inBoundsAttr = rewriter.getArrayAttr(
307  write.getInBoundsAttr().getValue().take_back(newType.getRank()));
308 
309  auto newVector = rewriter.create<vector::ExtractOp>(
310  write.getLoc(), write.getVector(), splatZero(dropDim));
311 
312  if (write.getMask()) {
313  VectorType maskType = write.getMaskType();
314  Value newMask = dropUnitDimsFromMask(
315  rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
316  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
317  write, newVector, write.getSource(), write.getIndices(),
318  AffineMapAttr::get(newMap), newMask, inBoundsAttr);
319  return success();
320  }
321 
322  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
323  write, newVector, write.getSource(), write.getIndices(),
324  AffineMapAttr::get(newMap), inBoundsAttr);
325  return success();
326  }
327 };
328 
329 } // namespace
330 
331 FailureOr<Value>
332 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
333  MaskingOpInterface maskingOp,
334  RewriterBase &rewriter) {
335  VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
336  if (oldAccType == nullptr)
337  return failure();
338  if (oldAccType.getRank() < 2)
339  return failure();
340  if (oldAccType.getShape()[0] != 1)
341  return failure();
342  // currently we support only dropping one dim but the pattern can be applied
343  // greedily to drop more.
344  int64_t dropDim = 1;
345 
346  auto oldIndexingMaps = contractOp.getIndexingMapsArray();
347  SmallVector<AffineMap> newIndexingMaps;
348 
349  auto oldIteratorTypes = contractOp.getIteratorTypes();
350  SmallVector<Attribute> newIteratorTypes;
351 
352  int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
353 
354  if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
355  // only parallel type iterators can be dropped.
356  return failure();
357 
358  for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
359  int64_t currDim = it.index();
360  if (currDim == dimToDrop)
361  continue;
362  newIteratorTypes.push_back(it.value());
363  }
364 
365  SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
366  contractOp.getAcc()};
367  SmallVector<Value> newOperands;
368  auto loc = contractOp.getLoc();
369 
370  for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
371  // Check if the dim to be dropped exists as a leading dim in the operand
372  // if it does then we use vector.extract to drop it.
373  bool validExtract = false;
374  SmallVector<AffineExpr> results;
375  auto map = it.value();
376  int64_t orginalZeroDim = it.value().getDimPosition(0);
377  if (orginalZeroDim != dimToDrop) {
378  // There are two reasons to be in this path, 1. We need to
379  // transpose the operand to make the dim to be dropped
380  // leading. 2. The dim to be dropped does not exist and in
381  // that case we dont want to add a unit transpose but we must
382  // check all the indices to make sure this is the case.
383  bool transposeNeeded = false;
385  SmallVector<AffineExpr> transposeResults;
386 
387  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
388  int64_t currDim = map.getDimPosition(i);
389  if (currDim == dimToDrop) {
390  transposeNeeded = true;
391  perm.insert(perm.begin(), i);
392  auto targetExpr = rewriter.getAffineDimExpr(currDim);
393  transposeResults.insert(transposeResults.begin(), targetExpr);
394  } else {
395  perm.push_back(i);
396  auto targetExpr = rewriter.getAffineDimExpr(currDim);
397  transposeResults.push_back(targetExpr);
398  }
399  }
400 
401  // Checks if only the outer, unit dimensions (of size 1) are permuted.
402  // Such transposes do not materially effect the underlying vector and can
403  // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
404  bool transposeNonOuterUnitDims = false;
405  auto operandShape = cast<ShapedType>(operands[it.index()].getType());
406  for (auto [index, dim] :
407  llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
408  if (dim != static_cast<int64_t>(index) &&
409  operandShape.getDimSize(index) != 1) {
410  transposeNonOuterUnitDims = true;
411  break;
412  }
413  }
414 
415  // Do the transpose now if needed so that we can drop the
416  // correct dim using extract later.
417  if (transposeNeeded) {
418  map = AffineMap::get(map.getNumDims(), 0, transposeResults,
419  contractOp.getContext());
420  if (transposeNonOuterUnitDims) {
421  operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
422  loc, operands[it.index()], perm);
423  }
424  }
425  }
426  // We have taken care to have the dim to be dropped be
427  // the leading dim. If its still not leading that means it
428  // does not exist in this operand and hence we do not need
429  // an extract.
430  if (map.getDimPosition(0) == dimToDrop)
431  validExtract = true;
432 
433  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
434  int64_t currDim = map.getDimPosition(i);
435  if (currDim == dimToDrop)
436  // This is the dim we are dropping.
437  continue;
438  auto targetExpr = rewriter.getAffineDimExpr(
439  currDim < dimToDrop ? currDim : currDim - 1);
440  results.push_back(targetExpr);
441  }
442  newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
443  contractOp.getContext()));
444  // Extract if its a valid extraction, otherwise use the operand
445  // without extraction.
446  newOperands.push_back(
447  validExtract ? rewriter.create<vector::ExtractOp>(
448  loc, operands[it.index()], splatZero(dropDim))
449  : operands[it.index()]);
450  }
451 
452  // Depending on whether this vector.contract is masked, the replacing Op
453  // should either be a new vector.contract Op or vector.mask Op.
454  Operation *newOp = rewriter.create<vector::ContractionOp>(
455  loc, newOperands[0], newOperands[1], newOperands[2],
456  rewriter.getAffineMapArrayAttr(newIndexingMaps),
457  rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
458 
459  if (maskingOp) {
460  auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
461  splatZero(dropDim));
462 
463  newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
464  }
465 
466  return rewriter
467  .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
468  newOp->getResults()[0])
469  .getResult();
470 }
471 
472 namespace {
473 
474 /// Turns vector.contract on vector with leading 1 dimensions into
475 /// vector.extract followed by vector.contract on vector without leading
476 /// 1 dimensions. Also performs transpose of lhs and rhs operands if required
477 /// prior to extract.
478 struct CastAwayContractionLeadingOneDim
479  : public MaskableOpRewritePattern<vector::ContractionOp> {
480  using MaskableOpRewritePattern::MaskableOpRewritePattern;
481 
482  FailureOr<Value>
483  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
484  MaskingOpInterface maskingOp,
485  PatternRewriter &rewriter) const override {
486  return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
487  }
488 };
489 
490 /// Looks at elementwise operations on vectors with at least one leading
491 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
492 /// and cast aways the leading one dimensions (_plural_) and then broadcasts
493 /// the results.
494 ///
495 /// Example before:
496 /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
497 /// Example after:
498 /// %2 = arith.mulf %0, %1 : vector<4x1xf32>
499 /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
500 ///
501 /// Does support scalable vectors.
502 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
503 public:
504  CastAwayElementwiseLeadingOneDim(MLIRContext *context,
505  PatternBenefit benefit = 1)
506  : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
507 
508  LogicalResult matchAndRewrite(Operation *op,
509  PatternRewriter &rewriter) const override {
511  return failure();
512  auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
513  if (!vecType)
514  return failure();
515  VectorType newVecType = trimLeadingOneDims(vecType);
516  if (newVecType == vecType)
517  return failure();
518  int64_t dropDim = vecType.getRank() - newVecType.getRank();
519  SmallVector<Value, 4> newOperands;
520  for (Value operand : op->getOperands()) {
521  if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
522  newOperands.push_back(rewriter.create<vector::ExtractOp>(
523  op->getLoc(), operand, splatZero(dropDim)));
524  } else {
525  newOperands.push_back(operand);
526  }
527  }
528  Operation *newOp =
529  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
530  newOperands, newVecType, op->getAttrs());
531  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
532  newOp->getResult(0));
533  return success();
534  }
535 };
536 
537 // Drops leading 1 dimensions from vector.constant_mask and inserts a
538 // vector.broadcast back to the original shape.
539 struct CastAwayConstantMaskLeadingOneDim
540  : public OpRewritePattern<vector::ConstantMaskOp> {
542 
543  LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
544  PatternRewriter &rewriter) const override {
545  VectorType oldType = mask.getType();
546  VectorType newType = trimLeadingOneDims(oldType);
547 
548  if (newType == oldType)
549  return failure();
550 
551  int64_t dropDim = oldType.getRank() - newType.getRank();
552  ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
553 
554  // If any of the dropped unit dims has a size of `0`, the entire mask is a
555  // zero mask, else the unit dim has no effect on the mask.
556  int64_t flatLeadingSize =
557  std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
558  static_cast<int64_t>(1), std::multiplies<int64_t>());
559  SmallVector<int64_t> newDimSizes = {flatLeadingSize};
560  newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
561 
562  auto newMask = rewriter.create<vector::ConstantMaskOp>(
563  mask.getLoc(), newType, newDimSizes);
564  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
565  return success();
566  }
567 };
568 
569 } // namespace
570 
573  patterns
574  .add<CastAwayExtractStridedSliceLeadingOneDim,
575  CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
576  CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
577  CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
578  CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
579 }
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< Value > castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2507
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4185
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:157