MLIR  22.0.0git
VectorDropLeadUnitDim.cpp
Go to the documentation of this file.
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <numeric>
10 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #define DEBUG_TYPE "vector-drop-unit-dim"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 // Trims leading one dimensions from `oldType` and returns the result type.
25 // Returns `vector<1xT>` if `oldType` only has one element.
26 static VectorType trimLeadingOneDims(VectorType oldType) {
27  ArrayRef<int64_t> oldShape = oldType.getShape();
28  ArrayRef<int64_t> newShape = oldShape;
29 
30  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
31  ArrayRef<bool> newScalableDims = oldScalableDims;
32 
33  while (!newShape.empty() && newShape.front() == 1 &&
34  !newScalableDims.front()) {
35  newShape = newShape.drop_front(1);
36  newScalableDims = newScalableDims.drop_front(1);
37  }
38 
39  // Make sure we have at least 1 dimension per vector type requirements.
40  if (newShape.empty()) {
41  newShape = oldShape.take_back();
42  newScalableDims = oldType.getScalableDims().take_back();
43  }
44  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
45 }
46 
47 /// Return a smallVector of size `rank` containing all zeros.
48 static SmallVector<int64_t> splatZero(int64_t rank) {
49  return SmallVector<int64_t>(rank, 0);
50 }
51 namespace {
52 
53 // Casts away leading one dimensions in vector.extract_strided_slice's vector
54 // input by inserting vector.broadcast.
55 struct CastAwayExtractStridedSliceLeadingOneDim
56  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
58 
59  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
60  PatternRewriter &rewriter) const override {
61  // vector.extract_strided_slice requires the input and output vector to have
62  // the same rank. Here we drop leading one dimensions from the input vector
63  // type to make sure we don't cause mismatch.
64  VectorType oldSrcType = extractOp.getSourceVectorType();
65  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
66 
67  if (newSrcType.getRank() == oldSrcType.getRank())
68  return failure();
69 
70  int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
71 
72  VectorType oldDstType = extractOp.getType();
73  VectorType newDstType =
74  VectorType::get(oldDstType.getShape().drop_front(dropCount),
75  oldDstType.getElementType(),
76  oldDstType.getScalableDims().drop_front(dropCount));
77 
78  Location loc = extractOp.getLoc();
79 
80  Value newSrcVector = vector::ExtractOp::create(
81  rewriter, loc, extractOp.getVector(), splatZero(dropCount));
82 
83  // The offsets/sizes/strides attribute can have a less number of elements
84  // than the input vector's rank: it is meant for the leading dimensions.
85  auto newOffsets = rewriter.getArrayAttr(
86  extractOp.getOffsets().getValue().drop_front(dropCount));
87  auto newSizes = rewriter.getArrayAttr(
88  extractOp.getSizes().getValue().drop_front(dropCount));
89  auto newStrides = rewriter.getArrayAttr(
90  extractOp.getStrides().getValue().drop_front(dropCount));
91 
92  auto newExtractOp = vector::ExtractStridedSliceOp::create(
93  rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
94  newStrides);
95 
96  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
97  newExtractOp);
98 
99  return success();
100  }
101 };
102 
103 // Casts away leading one dimensions in vector.insert_strided_slice's vector
104 // inputs by inserting vector.broadcast.
105 struct CastAwayInsertStridedSliceLeadingOneDim
106  : public OpRewritePattern<vector::InsertStridedSliceOp> {
108 
109  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
110  PatternRewriter &rewriter) const override {
111  VectorType oldSrcType = insertOp.getSourceVectorType();
112  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
113  VectorType oldDstType = insertOp.getDestVectorType();
114  VectorType newDstType = trimLeadingOneDims(oldDstType);
115 
116  int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
117  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
118  if (srcDropCount == 0 && dstDropCount == 0)
119  return failure();
120 
121  // Trim leading one dimensions from both operands.
122  Location loc = insertOp.getLoc();
123 
124  Value newSrcVector = vector::ExtractOp::create(
125  rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
126  Value newDstVector = vector::ExtractOp::create(
127  rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
128 
129  auto newOffsets = rewriter.getArrayAttr(
130  insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
131  auto newStrides = rewriter.getArrayAttr(
132  insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
133 
134  auto newInsertOp = vector::InsertStridedSliceOp::create(
135  rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
136  newStrides);
137 
138  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
139  newInsertOp);
140 
141  return success();
142  }
143 };
144 
145 // Casts away leading one dimensions in vector.insert's vector inputs by
146 // inserting vector.broadcast.
147 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
149 
150  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
151  PatternRewriter &rewriter) const override {
152  Type oldSrcType = insertOp.getValueToStoreType();
153  Type newSrcType = oldSrcType;
154  int64_t oldSrcRank = 0, newSrcRank = 0;
155  if (auto type = dyn_cast<VectorType>(oldSrcType)) {
156  newSrcType = trimLeadingOneDims(type);
157  oldSrcRank = type.getRank();
158  newSrcRank = cast<VectorType>(newSrcType).getRank();
159  }
160 
161  VectorType oldDstType = insertOp.getDestVectorType();
162  VectorType newDstType = trimLeadingOneDims(oldDstType);
163 
164  int64_t srcDropCount = oldSrcRank - newSrcRank;
165  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
166  if (srcDropCount == 0 && dstDropCount == 0)
167  return failure();
168 
169  // Trim leading one dimensions from both operands.
170  Location loc = insertOp.getLoc();
171 
172  Value newSrcVector = insertOp.getValueToStore();
173  if (oldSrcRank != 0) {
174  newSrcVector = vector::ExtractOp::create(
175  rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
176  }
177  Value newDstVector = vector::ExtractOp::create(
178  rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
179 
180  // New position rank needs to be computed in two steps: (1) if destination
181  // type has leading unit dims, we also trim the position array accordingly,
182  // then (2) if source type also has leading unit dims, we need to append
183  // zeroes to the position array accordingly.
184  unsigned oldPosRank = insertOp.getNumIndices();
185  unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
186  SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
187  SmallVector<OpFoldResult> newPosition =
188  llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
189  newPosition.resize(newDstType.getRank() - newSrcRank,
190  rewriter.getI64IntegerAttr(0));
191 
192  auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
193  newDstVector, newPosition);
194 
195  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
196  newInsertOp);
197 
198  return success();
199  }
200 };
201 
202 static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
203  VectorType newType, AffineMap newMap,
204  VectorType oldMaskType) {
205  // Infer the type of the new mask from the new map.
206  VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
207 
208  // If the new mask is broadcastable to the old result type, we can safely
209  // use a `vector.extract` to get the new mask. Otherwise the best we can
210  // do is shape cast.
211  if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
213  int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
214  return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim));
215  }
216  return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
217 }
218 
219 // Turns vector.transfer_read on vector with leading 1 dimensions into
220 // vector.shape_cast followed by vector.transfer_read on vector without leading
221 // 1 dimensions.
222 struct CastAwayTransferReadLeadingOneDim
223  : public OpRewritePattern<vector::TransferReadOp> {
225 
226  LogicalResult matchAndRewrite(vector::TransferReadOp read,
227  PatternRewriter &rewriter) const override {
228  // TODO(#78787): Not supported masked op yet.
229  if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
230  return failure();
231  // TODO: support 0-d corner case.
232  if (read.getTransferRank() == 0)
233  return failure();
234 
235  auto shapedType = cast<ShapedType>(read.getBase().getType());
236  if (shapedType.getElementType() != read.getVectorType().getElementType())
237  return failure();
238 
239  VectorType oldType = read.getVectorType();
240  VectorType newType = trimLeadingOneDims(oldType);
241 
242  if (newType == oldType)
243  return failure();
244 
245  AffineMap oldMap = read.getPermutationMap();
246  ArrayRef<AffineExpr> newResults =
247  oldMap.getResults().take_back(newType.getRank());
248  AffineMap newMap =
249  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
250  rewriter.getContext());
251 
252  ArrayAttr inBoundsAttr;
253  if (read.getInBounds())
254  inBoundsAttr = rewriter.getArrayAttr(
255  read.getInBoundsAttr().getValue().take_back(newType.getRank()));
256 
257  Value mask = Value();
258  if (read.getMask()) {
259  VectorType maskType = read.getMaskType();
260  mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
261  newType, newMap, maskType);
262  }
263 
264  auto newRead = vector::TransferReadOp::create(
265  rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
266  AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
267  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
268 
269  return success();
270  }
271 };
272 
273 // Turns vector.transfer_write on vector with leading 1 dimensions into
274 // vector.shape_cast followed by vector.transfer_write on vector without leading
275 // 1 dimensions.
276 struct CastAwayTransferWriteLeadingOneDim
277  : public OpRewritePattern<vector::TransferWriteOp> {
279 
280  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
281  PatternRewriter &rewriter) const override {
282  // TODO(#78787): Not supported masked op yet.
283  if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
284  return failure();
285  // TODO: support 0-d corner case.
286  if (write.getTransferRank() == 0)
287  return failure();
288 
289  auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
290  if (shapedType.getElementType() != write.getVectorType().getElementType())
291  return failure();
292 
293  VectorType oldType = write.getVectorType();
294  VectorType newType = trimLeadingOneDims(oldType);
295  if (newType == oldType)
296  return failure();
297  int64_t dropDim = oldType.getRank() - newType.getRank();
298 
299  AffineMap oldMap = write.getPermutationMap();
300  ArrayRef<AffineExpr> newResults =
301  oldMap.getResults().take_back(newType.getRank());
302  AffineMap newMap =
303  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
304  rewriter.getContext());
305 
306  ArrayAttr inBoundsAttr;
307  if (write.getInBounds())
308  inBoundsAttr = rewriter.getArrayAttr(
309  write.getInBoundsAttr().getValue().take_back(newType.getRank()));
310 
311  auto newVector = vector::ExtractOp::create(
312  rewriter, write.getLoc(), write.getVector(), splatZero(dropDim));
313 
314  if (write.getMask()) {
315  VectorType maskType = write.getMaskType();
316  Value newMask = dropUnitDimsFromMask(
317  rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
318  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
319  write, newVector, write.getBase(), write.getIndices(),
320  AffineMapAttr::get(newMap), newMask, inBoundsAttr);
321  return success();
322  }
323 
324  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
325  write, newVector, write.getBase(), write.getIndices(),
326  AffineMapAttr::get(newMap), inBoundsAttr);
327  return success();
328  }
329 };
330 
331 } // namespace
332 
333 FailureOr<Value>
334 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
335  MaskingOpInterface maskingOp,
336  RewriterBase &rewriter) {
337  VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
338  if (oldAccType == nullptr)
339  return failure();
340  if (oldAccType.getRank() < 2)
341  return failure();
342  if (oldAccType.getShape()[0] != 1)
343  return failure();
344  // currently we support only dropping one dim but the pattern can be applied
345  // greedily to drop more.
346  int64_t dropDim = 1;
347 
348  auto oldIndexingMaps = contractOp.getIndexingMapsArray();
349  SmallVector<AffineMap> newIndexingMaps;
350 
351  auto oldIteratorTypes = contractOp.getIteratorTypes();
352  SmallVector<Attribute> newIteratorTypes;
353 
354  int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
355 
356  if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
357  // only parallel type iterators can be dropped.
358  return failure();
359 
360  for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
361  int64_t currDim = it.index();
362  if (currDim == dimToDrop)
363  continue;
364  newIteratorTypes.push_back(it.value());
365  }
366 
367  SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
368  contractOp.getAcc()};
369  SmallVector<Value> newOperands;
370  auto loc = contractOp.getLoc();
371 
372  for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
373  // Check if the dim to be dropped exists as a leading dim in the operand
374  // if it does then we use vector.extract to drop it.
375  bool validExtract = false;
376  SmallVector<AffineExpr> results;
377  auto map = it.value();
378  int64_t orginalZeroDim = it.value().getDimPosition(0);
379  if (orginalZeroDim != dimToDrop) {
380  // There are two reasons to be in this path, 1. We need to
381  // transpose the operand to make the dim to be dropped
382  // leading. 2. The dim to be dropped does not exist and in
383  // that case we dont want to add a unit transpose but we must
384  // check all the indices to make sure this is the case.
385  bool transposeNeeded = false;
387  SmallVector<AffineExpr> transposeResults;
388 
389  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
390  int64_t currDim = map.getDimPosition(i);
391  if (currDim == dimToDrop) {
392  transposeNeeded = true;
393  perm.insert(perm.begin(), i);
394  auto targetExpr = rewriter.getAffineDimExpr(currDim);
395  transposeResults.insert(transposeResults.begin(), targetExpr);
396  } else {
397  perm.push_back(i);
398  auto targetExpr = rewriter.getAffineDimExpr(currDim);
399  transposeResults.push_back(targetExpr);
400  }
401  }
402 
403  // Checks if only the outer, unit dimensions (of size 1) are permuted.
404  // Such transposes do not materially effect the underlying vector and can
405  // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
406  bool transposeNonOuterUnitDims = false;
407  auto operandShape = cast<ShapedType>(operands[it.index()].getType());
408  for (auto [index, dim] :
409  llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
410  if (dim != static_cast<int64_t>(index) &&
411  operandShape.getDimSize(index) != 1) {
412  transposeNonOuterUnitDims = true;
413  break;
414  }
415  }
416 
417  // Do the transpose now if needed so that we can drop the
418  // correct dim using extract later.
419  if (transposeNeeded) {
420  map = AffineMap::get(map.getNumDims(), 0, transposeResults,
421  contractOp.getContext());
422  if (transposeNonOuterUnitDims) {
423  operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
424  loc, operands[it.index()], perm);
425  }
426  }
427  }
428  // We have taken care to have the dim to be dropped be
429  // the leading dim. If its still not leading that means it
430  // does not exist in this operand and hence we do not need
431  // an extract.
432  if (map.getDimPosition(0) == dimToDrop)
433  validExtract = true;
434 
435  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
436  int64_t currDim = map.getDimPosition(i);
437  if (currDim == dimToDrop)
438  // This is the dim we are dropping.
439  continue;
440  auto targetExpr = rewriter.getAffineDimExpr(
441  currDim < dimToDrop ? currDim : currDim - 1);
442  results.push_back(targetExpr);
443  }
444  newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
445  contractOp.getContext()));
446  // Extract if its a valid extraction, otherwise use the operand
447  // without extraction.
448  newOperands.push_back(validExtract
449  ? vector::ExtractOp::create(rewriter, loc,
450  operands[it.index()],
451  splatZero(dropDim))
452  : operands[it.index()]);
453  }
454 
455  // Depending on whether this vector.contract is masked, the replacing Op
456  // should either be a new vector.contract Op or vector.mask Op.
457  Operation *newOp = vector::ContractionOp::create(
458  rewriter, loc, newOperands[0], newOperands[1], newOperands[2],
459  rewriter.getAffineMapArrayAttr(newIndexingMaps),
460  rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
461 
462  if (maskingOp) {
463  auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
464  splatZero(dropDim));
465 
466  newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
467  }
468 
469  return vector::BroadcastOp::create(rewriter, loc,
470  contractOp->getResultTypes()[0],
471  newOp->getResults()[0])
472  .getResult();
473 }
474 
475 namespace {
476 
477 /// Turns vector.contract on vector with leading 1 dimensions into
478 /// vector.extract followed by vector.contract on vector without leading
479 /// 1 dimensions. Also performs transpose of lhs and rhs operands if required
480 /// prior to extract.
481 struct CastAwayContractionLeadingOneDim
482  : public MaskableOpRewritePattern<vector::ContractionOp> {
483  using MaskableOpRewritePattern::MaskableOpRewritePattern;
484 
485  FailureOr<Value>
486  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
487  MaskingOpInterface maskingOp,
488  PatternRewriter &rewriter) const override {
489  return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
490  }
491 };
492 
493 /// Looks at elementwise operations on vectors with at least one leading
494 /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
495 /// and cast aways the leading one dimensions (_plural_) and then broadcasts
496 /// the results.
497 ///
498 /// Example before:
499 /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
500 /// Example after:
501 /// %2 = arith.mulf %0, %1 : vector<4x1xf32>
502 /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
503 ///
504 /// Does support scalable vectors.
505 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
506 public:
507  CastAwayElementwiseLeadingOneDim(MLIRContext *context,
508  PatternBenefit benefit = 1)
509  : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
510 
511  LogicalResult matchAndRewrite(Operation *op,
512  PatternRewriter &rewriter) const override {
514  return failure();
515  auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
516  if (!vecType)
517  return failure();
518  VectorType newVecType = trimLeadingOneDims(vecType);
519  if (newVecType == vecType)
520  return failure();
521  int64_t dropDim = vecType.getRank() - newVecType.getRank();
522  SmallVector<Value, 4> newOperands;
523  for (Value operand : op->getOperands()) {
524  if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
525  newOperands.push_back(vector::ExtractOp::create(
526  rewriter, op->getLoc(), operand, splatZero(dropDim)));
527  } else {
528  newOperands.push_back(operand);
529  }
530  }
531  Operation *newOp =
532  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
533  newOperands, newVecType, op->getAttrs());
534  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
535  newOp->getResult(0));
536  return success();
537  }
538 };
539 
540 // Drops leading 1 dimensions from vector.constant_mask and inserts a
541 // vector.broadcast back to the original shape.
542 struct CastAwayConstantMaskLeadingOneDim
543  : public OpRewritePattern<vector::ConstantMaskOp> {
545 
546  LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
547  PatternRewriter &rewriter) const override {
548  VectorType oldType = mask.getType();
549  VectorType newType = trimLeadingOneDims(oldType);
550 
551  if (newType == oldType)
552  return failure();
553 
554  int64_t dropDim = oldType.getRank() - newType.getRank();
555  ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
556 
557  // If any of the dropped unit dims has a size of `0`, the entire mask is a
558  // zero mask, else the unit dim has no effect on the mask.
559  int64_t flatLeadingSize =
560  std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
561  static_cast<int64_t>(1), std::multiplies<int64_t>());
562  SmallVector<int64_t> newDimSizes = {flatLeadingSize};
563  newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
564 
565  auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
566  newType, newDimSizes);
567  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
568  return success();
569  }
570 };
571 
572 } // namespace
573 
574 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
576  patterns
577  .add<CastAwayExtractStridedSliceLeadingOneDim,
578  CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
579  CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
580  CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
581  CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
582 }
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:317
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< Value > castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2781
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4685
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:149
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163