MLIR  22.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <cassert>
16 #include <cstdint>
17 #include <functional>
18 #include <optional>
19 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/Location.h"
31 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/TypeUtilities.h"
34 
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/FormatVariadic.h"
37 
38 #define DEBUG_TYPE "vector-to-vector"
39 
40 using namespace mlir;
41 using namespace mlir::vector;
42 
43 template <typename IntType>
44 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
45  return llvm::to_vector<4>(llvm::map_range(
46  arrayAttr.getAsRange<IntegerAttr>(),
47  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
48 }
49 
50 // Helper to find an index in an affine map.
51 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
52  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
53  int64_t idx = map.getDimPosition(i);
54  if (idx == index)
55  return i;
56  }
57  return std::nullopt;
58 }
59 
60 namespace {
61 
62 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
63 /// Ex:
64 /// ```
65 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
66 /// %1 = vector.multi_reduction add, %0 [1]
67 /// : vector<8x32x16xf32> to vector<8x16xf32>
68 /// ```
69 /// Gets converted to:
70 /// ```
71 /// %1 = vector.contract {indexing_maps = [
72 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
73 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
74 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
75 /// iterator_types = ["parallel", "parallel", "reduction"],
76 /// kind = add} %0, %arg1, %cst_f0
77 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
78 /// ```
79 struct MultiReduceToContract
80  : public OpRewritePattern<vector::MultiDimReductionOp> {
81  using Base::Base;
82 
83  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84  PatternRewriter &rewriter) const override {
85  if (reduceOp.getKind() != vector::CombiningKind::ADD)
86  return failure();
87  Operation *mulOp = reduceOp.getSource().getDefiningOp();
88  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
89  return failure();
90  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
91  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
94  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
95  if (!isReduceDim.value()) {
96  iteratorTypes.push_back(vector::IteratorType::parallel);
97  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
98  } else {
99  iteratorTypes.push_back(vector::IteratorType::reduction);
100  }
101  }
102  auto dstMap =
103  AffineMap::get(/*dimCount=*/reductionMask.size(),
104  /*symbolCount=*/0, exprs, reduceOp.getContext());
105  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
106  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
107  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
108  rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
109  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
110  return IteratorTypeAttr::get(rewriter.getContext(), t);
111  }))));
112  return success();
113  }
114 };
115 
116 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
117 /// Ex:
118 /// ```
119 /// %0 = vector.transpose %arg0, [2, 0, 1]
120 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
121 /// %1 = vector.contract {indexing_maps = [
122 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
123 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
124 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
125 /// iterator_types = ["parallel", "parallel", "reduction"],
126 /// kind = add} %0, %arg1, %cst_f0
127 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
128 /// ```
129 /// Gets converted to:
130 /// ```
131 /// %1 = vector.contract {indexing_maps = [
132 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
133 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
135 /// iterator_types = ["parallel", "parallel", "reduction"],
136 /// kind = add} %arg0, %arg1, %cst_f0
137 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138 /// ```
139 struct CombineContractABTranspose final
140  : public OpRewritePattern<vector::ContractionOp> {
141  using Base::Base;
142 
143  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144  PatternRewriter &rewriter) const override {
146  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147  Value lhs = contractOp.getLhs();
148  Value rhs = contractOp.getRhs();
149  size_t index = 0;
150  bool changed = false;
151  for (Value *operand : {&lhs, &rhs}) {
152  AffineMap &map = maps[index++];
153  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
154  if (!transposeOp)
155  continue;
156  AffineMap permutationMap = AffineMap::getPermutationMap(
157  transposeOp.getPermutation(), contractOp.getContext());
158  map = inversePermutation(permutationMap).compose(map);
159  *operand = transposeOp.getVector();
160  changed = true;
161  }
162  if (!changed)
163  return failure();
164  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
165  contractOp, lhs, rhs, contractOp.getAcc(),
166  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
167  return success();
168  }
169 };
170 
171 /// Merges accumulator and result transposes into contract.
172 ///
173 /// For example:
174 /// ```mlir
175 /// %accT = vector.transpose %acc, [0, 2, 1]
176 /// : vector<2x8x4xf32> to vector<2x4x8xf32>
177 /// %contract = vector.contract {
178 /// indexing_maps = [
179 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
180 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
181 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
182 /// ],
183 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
184 /// kind = #vector.kind<add>
185 /// } %lhs, %rhs, %accT
186 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
187 /// %0 = vector.transpose %contract, [0, 2, 1]
188 /// : vector<2x4x8xf32> to vector<2x8x4>
189 /// ```
190 /// Becomes:
191 /// ```mlir
192 /// %0 = vector.contract {
193 /// indexing_maps = [
194 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
195 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
196 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
197 /// ],
198 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
199 /// kind = #vector.kind<add>
200 /// } %lhs, %rhs, %acc
201 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
202 /// ```
203 struct CombineContractResultTranspose final
204  : public OpRewritePattern<vector::TransposeOp> {
205  using Base::Base;
206 
207  LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208  PatternRewriter &rewriter) const override {
209  auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210  if (!contractOp || !contractOp->hasOneUse())
211  return failure();
212 
213  auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
214  if (!accTOp)
215  return failure();
216 
217  MLIRContext *context = contractOp.getContext();
218  auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
219  AffineMap contractMap = maps.back();
220 
221  // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
222  // To index into A in contract, we need revert(f)(g(C)) -> A.
223  auto accTMap =
224  AffineMap::getPermutationMap(accTOp.getPermutation(), context);
225 
226  // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
227  // To index into E in contract, we need h(g(C)) -> E.
228  auto resTMap =
229  AffineMap::getPermutationMap(resTOp.getPermutation(), context);
230  auto combinedResMap = resTMap.compose(contractMap);
231 
232  // The accumulator and result share the same indexing map. So they should be
233  // the same to be able to merge. This means combinedResMap is the same as
234  // inversePermutation(accTMap).compose(contractMap), which means
235  if (inversePermutation(accTMap) != resTMap)
236  return failure();
237  maps.back() = combinedResMap;
238 
239  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
240  resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
241  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
242  return success();
243  }
244 };
245 
246 /// Merge BroadcastOp into ContractionOp user.
247 /// Ex:
248 /// ```
249 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
250 /// %1 = vector.contract {indexing_maps = [
251 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
252 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
253 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
254 /// iterator_types = ["parallel", "parallel", "reduction"],
255 /// kind = add} %0, %arg1, %cst_f0
256 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
257 /// ```
258 /// Gets converted to:
259 /// ```
260 /// %1 = vector.contract {indexing_maps = [
261 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
262 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
263 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
264 /// iterator_types = ["parallel", "parallel", "reduction"],
265 /// kind = add} %arg0, %arg1, %cst_f0
266 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267 /// ```
268 ///
269 /// For masked vector.contract, the mask requires updating when a dimension is
270 /// dropped. In such cases, the dropped dimensions must correspond to the mask's
271 /// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272 /// is not supported.
273 FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274  MaskingOpInterface maskingOp,
275  PatternRewriter &rewriter) {
277  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278  Value lhs = contractOp.getLhs();
279  Value rhs = contractOp.getRhs();
280  size_t index = 0;
281  bool changed = false;
282  for (Value *operand : {&lhs, &rhs}) {
283  AffineMap &map = maps[index++];
284  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
285  if (!broadcast)
286  continue;
287  // contractionOp can only take vector as operands.
288  auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
289  if (!srcType ||
290  srcType.getRank() == broadcast.getResultVectorType().getRank())
291  continue;
292  int64_t rankDiff =
293  broadcast.getResultVectorType().getRank() - srcType.getRank();
294  bool innerDimBroadcast = false;
295  SmallVector<AffineExpr> originalDims;
296  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
297  if (dim.value() !=
298  broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299  innerDimBroadcast = true;
300  break;
301  }
302  originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
303  }
304  // Contract doesn't support inner dimension broadcast. Once this is
305  // relaxed we can remove this case.
306  if (innerDimBroadcast)
307  continue;
308 
309  // It would be incorrect to fold a broadcast onto a reduction dimension
310  // of non-unit size.
311  bool nonUnitDimReductionBroadcast = false;
312  for (int64_t i = 0; i < rankDiff; ++i) {
313  if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
314  isReductionIterator(contractOp.getIteratorTypes()
315  .getValue()[map.getDimPosition(i)])) {
316  nonUnitDimReductionBroadcast = true;
317  break;
318  }
319  }
320  if (nonUnitDimReductionBroadcast)
321  continue;
322 
323  AffineMap broadcastMap =
324  AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
325  originalDims, contractOp.getContext());
326  map = broadcastMap.compose(map);
327  *operand = broadcast.getSource();
328  changed = true;
329  }
330 
331  if (!changed)
332  return failure();
333 
334  // Determine which dims are usused, now that the maps have been composed
335  // with the broadcast maps.
336  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
337  // Compress unused dims.
338  for (auto &m : maps)
339  m = compressDims(m, unusedDimsBitVector);
340  // Compute the combined iterators.
341  SmallVector<Attribute> iterators;
342  for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343  if (!unusedDimsBitVector.test(i))
344  iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
345  }
346 
347  // Check whether any of the unused dims is non-unit, e.g.:
348  // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349  // This is only required when collapsing a mask. If there is no mask, skip.
350  VectorType oldMaskType;
351  bool isAnyUnusedDimNonUnit = false;
352  if (maskingOp) {
353  oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354  for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355  if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356  isAnyUnusedDimNonUnit = true;
357  break;
358  }
359  }
360  }
361 
362  // Check that compressing unused dims isn't removing all reduction dimension
363  // pairs. For example, if the vector.contract had only one reduction
364  // iterator and that was a unit-dimension created by a broadcast,
365  // then we should bail here, otherwise we would create a contract without
366  // a reduction dimension pair.
367  bool hasReductionIteratorApplyingOnBothSides = false;
368  for (unsigned i = 0; i < iterators.size(); ++i) {
369  if (!isReductionIterator(iterators[i]))
370  continue;
371  if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
372  hasReductionIteratorApplyingOnBothSides = true;
373  break;
374  }
375  }
376  if (!hasReductionIteratorApplyingOnBothSides)
377  return failure();
378 
379  // If the compressed maps have a dimension that is not used by either LHS or
380  // RHS then the ContractionOp verifier would fail.
381  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
382  return failure();
383 
384  Operation *newOp = vector::ContractionOp::create(
385  rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
386  rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
387 
388  // Handle the mask.
389  if (maskingOp) {
390  if (isAnyUnusedDimNonUnit)
391  return rewriter.notifyMatchFailure(contractOp,
392  "Cannont drop non-unit mask dim.");
393  assert(unusedDimsBitVector.size() ==
394  static_cast<size_t>(oldMaskType.getRank()) &&
395  "The mask rank is incorrect!");
396 
397  // If a dimension has been dropped, update the mask accordingly. Otherwise,
398  // keep it as is.
399  Value mask = maskingOp.getMask();
400  if (unusedDimsBitVector.count() != 0) {
401  // At this point, two assumptions are made:
402  // * The unused dimensions are the leading mask dimensions
403  // (vector.contract does not support inner dim broadcasting).
404  // * The unused dimensions are all unit.
405  // These conditions are effectively verified in the blocks preceeding this
406  // one.
407  auto newShape =
408  oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409  auto newShapeScalableDims =
410  oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411  VectorType maskOpType =
412  VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
413  mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
414  maskOpType, maskingOp.getMask())
415  .getResult();
416  }
417 
418  newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
419  }
420  return newOp->getResult(0);
421 }
422 
423 struct CombineContractBroadcastMask
424  : public MaskableOpRewritePattern<vector::ContractionOp> {
425  using MaskableOpRewritePattern::MaskableOpRewritePattern;
426  FailureOr<Value>
427 
428  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
429  MaskingOpInterface maskingOp,
430  PatternRewriter &rewriter) const override {
431  return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
432  }
433 };
434 
435 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
436 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
437 /// casting ops are around these operations.
438 /// Ex:
439 /// ```
440 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
441 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
442 /// ```
443 /// Gets converted to:
444 /// ```
445 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
446 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
447 /// ```
448 struct ReorderCastOpsOnBroadcast
449  : public OpInterfaceRewritePattern<CastOpInterface> {
451 
452  LogicalResult matchAndRewrite(CastOpInterface op,
453  PatternRewriter &rewriter) const override {
454  if (op->getNumOperands() != 1)
455  return failure();
456  auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
457  if (!bcastOp)
458  return failure();
459 
460  Type castResTy = getElementTypeOrSelf(op->getResult(0));
461  if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
462  castResTy = vecTy.clone(castResTy);
463  auto *castOp =
464  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
465  bcastOp.getSource(), castResTy, op->getAttrs());
466  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
467  op, op->getResult(0).getType(), castOp->getResult(0));
468  return success();
469  }
470 };
471 
472 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
473 /// transpose ops and contraction ops closer, which kicks in
474 /// CombineContractABTranspose pattern when elementwise ops are between these
475 /// operations. Ex:
476 /// ```
477 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
478 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
479 /// %r = arith.addf %at, %bt : vector<2x4xf32>
480 /// ```
481 /// Gets converted to:
482 /// ```
483 /// %0 = arith.addf %a, %b : vector<4x2xf32>
484 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
485 /// ```
486 struct ReorderElementwiseOpsOnTranspose final
487  : public OpTraitRewritePattern<OpTrait::Elementwise> {
489  LogicalResult matchAndRewrite(Operation *op,
490  PatternRewriter &rewriter) const override {
491  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
492  return failure();
493 
494  // Make sure all operands are transpose/constant ops and collect their
495  // transposition maps.
496  SmallVector<ArrayRef<int64_t>> transposeMaps;
497  transposeMaps.reserve(op->getNumOperands());
498  // Record the initial type before transposition. We'll use its shape later.
499  // Any type will do here as we will check all transpose maps are the same.
500  VectorType srcType;
501  for (Value operand : op->getOperands()) {
502  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
503  if (transposeOp) {
504  transposeMaps.push_back(transposeOp.getPermutation());
505  srcType = transposeOp.getSourceVectorType();
506  } else if (!matchPattern(operand, m_Constant())) {
507  return failure();
508  }
509  }
510  if (transposeMaps.empty())
511  return failure();
512  // This is an elementwise op, so all transposed operands should have the
513  // same type. We need to additionally check that all transposes uses the
514  // same map.
515  if (!llvm::all_equal(transposeMaps))
516  return rewriter.notifyMatchFailure(op, "different transpose map");
517 
518  SmallVector<Value> srcValues;
519  srcValues.reserve(op->getNumOperands());
520 
521  // If there are constant operands, we need to insert inverse transposes for
522  // them. Calculate the inverse order first.
523  auto order = transposeMaps.front();
524  SmallVector<int64_t> invOrder(order.size());
525  for (int i = 0, e = order.size(); i < e; ++i)
526  invOrder[order[i]] = i;
527 
528  for (Value operand : op->getOperands()) {
529  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
530  if (transposeOp) {
531  srcValues.push_back(transposeOp.getVector());
532  } else {
533  // This is a constant. Create a reverse transpose op for it.
534  auto vectorType =
535  srcType.clone(cast<VectorType>(operand.getType()).getElementType());
536  srcValues.push_back(vector::TransposeOp::create(
537  rewriter, operand.getLoc(), vectorType, operand, invOrder));
538  }
539  }
540 
541  auto vectorType = srcType.clone(
542  cast<VectorType>(op->getResultTypes()[0]).getElementType());
543  Operation *elementwiseOp =
544  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
545  vectorType, op->getAttrs());
546  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
547  op, op->getResultTypes()[0], elementwiseOp->getResult(0),
548  transposeMaps.front());
549  return success();
550  }
551 };
552 
553 // Returns the values in `arrayAttr` as an integer vector.
554 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
555  return llvm::to_vector<4>(
556  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
557  [](IntegerAttr attr) { return attr.getInt(); }));
558 }
559 
560 // Shuffles vector.bitcast op after vector.extract op.
561 //
562 // This transforms IR like:
563 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
564 // %1 = vector.extract %0[3] : f16 from vector<8xf16>
565 // Into:
566 // %0 = vector.extract %src[1] : f32 from vector<4xf32>
567 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
568 // %2 = vector.extract %1[1] : f16 from vector<2xf16>
569 struct BubbleDownVectorBitCastForExtract
570  : public OpRewritePattern<vector::ExtractOp> {
571  using Base::Base;
572 
573  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574  PatternRewriter &rewriter) const override {
575  // Only support extracting scalars for now.
576  if (extractOp.getSourceVectorType().getRank() != 1)
577  return failure();
578 
579  auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
580  if (!castOp)
581  return failure();
582 
583  VectorType castSrcType = castOp.getSourceVectorType();
584  VectorType castDstType = castOp.getResultVectorType();
585  assert(castSrcType.getRank() == castDstType.getRank());
586 
587  // Fail to match if we only have one element in the cast op source.
588  // This is to avoid infinite loop given that this pattern can generate
589  // such cases.
590  if (castSrcType.getNumElements() == 1)
591  return failure();
592 
593  // Only support casting to a larger number of elements or now.
594  // E.g., vector<4xf32> -> vector<8xf16>.
595  if (castSrcType.getNumElements() > castDstType.getNumElements())
596  return failure();
597 
598  unsigned expandRatio =
599  castDstType.getNumElements() / castSrcType.getNumElements();
600 
601  // Get the first element of the mixed position as integer.
602  auto mixedPos = extractOp.getMixedPosition();
603  if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
604  return failure();
605  uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
606 
607  // Get the single scalar (as a vector) in the source value that packs the
608  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
609  Location loc = extractOp.getLoc();
610  Value packedValue = vector::ExtractOp::create(
611  rewriter, loc, castOp.getSource(), index / expandRatio);
612  Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
613  Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
614  rewriter.getZeroAttr(packedVecType));
615  packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
616  /*position=*/0);
617 
618  // Cast it to a vector with the desired scalar's type.
619  // E.g. f32 -> vector<2xf16>
620  VectorType packedType =
621  VectorType::get({expandRatio}, castDstType.getElementType());
622  Value castedValue =
623  vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
624 
625  // Finally extract the desired scalar.
626  rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
627  index % expandRatio);
628  return success();
629  }
630 };
631 
632 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
633 //
634 // This transforms IR like:
635 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
636 // %0 = vector.extract_strided_slice %cast {
637 // offsets = [4], sizes = [4], strides = [1]
638 // } : vector<8xf16> to vector<4xf16>
639 // Into:
640 // %0 = vector.extract_strided_slice %src {
641 // offsets = [2], sizes = [2], strides = [1]
642 // } : vector<4xf32> to vector<2xf32>
643 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
644 struct BubbleDownBitCastForStridedSliceExtract
645  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
646  using Base::Base;
647 
648  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649  PatternRewriter &rewriter) const override {
650  auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
651  if (!castOp)
652  return failure();
653 
654  VectorType castSrcType = castOp.getSourceVectorType();
655  VectorType castDstType = castOp.getResultVectorType();
656  assert(castSrcType.getRank() == castDstType.getRank());
657 
658  int64_t castSrcLastDim = castSrcType.getShape().back();
659  int64_t castDstLastDim = castDstType.getShape().back();
660  // Require casting to more elements for now; other cases to be implemented.
661  if (castSrcLastDim > castDstLastDim)
662  return failure();
663 
664  // Only accept all one strides for now.
665  if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666  [](const APInt &val) { return !val.isOne(); }))
667  return failure();
668 
669  unsigned rank = extractOp.getSourceVectorType().getRank();
670  assert(castDstLastDim % castSrcLastDim == 0);
671  int64_t expandRatio = castDstLastDim / castSrcLastDim;
672 
673  // If we have a less number of offsets than the rank, then implicitly we
674  // are selecting the full range for the last bitcasted dimension; other
675  // dimensions aren't affected. Otherwise, we need to scale down the last
676  // dimension's offset given we are extracting from less elements now.
677  ArrayAttr newOffsets = extractOp.getOffsets();
678  if (newOffsets.size() == rank) {
679  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
680  if (offsets.back() % expandRatio != 0)
681  return failure();
682  offsets.back() = offsets.back() / expandRatio;
683  newOffsets = rewriter.getI64ArrayAttr(offsets);
684  }
685 
686  // Similarly for sizes.
687  ArrayAttr newSizes = extractOp.getSizes();
688  if (newSizes.size() == rank) {
689  SmallVector<int64_t> sizes = getIntValueVector(newSizes);
690  if (sizes.back() % expandRatio != 0)
691  return failure();
692  sizes.back() = sizes.back() / expandRatio;
693  newSizes = rewriter.getI64ArrayAttr(sizes);
694  }
695 
696  SmallVector<int64_t> dims =
697  llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698  dims.back() = dims.back() / expandRatio;
699  VectorType newExtractType =
700  VectorType::get(dims, castSrcType.getElementType());
701 
702  auto newExtractOp = vector::ExtractStridedSliceOp::create(
703  rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
704  newOffsets, newSizes, extractOp.getStrides());
705 
706  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
707  extractOp, extractOp.getType(), newExtractOp);
708 
709  return success();
710  }
711 };
712 
713 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
714 //
715 // This transforms IR like:
716 // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
717 // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
718 // Into:
719 // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
720 // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
721 // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
722 //
723 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
724  using Base::Base;
725 
726  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727  PatternRewriter &rewriter) const override {
728  VectorType castSrcType = bitcastOp.getSourceVectorType();
729  VectorType castDstType = bitcastOp.getResultVectorType();
730 
731  // 0-D and scalable vectors are not supported yet.
732  if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733  castDstType.isScalable())
734  return failure();
735 
736  int64_t castSrcLastDim = castSrcType.getShape().back();
737  int64_t castDstLastDim = castDstType.getShape().back();
738  bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
739  int64_t ratio;
740  if (isNumElemsShrink) {
741  assert(castSrcLastDim % castDstLastDim == 0);
742  ratio = castSrcLastDim / castDstLastDim;
743  } else {
744  assert(castDstLastDim % castSrcLastDim == 0);
745  ratio = castDstLastDim / castSrcLastDim;
746  }
747 
748  auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
749  if (!insertOp)
750  return failure();
751 
752  // Only vector sources are supported for now.
753  auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
754  if (!insertSrcType)
755  return failure();
756 
757  // Bitcast the source.
758  SmallVector<int64_t> srcDims(insertSrcType.getShape());
759  srcDims.back() =
760  isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761  VectorType newCastSrcType =
762  VectorType::get(srcDims, castDstType.getElementType());
763  auto newCastSrcOp =
764  vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
765  insertOp.getValueToStore());
766 
767  SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
768  dstDims.back() =
769  isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770  VectorType newCastDstType =
771  VectorType::get(dstDims, castDstType.getElementType());
772 
773  // Bitcast the destination.
774  auto newCastDstOp = vector::BitCastOp::create(
775  rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
776 
777  // Generate new insert.
778  rewriter.replaceOpWithNewOp<vector::InsertOp>(
779  bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
780  return success();
781  }
782 };
783 
784 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
785 //
786 // This transforms IR like:
787 // %0 = vector.insert_strided_slice %src, %dst {
788 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
789 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
790 // Into:
791 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
792 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
793 // %2 = vector.insert_strided_slice %src, %dst {
794 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
795 struct BubbleUpBitCastForStridedSliceInsert
796  : public OpRewritePattern<vector::BitCastOp> {
797  using Base::Base;
798 
799  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
800  PatternRewriter &rewriter) const override {
801  VectorType castSrcType = bitcastOp.getSourceVectorType();
802  VectorType castDstType = bitcastOp.getResultVectorType();
803  assert(castSrcType.getRank() == castDstType.getRank());
804  // Skip 0-D vector which will not from InsertStridedSliceOp.
805  if (castSrcType.getRank() == 0)
806  return failure();
807 
808  int64_t castSrcLastDim = castSrcType.getShape().back();
809  int64_t castDstLastDim = castDstType.getShape().back();
810  // Require casting to less elements for now; other cases to be implemented.
811  if (castSrcLastDim < castDstLastDim)
812  return failure();
813 
814  assert(castSrcLastDim % castDstLastDim == 0);
815  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
816 
817  auto insertOp =
818  bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
819  if (!insertOp)
820  return failure();
821 
822  // Only accept all one strides for now.
823  if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824  [](const APInt &val) { return !val.isOne(); }))
825  return failure();
826 
827  unsigned rank = insertOp.getSourceVectorType().getRank();
828  // Require insert op to have the same rank for the source and destination
829  // vector; other cases to be implemented.
830  if (rank != insertOp.getDestVectorType().getRank())
831  return failure();
832 
833  // Requires that shape of insert op src is castable to dstType.
834  unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835  unsigned destinationWidth =
836  castDstType.getElementType().getIntOrFloatBitWidth();
837  unsigned numElements = destinationWidth / sourceWidth;
838  if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
839  return failure();
840 
841  ArrayAttr newOffsets = insertOp.getOffsets();
842  assert(newOffsets.size() == rank);
843  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
844  if (offsets.back() % shrinkRatio != 0)
845  return failure();
846  offsets.back() = offsets.back() / shrinkRatio;
847  newOffsets = rewriter.getI64ArrayAttr(offsets);
848 
849  SmallVector<int64_t> srcDims =
850  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
851  srcDims.back() = srcDims.back() / shrinkRatio;
852  VectorType newCastSrcType =
853  VectorType::get(srcDims, castDstType.getElementType());
854 
855  auto newCastSrcOp =
856  vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
857  insertOp.getValueToStore());
858 
859  SmallVector<int64_t> dstDims =
860  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
861  dstDims.back() = dstDims.back() / shrinkRatio;
862  VectorType newCastDstType =
863  VectorType::get(dstDims, castDstType.getElementType());
864 
865  auto newCastDstOp = vector::BitCastOp::create(
866  rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
867 
868  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
869  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
870  insertOp.getStrides());
871 
872  return success();
873  }
874 };
875 
876 // Breaks down vector.bitcast op
877 //
878 // This transforms IR like:
879 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
880 // Into:
881 // %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
882 // %1 = vector.extract_strided_slice %0 {
883 // offsets = [0], sizes = [4], strides = [1]
884 // } : vector<8xf16> to vector<4xf16>
885 // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
886 // %4 = vector.insert_strided_slice %2, %cst {
887 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
888 // %5 = vector.extract_strided_slice %0 {
889 // offsets = [4], sizes = [4], strides = [1]
890 // } : vector<8xf16> to vector<4xf16>
891 // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
892 // %7 = vector.insert_strided_slice %6, %cst {
893 // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
894 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
895  using Base::Base;
896 
897 public:
898  BreakDownVectorBitCast(MLIRContext *context,
899  std::function<bool(vector::BitCastOp)> controlFn,
900  PatternBenefit benefit)
901  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
902 
903  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
904  PatternRewriter &rewriter) const override {
905 
906  if (controlFn && !controlFn(bitcastOp))
907  return failure();
908 
909  VectorType castSrcType = bitcastOp.getSourceVectorType();
910  VectorType castDstType = bitcastOp.getResultVectorType();
911  assert(castSrcType.getRank() == castDstType.getRank());
912 
913  // This transformation builds on top of
914  // vector.{extract|insert}_strided_slice, which do not support
915  // extracting/inserting "scallable sub-vectors". Bail out.
916  if (castSrcType.isScalable())
917  return rewriter.notifyMatchFailure(bitcastOp,
918  "Scalable vectors are not supported");
919 
920  // Only support rank 1 case for now.
921  if (castSrcType.getRank() != 1)
922  return failure();
923 
924  int64_t castSrcLastDim = castSrcType.getShape().back();
925  int64_t castDstLastDim = castDstType.getShape().back();
926  // Require casting to less elements for now; other cases to be implemented.
927  if (castSrcLastDim < castDstLastDim)
928  return failure();
929 
930  assert(castSrcLastDim % castDstLastDim == 0);
931  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
932  // Nothing to do if it is already bitcasting to a single element.
933  if (castSrcLastDim == shrinkRatio)
934  return failure();
935 
936  Location loc = bitcastOp.getLoc();
937  Type elemType = castDstType.getElementType();
938  assert(elemType.isSignlessIntOrIndexOrFloat());
939 
940  Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
941  rewriter.getZeroAttr(elemType));
942  Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
943 
944  SmallVector<int64_t> sliceShape = {castDstLastDim};
945  SmallVector<int64_t> strides = {1};
946  VectorType newCastDstType =
947  VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
948  castDstType.getElementType());
949 
950  for (int i = 0, e = shrinkRatio; i < e; ++i) {
951  Value extracted = ExtractStridedSliceOp::create(
952  rewriter, loc, bitcastOp.getSource(),
953  ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
954  Value bitcast =
955  BitCastOp::create(rewriter, loc, newCastDstType, extracted);
956  res = InsertStridedSliceOp::create(
957  rewriter, loc, bitcast, res,
958  ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
959  }
960  rewriter.replaceOp(bitcastOp, res);
961  return success();
962  }
963 
964 private:
965  std::function<bool(BitCastOp)> controlFn;
966 };
967 
968 static bool haveSameShapeAndScaling(Type t, Type u) {
969  auto tVec = dyn_cast<VectorType>(t);
970  auto uVec = dyn_cast<VectorType>(u);
971  if (!tVec) {
972  return !uVec;
973  }
974  if (!uVec) {
975  return false;
976  }
977  return tVec.getShape() == uVec.getShape() &&
978  tVec.getScalableDims() == uVec.getScalableDims();
979 }
980 
981 /// If `type` is shaped, clone it with `newElementType`. Otherwise,
982 /// return `newElementType`.
983 static Type cloneOrReplace(Type type, Type newElementType) {
984  if (auto shapedType = dyn_cast<ShapedType>(type)) {
985  return shapedType.clone(newElementType);
986  }
987  return newElementType;
988 }
989 
990 /// If `value` is the result of a broadcast operation, return the input
991 /// of the broadcast operation.
992 static Value getBroadcastLikeSource(Value value) {
993 
994  Operation *op = value.getDefiningOp();
995  if (!op)
996  return {};
997 
998  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
999  return broadcast.getSource();
1000 
1001  return {};
1002 }
1003 
1004 /// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
1005 ///
1006 /// Example:
1007 /// ```
1008 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
1009 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
1010 /// %r = arith.addi %a, %b : vector<1x4xindex>
1011 /// ```
1012 /// Gets converted to:
1013 /// ```
1014 /// %r = arith.addi %arg0, %arg1 : index
1015 /// %b = vector.broadcast %r : index to vector<1x4xindex>
1016 /// ```
1017 struct ReorderElementwiseOpsOnBroadcast final
1018  : public OpTraitRewritePattern<OpTrait::Elementwise> {
1020  LogicalResult matchAndRewrite(Operation *op,
1021  PatternRewriter &rewriter) const override {
1022  if (op->getNumResults() != 1)
1023  return failure();
1024  auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
1025  if (!resultType)
1026  return failure();
1028  return rewriter.notifyMatchFailure(
1029  op, "Op doesn't have ElementwiseMappableTraits");
1030  if (op->getNumOperands() == 0)
1031  return failure();
1032  if (isa<vector::FMAOp>(op)) {
1033  return rewriter.notifyMatchFailure(
1034  op,
1035  "Op only accepts vector types - not supported as broadcast source "
1036  "might be a scalar");
1037  }
1038 
1039  Type resultElemType = resultType.getElementType();
1040 
1041  // Get the type of the first non-constant operand
1042  Value broadcastSource;
1043  for (Value operand : op->getOperands()) {
1044  Operation *definingOp = operand.getDefiningOp();
1045  if (!definingOp)
1046  return failure();
1047  if (definingOp->hasTrait<OpTrait::ConstantLike>())
1048  continue;
1049  broadcastSource = getBroadcastLikeSource(operand);
1050  break;
1051  }
1052  if (!broadcastSource)
1053  return failure();
1054  Type unbroadcastResultType =
1055  cloneOrReplace(broadcastSource.getType(), resultElemType);
1056 
1057  // Make sure that all operands are broadcast from identically-shaped types:
1058  // * scalar (`vector.broadcast`), or
1059  // * vector (`vector.broadcast`).
1060  // Otherwise the re-ordering wouldn't be safe.
1061  if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
1062  if (auto source = getBroadcastLikeSource(val))
1063  return haveSameShapeAndScaling(source.getType(),
1064  broadcastSource.getType());
1065  SplatElementsAttr splatConst;
1066  return matchPattern(val, m_Constant(&splatConst));
1067  })) {
1068  return rewriter.notifyMatchFailure(
1069  op,
1070  "not all operands are constants or broadcasts from the same type");
1071  }
1072 
1073  // Collect the source values before broadcasting
1074  SmallVector<Value> srcValues;
1075  srcValues.reserve(op->getNumOperands());
1076  for (Value operand : op->getOperands()) {
1077  SplatElementsAttr splatConst;
1078  if (matchPattern(operand, m_Constant(&splatConst))) {
1079  Attribute newConst;
1080  Type elementType = getElementTypeOrSelf(operand.getType());
1081  Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1082  if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1083  newConst = splatConst.resizeSplat(newTypeShaped);
1084  } else {
1085  newConst = splatConst.getSplatValue<Attribute>();
1086  }
1087  Operation *newConstOp =
1088  operand.getDefiningOp()->getDialect()->materializeConstant(
1089  rewriter, newConst, newType, operand.getLoc());
1090  srcValues.push_back(newConstOp->getResult(0));
1091  } else {
1092  srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1093  }
1094  }
1095 
1096  // Create the "elementwise" Op
1097  Operation *elementwiseOp =
1098  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1099  unbroadcastResultType, op->getAttrs());
1100 
1101  // Replace the original Op with the elementwise Op
1102  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1103  op, resultType, elementwiseOp->getResults());
1104 
1105  return success();
1106  }
1107 };
1108 
1109 /// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1110 /// This may result in cleaner code when extracting a single value
1111 /// from multi-element vector and also to help canonicalize 1-element vectors to
1112 /// scalars.
1113 ///
1114 /// Example:
1115 /// ```
1116 /// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1117 /// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1118 /// ```
1119 /// Gets converted to:
1120 /// ```
1121 /// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1122 /// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1123 /// %2 = arith.addf %0, %1 : f32
1124 /// ```
1125 class ExtractOpFromElementwise final
1126  : public OpRewritePattern<vector::ExtractOp> {
1127 public:
1128  using Base::Base;
1129 
1130  LogicalResult matchAndRewrite(vector::ExtractOp op,
1131  PatternRewriter &rewriter) const override {
1132  Operation *eltwise = op.getSource().getDefiningOp();
1133 
1134  // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1135  // as it doesn't support scalars.
1136  if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1137  isa<vector::FMAOp>(eltwise))
1138  return rewriter.notifyMatchFailure(op, "not an elementwise op");
1139 
1140  if (eltwise->getNumResults() != 1)
1141  return rewriter.notifyMatchFailure(op, "expected single result");
1142 
1143  if (!eltwise->hasOneUse())
1144  return rewriter.notifyMatchFailure(op, "expected single op use");
1145 
1146  if (!llvm::all_equal(eltwise->getOperandTypes()))
1147  return rewriter.notifyMatchFailure(op, "operand types are different");
1148 
1149  // Dynamic position can cause dominance issues, so conservatively fail for
1150  // now.
1151  if (!op.getDynamicPosition().empty())
1152  return rewriter.notifyMatchFailure(
1153  op, "dynamic position not yet implemented");
1154 
1155  Type dstType = op.getType();
1156 
1157  OpBuilder::InsertionGuard g(rewriter);
1158  rewriter.setInsertionPoint(eltwise);
1159 
1160  IRMapping mapping;
1161  Location loc = eltwise->getLoc();
1162  SmallVector<OpFoldResult> pos = op.getMixedPosition();
1163  for (Value arg : eltwise->getOperands()) {
1164  Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1165  mapping.map(arg, newArg);
1166  }
1167 
1168  Operation *newEltwise = rewriter.clone(*eltwise, mapping);
1169  newEltwise->getResult(0).setType(dstType);
1170 
1171  rewriter.replaceOp(op, newEltwise);
1172  rewriter.eraseOp(eltwise);
1173  return success();
1174  }
1175 };
1176 
1177 /// Check if the element type is suitable for vector.load/store sinking.
1178 /// Element type must be index or byte-aligned integer or floating-point type.
1179 static bool isSupportedMemSinkElementType(Type type) {
1180  if (isa<IndexType>(type))
1181  return true;
1182 
1183  return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1184 }
1185 
1186 /// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1187 /// Only index and byte-aligned integer and floating-point element types are
1188 /// supported for now.
1189 ///
1190 /// Example:
1191 /// ```
1192 /// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1193 /// vector.extract %0[1] : f32 from vector<4xf32>
1194 /// ```
1195 /// Gets converted to:
1196 /// ```
1197 /// %c1 = arith.constant 1 : index
1198 /// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1199 /// %1 = memref.load %arg0[%0] : memref<?xf32>
1200 /// ```
1201 class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1202 public:
1203  using Base::Base;
1204 
1205  LogicalResult matchAndRewrite(vector::ExtractOp op,
1206  PatternRewriter &rewriter) const override {
1207  auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1208  if (!loadOp)
1209  return rewriter.notifyMatchFailure(op, "expected a load op");
1210 
1211  // Checking for single use so we won't duplicate load ops.
1212  if (!loadOp->hasOneUse())
1213  return rewriter.notifyMatchFailure(op, "expected single op use");
1214 
1215  VectorType loadVecType = loadOp.getVectorType();
1216  if (loadVecType.isScalable())
1217  return rewriter.notifyMatchFailure(op,
1218  "scalable vectors are not supported");
1219 
1220  MemRefType memType = loadOp.getMemRefType();
1221 
1222  // Non-byte-aligned types are tricky and may require special handling,
1223  // ignore them for now.
1224  if (!isSupportedMemSinkElementType(memType.getElementType()))
1225  return rewriter.notifyMatchFailure(op, "unsupported element type");
1226 
1227  int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1228  if (rankOffset < 0)
1229  return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1230 
1231  auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1232  int64_t finalRank = 0;
1233  if (extractVecType)
1234  finalRank = extractVecType.getRank();
1235 
1236  SmallVector<Value> indices = loadOp.getIndices();
1237  SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1238 
1239  // There may be memory stores between the load and the extract op, so we
1240  // need to make sure that the new load op is inserted at the same place as
1241  // the original load op.
1242  OpBuilder::InsertionGuard g(rewriter);
1243  rewriter.setInsertionPoint(loadOp);
1244  Location loc = loadOp.getLoc();
1245  ArithIndexingBuilder idxBuilderf(rewriter, loc);
1246  for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1247  OpFoldResult pos = extractPos[i - rankOffset];
1248  if (isZeroInteger(pos))
1249  continue;
1250 
1251  Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1252  indices[i] = idxBuilderf.add(indices[i], offset);
1253  }
1254 
1255  Value base = loadOp.getBase();
1256  if (extractVecType) {
1257  rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
1258  indices);
1259  } else {
1260  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1261  }
1262  // We checked for single use so we can safely erase the load op.
1263  rewriter.eraseOp(loadOp);
1264  return success();
1265  }
1266 };
1267 
1268 /// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
1269 ///
1270 /// Example:
1271 /// ```
1272 /// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
1273 /// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1274 /// ```
1275 /// Gets converted to:
1276 /// ```
1277 /// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1278 /// ```
1279 class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
1280 public:
1281  using Base::Base;
1282 
1283  LogicalResult matchAndRewrite(vector::StoreOp op,
1284  PatternRewriter &rewriter) const override {
1285  VectorType vecType = op.getVectorType();
1286  if (vecType.isScalable())
1287  return rewriter.notifyMatchFailure(op,
1288  "scalable vectors are not supported");
1289 
1290  if (isa<VectorType>(op.getMemRefType().getElementType()))
1291  return rewriter.notifyMatchFailure(
1292  op, "memrefs of vectors are not supported");
1293 
1294  if (vecType.getNumElements() != 1)
1295  return rewriter.notifyMatchFailure(
1296  op, "only 1-element vectors are supported");
1297 
1298  Value toStore = op.getValueToStore();
1299  Value source = getBroadcastLikeSource(toStore);
1300  if (!source)
1301  return rewriter.notifyMatchFailure(
1302  op, "value to store is not from a broadcast");
1303 
1304  // Checking for single use so we can remove broadcast.
1305  Operation *broadcast = toStore.getDefiningOp();
1306  if (!broadcast->hasOneUse())
1307  return rewriter.notifyMatchFailure(op, "expected single op use");
1308 
1309  Value base = op.getBase();
1310  ValueRange indices = op.getIndices();
1311 
1312  if (isa<VectorType>(source.getType())) {
1313  rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1314  } else {
1315  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1316  }
1317  rewriter.eraseOp(broadcast);
1318  return success();
1319  }
1320 };
1321 
1322 // Helper that returns a vector comparison that constructs a mask:
1323 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1324 //
1325 // If `dim == 0` then the result will be a 0-D vector.
1326 //
1327 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1328 // much more compact, IR for this operation, but LLVM eventually
1329 // generates more elaborate instructions for this intrinsic since it
1330 // is very conservative on the boundary conditions.
1331 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1332  bool force32BitVectorIndices, int64_t dim,
1333  Value b, Value *off = nullptr) {
1334  auto loc = op->getLoc();
1335  // If we can assume all indices fit in 32-bit, we perform the vector
1336  // comparison in 32-bit to get a higher degree of SIMD parallelism.
1337  // Otherwise we perform the vector comparison using 64-bit indices.
1338  Type idxType =
1339  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1340  DenseIntElementsAttr indicesAttr;
1341  if (dim == 0 && force32BitVectorIndices) {
1342  indicesAttr = DenseIntElementsAttr::get(
1344  } else if (dim == 0) {
1345  indicesAttr = DenseIntElementsAttr::get(
1347  } else if (force32BitVectorIndices) {
1348  indicesAttr = rewriter.getI32VectorAttr(
1349  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1350  } else {
1351  indicesAttr = rewriter.getI64VectorAttr(
1352  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1353  }
1354  Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1355  // Add in an offset if requested.
1356  if (off) {
1357  Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1358  Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
1359  indices = arith::AddIOp::create(rewriter, loc, ov, indices);
1360  }
1361  // Construct the vector comparison.
1362  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1363  Value bounds =
1364  vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1365  return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1366  indices, bounds);
1367 }
1368 
1369 template <typename ConcreteOp>
1370 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1371 public:
1372  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1373  PatternBenefit benefit = 1)
1374  : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1375  force32BitVectorIndices(enableIndexOpt) {}
1376 
1377  LogicalResult matchAndRewrite(ConcreteOp xferOp,
1378  PatternRewriter &rewriter) const override {
1379  if (!xferOp.hasOutOfBoundsDim())
1380  return failure();
1381 
1382  if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1383  return failure();
1384 
1385  Location loc = xferOp->getLoc();
1386  VectorType vtp = xferOp.getVectorType();
1387 
1388  // Create the in-bounds mask with all elements between [0 .. dim - offset)
1389  // set and [dim - offset .. vector_length) unset.
1390  //
1391  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1392  // dimensions here.
1393  unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1394  Value off = xferOp.getIndices()[lastIndex];
1395  Value dim =
1396  vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex);
1397  Value b = arith::SubIOp::create(rewriter, loc, dim.getType(), dim, off);
1398  Value mask = vector::CreateMaskOp::create(
1399  rewriter, loc,
1400  VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1401  vtp.getScalableDims()),
1402  b);
1403  if (xferOp.getMask()) {
1404  // Intersect the in-bounds with the mask specified as an op parameter.
1405  mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1406  }
1407 
1408  rewriter.modifyOpInPlace(xferOp, [&]() {
1409  xferOp.getMaskMutable().assign(mask);
1410  xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1411  });
1412 
1413  return success();
1414  }
1415 
1416 private:
1417  const bool force32BitVectorIndices;
1418 };
1419 
1420 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1421 class VectorCreateMaskOpConversion
1422  : public OpRewritePattern<vector::CreateMaskOp> {
1423 public:
1424  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1425  bool enableIndexOpt,
1426  PatternBenefit benefit = 1)
1427  : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1428  force32BitVectorIndices(enableIndexOpt) {}
1429 
1430  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1431  PatternRewriter &rewriter) const override {
1432  auto dstType = op.getType();
1433  if (cast<VectorType>(dstType).isScalable())
1434  return failure();
1435  int64_t rank = dstType.getRank();
1436  if (rank > 1)
1437  return failure();
1438  rewriter.replaceOp(
1439  op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1440  rank == 0 ? 0 : dstType.getDimSize(0),
1441  op.getOperand(0)));
1442  return success();
1443  }
1444 
1445 private:
1446  const bool force32BitVectorIndices;
1447 };
1448 
1449 /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1450 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1451  auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1452  // TODO: Support non-dense constant.
1453  if (!denseAttr)
1454  return false;
1455 
1456  assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1457  return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1458 }
1459 
1460 /// Folds a select operation between an all-true and all-false vector. For now,
1461 /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1462 ///
1463 /// %true = arith.constant dense<true> : vector<1xi1>
1464 /// %false = arith.constant dense<false> : vector<1xi1>
1465 /// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1466 /// =>
1467 /// %result = vector.broadcast %cond : i1 to vector<1xi1>
1468 ///
1469 /// InstCombine seems to handle vectors with multiple elements but not the
1470 /// single element ones.
1471 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1472  using Base::Base;
1473 
1474  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1475  PatternRewriter &rewriter) const override {
1476  auto vecType = dyn_cast<VectorType>(selectOp.getType());
1477  if (!vecType || !vecType.getElementType().isInteger(1))
1478  return failure();
1479 
1480  // Only scalar conditions can be folded.
1481  Value cond = selectOp.getCondition();
1482  if (isa<VectorType>(cond.getType()))
1483  return failure();
1484 
1485  // TODO: Support n-D and scalable vectors.
1486  if (vecType.getRank() != 1 || vecType.isScalable())
1487  return failure();
1488 
1489  // TODO: Support vectors with multiple elements.
1490  if (vecType.getShape()[0] != 1)
1491  return failure();
1492 
1493  auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1494  if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1495  return failure();
1496 
1497  auto falseConst =
1498  selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1499  if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1500  return failure();
1501 
1502  // Replace select with its condition broadcasted to single element vector.
1503  auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1504  auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1505  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1506  return success();
1507  }
1508 };
1509 
1510 /// Returns the number of dims can be folded away from transfer ops. It returns
1511 /// a failure if it can not determine the number of dims to be folded.
1512 ///
1513 /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1514 /// `vectorType` is vector<16x16x1x1xf32>
1515 /// (there two inner most dims can be dropped by memref.subview ops)
1516 ///
1517 /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1518 /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1519 /// (only the inner most unit dim of `srcType` can be dropped)
1520 ///
1521 /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1522 /// `vectorType` is vector<16x16x1x[1]xf32>
1523 /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1524 /// unit")
1525 static FailureOr<size_t>
1526 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1527  SmallVector<int64_t> srcStrides;
1528  int64_t srcOffset;
1529  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1530  return failure();
1531 
1532  auto isUnitDim = [](VectorType type, int dim) {
1533  return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1534  };
1535 
1536  // According to vector.transfer_read/write semantics, the vector can be a
1537  // slice. Thus, we have to offset the check index with `rankDiff` in
1538  // `srcStrides` and source dim sizes.
1539  size_t result = 0;
1540  int rankDiff = srcType.getRank() - vectorType.getRank();
1541  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1542  // Check that the inner dim size is 1 for both memref type and vector slice.
1543  // It can be folded only if they are 1 and the stride is 1.
1544  int dim = vectorType.getRank() - i - 1;
1545  if (srcStrides[dim + rankDiff] != 1 ||
1546  srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1547  break;
1548  result++;
1549  }
1550  return result;
1551 }
1552 
1553 /// Drop inner most contiguous unit dimensions from transfer_read operand.
1554 class DropInnerMostUnitDimsTransferRead
1555  : public OpRewritePattern<vector::TransferReadOp> {
1556  using Base::Base;
1557 
1558  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1559  PatternRewriter &rewriter) const override {
1560  // TODO: support 0-d corner case.
1561  if (readOp.getTransferRank() == 0)
1562  return failure();
1563 
1564  // TODO: support mask.
1565  if (readOp.getMask())
1566  return failure();
1567 
1568  auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1569  if (!srcType)
1570  return failure();
1571 
1572  if (!readOp.getPermutationMap().isMinorIdentity())
1573  return failure();
1574 
1575  auto targetType = readOp.getVectorType();
1576  if (targetType.getRank() <= 1)
1577  return failure();
1578 
1579  FailureOr<size_t> maybeDimsToDrop =
1580  getTransferFoldableInnerUnitDims(srcType, targetType);
1581  if (failed(maybeDimsToDrop))
1582  return failure();
1583 
1584  size_t dimsToDrop = maybeDimsToDrop.value();
1585  if (dimsToDrop == 0)
1586  return failure();
1587 
1588  auto inBounds = readOp.getInBoundsValues();
1589  auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1590  if (llvm::is_contained(droppedInBounds, false))
1591  return failure();
1592 
1593  auto resultTargetVecType =
1594  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1595  targetType.getElementType(),
1596  targetType.getScalableDims().drop_back(dimsToDrop));
1597 
1598  auto loc = readOp.getLoc();
1600  memref::getMixedSizes(rewriter, loc, readOp.getBase());
1601  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1602  rewriter.getIndexAttr(0));
1603  SmallVector<OpFoldResult> strides(srcType.getRank(),
1604  rewriter.getIndexAttr(1));
1605  MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1606  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1607  strides);
1608  ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1609  readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1610  Value rankedReducedView =
1611  memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1612  readOp.getBase(), offsets, sizes, strides);
1613  auto permMap = getTransferMinorIdentityMap(
1614  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1615  Value result = vector::TransferReadOp::create(
1616  rewriter, loc, resultTargetVecType, rankedReducedView,
1617  readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1618  readOp.getPadding(),
1619  // TODO: support mask.
1620  /*mask=*/Value(), inBoundsAttr);
1621  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1622  result);
1623  return success();
1624  }
1625 };
1626 
1627 /// Drop inner most contiguous unit dimensions from transfer_write operand.
1628 /// E.g.,
1629 /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1630 /// {in_bounds = [true, true, true, true, true]}
1631 /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1632 ///
1633 /// will be replaced with
1634 ///
1635 /// %subview = memref.subview %arg0
1636 /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1637 /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1638 /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1639 /// to vector<1x16x16xf32>
1640 /// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1641 /// {in_bounds = [true, true, true]}
1642 /// : vector<1x16x16xf32>, memref<1x512x16xf32>
1643 ///
1644 /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1645 class DropInnerMostUnitDimsTransferWrite
1646  : public OpRewritePattern<vector::TransferWriteOp> {
1647  using Base::Base;
1648 
1649  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1650  PatternRewriter &rewriter) const override {
1651  // TODO: support 0-d corner case.
1652  if (writeOp.getTransferRank() == 0)
1653  return failure();
1654 
1655  // TODO: support mask.
1656  if (writeOp.getMask())
1657  return failure();
1658 
1659  auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1660  if (!srcType)
1661  return failure();
1662 
1663  if (!writeOp.getPermutationMap().isMinorIdentity())
1664  return failure();
1665 
1666  auto targetType = writeOp.getVectorType();
1667  if (targetType.getRank() <= 1)
1668  return failure();
1669 
1670  FailureOr<size_t> maybeDimsToDrop =
1671  getTransferFoldableInnerUnitDims(srcType, targetType);
1672  if (failed(maybeDimsToDrop))
1673  return failure();
1674 
1675  size_t dimsToDrop = maybeDimsToDrop.value();
1676  if (dimsToDrop == 0)
1677  return failure();
1678 
1679  auto inBounds = writeOp.getInBoundsValues();
1680  auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1681  if (llvm::is_contained(droppedInBounds, false))
1682  return failure();
1683 
1684  auto resultTargetVecType =
1685  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1686  targetType.getElementType(),
1687  targetType.getScalableDims().drop_back(dimsToDrop));
1688 
1689  Location loc = writeOp.getLoc();
1691  memref::getMixedSizes(rewriter, loc, writeOp.getBase());
1692  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1693  rewriter.getIndexAttr(0));
1694  SmallVector<OpFoldResult> strides(srcType.getRank(),
1695  rewriter.getIndexAttr(1));
1696  MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1697  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1698  strides);
1699  ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1700  writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1701 
1702  Value rankedReducedView =
1703  memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1704  writeOp.getBase(), offsets, sizes, strides);
1705  auto permMap = getTransferMinorIdentityMap(
1706  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1707 
1708  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1709  loc, resultTargetVecType, writeOp.getVector());
1710  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1711  writeOp, shapeCast, rankedReducedView,
1712  writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1713  // TODO: support mask.
1714  /*mask=*/Value(), inBoundsAttr);
1715  return success();
1716  }
1717 };
1718 
1719 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1720 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1721 /// with the RHS transposed) lowering.
1722 struct CanonicalizeContractMatmulToMMT final
1723  : OpRewritePattern<vector::ContractionOp> {
1724  using Base::Base;
1725 
1726  using FilterConstraintType =
1727  std::function<LogicalResult(vector::ContractionOp op)>;
1728 
1729  CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1730  FilterConstraintType constraint)
1731  : OpRewritePattern<vector::ContractionOp>(context, benefit),
1732  filter(std::move(constraint)) {}
1733 
1734  LogicalResult matchAndRewrite(vector::ContractionOp op,
1735  PatternRewriter &rewriter) const override {
1736  if (failed(filter(op)))
1737  return failure();
1738 
1739  Location loc = op.getLoc();
1740  Value lhs = op.getLhs();
1741  Value rhs = op.getRhs();
1742  Value res = op.getAcc();
1743 
1744  // Set up the parallel/reduction structure in right form.
1745  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1746  auto infer = [&](MapList m) {
1747  return AffineMap::inferFromExprList(m, op.getContext());
1748  };
1749  AffineExpr m;
1750  AffineExpr n;
1751  AffineExpr k;
1752  bindDims(rewriter.getContext(), m, n, k);
1753  static constexpr std::array<int64_t, 2> perm = {1, 0};
1754  auto iteratorTypes = op.getIteratorTypes().getValue();
1755  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1756  if (iteratorTypes.size() != 3 ||
1757  !vector::isParallelIterator(iteratorTypes[0]) ||
1758  !vector::isParallelIterator(iteratorTypes[1]) ||
1759  !vector::isReductionIterator(iteratorTypes[2]))
1760  return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1761 
1762  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1763  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1764  if (maps == canonicalForm)
1765  return rewriter.notifyMatchFailure(op, "already in the canonical form");
1766 
1767  // Create a vector transpose making sure to emit zero/sign-extend at the
1768  // end.
1769  auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1770  if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1771  Value trans =
1772  vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1773  VectorType newType =
1774  cast<VectorType>(trans.getType())
1775  .clone(cast<VectorType>(mat.getType()).getElementType());
1776  return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1777  }
1778  if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1779  Value trans =
1780  vector::TransposeOp::create(rewriter, loc, zext.getIn(), perm);
1781  VectorType newType =
1782  VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1783  cast<VectorType>(mat.getType()).getElementType());
1784  return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1785  }
1786  return vector::TransposeOp::create(rewriter, loc, mat, perm);
1787  };
1788 
1789  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1790  rhs = createTranspose(rhs);
1791  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1792  lhs = createTranspose(lhs);
1793  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1794  rhs = createTranspose(rhs);
1795  lhs = createTranspose(lhs);
1796  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1797  std::swap(rhs, lhs);
1798  rhs = createTranspose(rhs);
1799  lhs = createTranspose(lhs);
1800  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1801  std::swap(rhs, lhs);
1802  rhs = createTranspose(rhs);
1803  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1804  std::swap(lhs, rhs);
1805  lhs = createTranspose(lhs);
1806  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1807  std::swap(lhs, rhs);
1808  } else {
1809  return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1810  }
1811  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1812  op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1813  op.getIteratorTypes());
1814  return success();
1815  };
1816 
1817 private:
1818  FilterConstraintType filter;
1819 };
1820 
1821 /// Pattern to fold arithmetic extensions on floating point data types into
1822 /// vector contraction operations. linalg.matmul introduces arithmetic
1823 /// extensions on its operands. Please mlir snippets below for more details.
1824 /// ```mlir
1825 /// "linalg.matmul"(%lhs, %rhs, %acc) ({
1826 /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1827 /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1828 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1829 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1830 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1831 /// "linalg.yield"(%acc) : (f32) -> ()
1832 /// })
1833 /// ```
1834 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1835 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1836 /// This pattern folds the arithmetic extensions into the vector contraction and
1837 /// enables the usage of native mixed precision Tensor Core instructions.
1838 template <typename ExtOp>
1839 struct FoldArithExtIntoContractionOp
1840  : public OpRewritePattern<vector::ContractionOp> {
1841  using Base::Base;
1842 
1843  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1844  PatternRewriter &rewriter) const override {
1845 
1846  auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1847  auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1848 
1849  if (!lhsDefOp || !rhsDefOp) {
1850  return rewriter.notifyMatchFailure(contractOp,
1851  "no defining op on contract operands");
1852  }
1853 
1854  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1855  contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1856  contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1857  contractOp.getIteratorTypesAttr());
1858 
1859  return success();
1860  }
1861 };
1862 
1863 /// Pattern to fold chained reduction to a series of vector additions and a
1864 /// final reduction. This form should require fewer subgroup operations.
1865 ///
1866 /// ```mlir
1867 /// %a = vector.reduction <add> %x, %acc
1868 /// %b = vector.reduction <add> %y, %a
1869 /// ==>
1870 /// %a = arith.addf %x, %y
1871 /// %b = vector.reduction <add> %a, %acc
1872 /// ```
1873 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1874  using Base::Base;
1875 
1876  LogicalResult matchAndRewrite(vector::ReductionOp op,
1877  PatternRewriter &rewriter) const override {
1878  // TODO: Handle other combining kinds.
1879  if (op.getKind() != vector::CombiningKind::ADD)
1880  return failure();
1881 
1882  // Accumulator is optional.
1883  Value acc = op.getAcc();
1884  if (!acc)
1885  return failure();
1886 
1887  if (!acc.getType().isIntOrFloat())
1888  return failure();
1889 
1890  auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1891  if (!parentReduction)
1892  return failure();
1893 
1894  Location loc = op.getLoc();
1895  Value vAdd;
1896  if (isa<IntegerType>(acc.getType())) {
1897  vAdd = rewriter.createOrFold<arith::AddIOp>(
1898  loc, parentReduction.getVector(), op.getVector());
1899  } else {
1900  vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1901  op.getVector());
1902  }
1903  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1904  parentReduction.getAcc());
1905  return success();
1906  }
1907 };
1908 
1909 // Helper function dropping unit non-scalable dimension from a VectorType
1910 // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1911 // dimensions are not dropped. Folding such dimensions would require "shifting"
1912 // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1913 // vector<[4]xf32>). This could be implemented in the future.
1914 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1915  auto inVecShape = inVecTy.getShape();
1916  SmallVector<int64_t> newShape;
1917  SmallVector<bool> newScalableDims;
1918  for (auto [dim, isScalable] :
1919  llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1920  if (dim == 1 && !isScalable)
1921  continue;
1922 
1923  newShape.push_back(dim);
1924  newScalableDims.push_back(isScalable);
1925  }
1926  // All dims have been dropped, return vector<1xeType>.
1927  if (newShape.empty()) {
1928  newShape.push_back(1);
1929  newScalableDims.push_back(false);
1930  }
1931 
1932  return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1933 }
1934 
1935 /// For vectors with at least one unit dim, replaces:
1936 /// elementwise(a, b)
1937 /// with:
1938 /// sc_a = shape_cast(a)
1939 /// sc_b = shape_cast(b)
1940 /// res = elementwise(sc_a, sc_b)
1941 /// return shape_cast(res)
1942 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1943 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1944 /// required to be rank > 1.
1945 ///
1946 /// Ex:
1947 /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1948 /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1949 ///
1950 /// gets converted to:
1951 ///
1952 /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1953 /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1954 /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1955 /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1956 /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1957 ///
1958 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1959 /// `%cast`.
1960 struct DropUnitDimFromElementwiseOps final
1961  : public OpTraitRewritePattern<OpTrait::Elementwise> {
1963  LogicalResult matchAndRewrite(Operation *op,
1964  PatternRewriter &rewriter) const override {
1965  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1966  return failure();
1967 
1968  auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1969  if (!resultVectorType)
1970  return failure();
1971 
1972  // Check the operand pre-conditions. For `Elementwise` ops all operands are
1973  // guaranteed to have identical shapes (with some exceptions such as
1974  // `arith.select`) and it suffices to only check one of them.
1975  auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1976  if (!sourceVectorType)
1977  return failure();
1978  if (sourceVectorType.getRank() < 2)
1979  return failure();
1980 
1981  SmallVector<Value> newOperands;
1982  auto loc = op->getLoc();
1983  for (auto operand : op->getOperands()) {
1984  auto opVectorType = cast<VectorType>(operand.getType());
1985  auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1986  if (newVType == opVectorType)
1987  return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1988 
1989  auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
1990  newOperands.push_back(opSC);
1991  }
1992 
1993  VectorType newResultVectorType =
1994  dropNonScalableUnitDimFromType(resultVectorType);
1995  // Create an updated elementwise Op without unit dim.
1996  Operation *elementwiseOp =
1997  rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1998  newResultVectorType, op->getAttrs());
1999 
2000  // Restore the unit dim by applying vector.shape_cast to the result.
2001  rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
2002  elementwiseOp->getResult(0));
2003 
2004  return success();
2005  }
2006 };
2007 
2008 /// A pattern to drop unit dims from vector.transpose.
2009 ///
2010 /// Example:
2011 ///
2012 /// BEFORE:
2013 /// ```mlir
2014 /// %transpose = vector.transpose %vector, [3, 0, 1, 2]
2015 /// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
2016 /// ```
2017 ///
2018 /// AFTER:
2019 /// ```mlir
2020 /// %dropDims = vector.shape_cast %vector
2021 /// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
2022 /// %transpose = vector.transpose %0, [1, 0]
2023 /// : vector<4x[4]xf32> to vector<[4]x4xf32>
2024 /// %restoreDims = vector.shape_cast %transpose
2025 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2026 /// ```
2027 struct DropUnitDimsFromTransposeOp final
2028  : OpRewritePattern<vector::TransposeOp> {
2029  using Base::Base;
2030 
2031  LogicalResult matchAndRewrite(vector::TransposeOp op,
2032  PatternRewriter &rewriter) const override {
2033  VectorType sourceType = op.getSourceVectorType();
2034  VectorType sourceTypeWithoutUnitDims =
2035  dropNonScalableUnitDimFromType(sourceType);
2036 
2037  if (sourceType == sourceTypeWithoutUnitDims)
2038  return failure();
2039 
2040  // Construct a map from dimIdx -> number of dims dropped before dimIdx.
2041  auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
2042  SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
2043  int64_t droppedDims = 0;
2044  for (auto [i, dim] : llvm::enumerate(sourceDims)) {
2045  droppedDimsBefore[i] = droppedDims;
2046  if (dim == std::make_tuple(1, false))
2047  ++droppedDims;
2048  }
2049 
2050  // Drop unit dims from transpose permutation.
2051  ArrayRef<int64_t> perm = op.getPermutation();
2052  SmallVector<int64_t> newPerm;
2053  for (int64_t idx : perm) {
2054  if (sourceDims[idx] == std::make_tuple(1, false))
2055  continue;
2056  newPerm.push_back(idx - droppedDimsBefore[idx]);
2057  }
2058 
2059  // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
2060  // type when the dimensions are unit dimensions. In this case, the newPerm
2061  // should be [0].
2062  if (newPerm.empty()) {
2063  newPerm.push_back(0);
2064  }
2065 
2066  Location loc = op.getLoc();
2067  // Drop the unit dims via shape_cast.
2068  auto dropDimsShapeCast = vector::ShapeCastOp::create(
2069  rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2070  // Create the new transpose.
2071  auto transposeWithoutUnitDims =
2072  vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2073  // Restore the unit dims via shape cast.
2074  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2075  op, op.getResultVectorType(), transposeWithoutUnitDims);
2076 
2077  return success();
2078  }
2079 };
2080 
2081 /// A pattern to drop unit dims from the iter_args of an scf.for.
2082 ///
2083 /// Example:
2084 ///
2085 /// BEFORE:
2086 /// ```mlir
2087 /// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
2088 /// ...
2089 /// scf.yield %
2090 /// }
2091 /// ```
2092 ///
2093 /// AFTER:
2094 /// ```mlir
2095 /// %drop = vector.shape_cast %init
2096 /// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
2097 /// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
2098 /// %new_iter = vector.shape_cast %iter
2099 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2100 /// ...
2101 /// }
2102 /// %res = vector.shape_cast %new_loop
2103 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2104 /// ```
2105 struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
2106  using Base::Base;
2107 
2108  LogicalResult matchAndRewrite(scf::ForOp forOp,
2109  PatternRewriter &rewriter) const override {
2110  /// Find the first iter_arg with droppable unit dims. Further applications
2111  /// of this pattern will apply to later arguments.
2112  for (OpOperand &operand : forOp.getInitArgsMutable()) {
2113  auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2114  if (!vectorType)
2115  continue;
2116 
2117  VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2118  if (vectorType == newVectorType)
2119  continue;
2120 
2121  // Create a new ForOp with that iter operand replaced.
2122  auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
2123  return vector::ShapeCastOp::create(b, loc, type, source);
2124  };
2125 
2126  Value replacement =
2127  castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2128  rewriter.replaceOp(forOp,
2129  replaceAndCastForOpIterArg(rewriter, forOp, operand,
2130  replacement, castFn));
2131  return success();
2132  }
2133  return failure();
2134  }
2135 };
2136 
2137 /// Pattern to eliminate redundant zero-constants added to reduction operands.
2138 /// It's enough for there to be one initial zero value, so we can eliminate the
2139 /// extra ones that feed into `vector.reduction <add>`. These get created by the
2140 /// `ChainedReduction` pattern.
2141 ///
2142 /// ```mlir
2143 /// %a = arith.addf %x, %zero
2144 /// %b = arith.addf %a, %y
2145 /// %c = vector.reduction <add> %b, %acc
2146 /// ==>
2147 /// %b = arith.addf %a, %y
2148 /// %c = vector.reduction <add> %b, %acc
2149 /// ```
2150 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
2151  using Base::Base;
2152 
2153  LogicalResult matchAndRewrite(vector::ReductionOp op,
2154  PatternRewriter &rewriter) const override {
2155  // TODO: Handle other reduction kinds and their identity values.
2156  if (op.getKind() != vector::CombiningKind::ADD)
2157  return failure();
2158 
2159  Type elemType = op.getSourceVectorType().getElementType();
2160  // The integer case should be handled by `arith.addi` folders, only check
2161  // for floats here.
2162  if (!isa<FloatType>(elemType))
2163  return failure();
2164 
2165  auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2166  if (!vAdd)
2167  return failure();
2168  auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2169  if (!addLhs)
2170  return failure();
2171 
2172  if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
2173  return failure();
2174 
2175  auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2176  addLhs.getLhs(), vAdd.getRhs());
2177  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
2178  op.getAcc());
2179  return success();
2180  }
2181 };
2182 
2183 /// Example:
2184 /// ```
2185 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
2186 /// ```
2187 /// is transformed into:
2188 /// ```
2189 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
2190 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
2191 /// %a = arith.addf %y, %z : f32
2192 /// ```
2193 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
2194  BreakDownVectorReduction(MLIRContext *context,
2195  unsigned maxNumElementsToExtract,
2196  PatternBenefit benefit)
2197  : OpRewritePattern(context, benefit),
2198  maxNumElementsToExtract(maxNumElementsToExtract) {}
2199 
2200  LogicalResult matchAndRewrite(vector::ReductionOp op,
2201  PatternRewriter &rewriter) const override {
2202  VectorType type = op.getSourceVectorType();
2203  if (type.isScalable() || op.isMasked())
2204  return failure();
2205  assert(type.getRank() == 1 && "Expected a 1-d vector");
2206 
2207  int64_t numElems = type.getNumElements();
2208  if (numElems > maxNumElementsToExtract) {
2209  return rewriter.notifyMatchFailure(
2210  op, llvm::formatv("has too many vector elements ({0}) to break down "
2211  "(max allowed: {1})",
2212  numElems, maxNumElementsToExtract));
2213  }
2214 
2215  Location loc = op.getLoc();
2216  SmallVector<Value> extracted(numElems, nullptr);
2217  for (auto [idx, extractedElem] : llvm::enumerate(extracted))
2218  extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2219  static_cast<int64_t>(idx));
2220 
2221  Value res = extracted.front();
2222  for (auto extractedElem : llvm::drop_begin(extracted))
2223  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
2224  extractedElem, op.getFastmathAttr());
2225  if (Value acc = op.getAcc())
2226  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
2227  op.getFastmathAttr());
2228 
2229  rewriter.replaceOp(op, res);
2230  return success();
2231  }
2232 
2233 private:
2234  unsigned maxNumElementsToExtract = 0;
2235 };
2236 
2237 /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
2238 /// B)`.
2239 /// Example:
2240 /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
2241 /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
2242 /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
2243 /// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
2244 ///
2245 /// Becomes :
2246 ///
2247 /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
2248 ///
2249 /// Supports only 1D-to-2D broadcasts. The following cases are not supported.
2250 /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
2251 /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
2252 /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
2253 template <typename MulOpType>
2254 struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
2256  // Returns whether a vector.broadcast matches requirements for an outerproduct
2257  // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
2258  bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
2259  // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
2260  // shape_casts/broadcasts which does not belong in this pattern.
2261  if (!broadcastOp.computeBroadcastedUnitDims().empty())
2262  return false;
2263  // Avoid broadcast like f32 or vector<f32> -> ResType
2264  auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2265  return srcType && srcType.getRank() != 2;
2266  }
2267 
2268  LogicalResult matchAndRewrite(MulOpType mulOp,
2269  PatternRewriter &rewriter) const override {
2270  auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2271  if (!resType)
2272  return failure();
2273  if (resType.getRank() != 2)
2274  return failure();
2275  /// If operandA can be written as tr(broadcast(A)) and operandB as
2276  /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2277  /// vector.outerproduct(A, B). Returns failure() otherwise.
2278  auto matchOuterProduct =
2279  [&](Value operandA,
2280  Value operandB) -> FailureOr<vector::OuterProductOp> {
2281  auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2282  if (!transposedLhs)
2283  return failure();
2284  // Fail unless this is a true 2-D matrix transpose.
2285  ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2286  if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2287  return failure();
2288 
2289  auto broadcastedLhs =
2290  transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2291  if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2292  return failure();
2293 
2294  auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2295  if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2296  return failure();
2297 
2298  return vector::OuterProductOp::create(
2299  rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2300  broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2301  };
2302 
2303  Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2304  auto maybeOuterP = matchOuterProduct(lhs, rhs);
2305  // Handle commutativity, the transposed op is the outerproduct LHS.
2306  if (failed(maybeOuterP))
2307  maybeOuterP = matchOuterProduct(rhs, lhs);
2308  if (failed(maybeOuterP))
2309  return failure();
2310  rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2311  return success();
2312  }
2313 };
2314 
2315 } // namespace
2316 
2319  patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2320  FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2321  patterns.getContext());
2322 }
2323 
2324 void mlir::vector::populateVectorMaskMaterializationPatterns(
2325  RewritePatternSet &patterns, bool force32BitVectorIndices,
2326  PatternBenefit benefit) {
2327  patterns.add<VectorCreateMaskOpConversion,
2328  MaterializeTransferMask<vector::TransferReadOp>,
2329  MaterializeTransferMask<vector::TransferWriteOp>>(
2330  patterns.getContext(), force32BitVectorIndices, benefit);
2331  patterns.add<FoldI1Select>(patterns.getContext(), benefit);
2332 }
2333 
2334 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2336  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2337  DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
2338 }
2339 
2340 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2342  patterns.add<BubbleDownVectorBitCastForExtract,
2343  BubbleDownBitCastForStridedSliceExtract,
2344  BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2345  patterns.getContext(), benefit);
2346 }
2347 
2348 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2350  std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2351  patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2352  std::move(controlFn), benefit);
2353 }
2354 
2357  std::function<LogicalResult(vector::ContractionOp)> constraint,
2358  PatternBenefit benefit) {
2359  patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2360  std::move(constraint));
2361 }
2362 
2365  patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2366  CombineContractABTranspose, CombineContractResultTranspose>(
2367  patterns.getContext(), benefit);
2368 }
2369 
2372  patterns.add<DropInnerMostUnitDimsTransferRead,
2373  DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
2374  benefit);
2375 }
2376 
2378  PatternBenefit benefit) {
2379  patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2380  ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2381  patterns.getContext(), benefit);
2382 }
2383 
2384 void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2385  PatternBenefit benefit) {
2386  // TODO: Consider converting these patterns to canonicalizations.
2387  patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
2388  benefit);
2389 }
2390 
2391 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2393  patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2394  patterns.add<ReduceRedundantZero>(patterns.getContext(),
2395  PatternBenefit(benefit.getBenefit() + 1));
2396 }
2397 
2398 void mlir::vector::populateBreakDownVectorReductionPatterns(
2399  RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2400  PatternBenefit benefit) {
2401  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
2402  maxNumElementsToExtract, benefit);
2403 }
2404 
2407  patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2408  FoldArithToVectorOuterProduct<arith::MulIOp>>(
2409  patterns.getContext());
2410 }
2411 
2412 //===----------------------------------------------------------------------===//
2413 // TableGen'd enum attribute definitions
2414 //===----------------------------------------------------------------------===//
2415 
2416 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:398
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:387
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:122
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:128
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:281
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:270
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:318
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:348
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:354
This class provides the API for a sub-set of ops that are known to be constant-like.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:833
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:155
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Definition: VectorUtils.h:130
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:189
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:150
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:39
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:119
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:399
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
Definition: AffineMap.cpp:710
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
ArithBuilder specialized specifically for tensor/memref indexing calculations.
Definition: Utils.h:126
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:333
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163