MLIR  21.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <cassert>
16 #include <cstdint>
17 #include <functional>
18 #include <optional>
19 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/Location.h"
31 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/TypeUtilities.h"
34 
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/FormatVariadic.h"
37 
38 #define DEBUG_TYPE "vector-to-vector"
39 
40 using namespace mlir;
41 using namespace mlir::vector;
42 
43 template <typename IntType>
44 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
45  return llvm::to_vector<4>(llvm::map_range(
46  arrayAttr.getAsRange<IntegerAttr>(),
47  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
48 }
49 
50 // Helper to find an index in an affine map.
51 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
52  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
53  int64_t idx = map.getDimPosition(i);
54  if (idx == index)
55  return i;
56  }
57  return std::nullopt;
58 }
59 
60 namespace {
61 
62 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
63 /// Ex:
64 /// ```
65 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
66 /// %1 = vector.multi_reduction add, %0 [1]
67 /// : vector<8x32x16xf32> to vector<8x16xf32>
68 /// ```
69 /// Gets converted to:
70 /// ```
71 /// %1 = vector.contract {indexing_maps = [
72 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
73 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
74 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
75 /// iterator_types = ["parallel", "parallel", "reduction"],
76 /// kind = add} %0, %arg1, %cst_f0
77 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
78 /// ```
79 struct MultiReduceToContract
80  : public OpRewritePattern<vector::MultiDimReductionOp> {
82 
83  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84  PatternRewriter &rewriter) const override {
85  if (reduceOp.getKind() != vector::CombiningKind::ADD)
86  return failure();
87  Operation *mulOp = reduceOp.getSource().getDefiningOp();
88  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
89  return failure();
90  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
91  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
94  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
95  if (!isReduceDim.value()) {
96  iteratorTypes.push_back(vector::IteratorType::parallel);
97  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
98  } else {
99  iteratorTypes.push_back(vector::IteratorType::reduction);
100  }
101  }
102  auto dstMap =
103  AffineMap::get(/*dimCount=*/reductionMask.size(),
104  /*symbolCount=*/0, exprs, reduceOp.getContext());
105  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
106  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
107  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
108  rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
109  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
110  return IteratorTypeAttr::get(rewriter.getContext(), t);
111  }))));
112  return success();
113  }
114 };
115 
116 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
117 /// Ex:
118 /// ```
119 /// %0 = vector.transpose %arg0, [2, 0, 1]
120 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
121 /// %1 = vector.contract {indexing_maps = [
122 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
123 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
124 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
125 /// iterator_types = ["parallel", "parallel", "reduction"],
126 /// kind = add} %0, %arg1, %cst_f0
127 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
128 /// ```
129 /// Gets converted to:
130 /// ```
131 /// %1 = vector.contract {indexing_maps = [
132 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
133 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
135 /// iterator_types = ["parallel", "parallel", "reduction"],
136 /// kind = add} %arg0, %arg1, %cst_f0
137 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138 /// ```
139 struct CombineContractABTranspose final
140  : public OpRewritePattern<vector::ContractionOp> {
142 
143  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144  PatternRewriter &rewriter) const override {
146  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147  Value lhs = contractOp.getLhs();
148  Value rhs = contractOp.getRhs();
149  size_t index = 0;
150  bool changed = false;
151  for (Value *operand : {&lhs, &rhs}) {
152  AffineMap &map = maps[index++];
153  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
154  if (!transposeOp)
155  continue;
156  AffineMap permutationMap = AffineMap::getPermutationMap(
157  transposeOp.getPermutation(), contractOp.getContext());
158  map = inversePermutation(permutationMap).compose(map);
159  *operand = transposeOp.getVector();
160  changed = true;
161  }
162  if (!changed)
163  return failure();
164  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
165  contractOp, lhs, rhs, contractOp.getAcc(),
166  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
167  return success();
168  }
169 };
170 
171 /// Merges accumulator and result transposes into contract.
172 ///
173 /// For example:
174 /// ```mlir
175 /// %accT = vector.transpose %acc, [0, 2, 1]
176 /// : vector<2x8x4xf32> to vector<2x4x8xf32>
177 /// %contract = vector.contract {
178 /// indexing_maps = [
179 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
180 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
181 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
182 /// ],
183 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
184 /// kind = #vector.kind<add>
185 /// } %lhs, %rhs, %accT
186 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
187 /// %0 = vector.transpose %contract, [0, 2, 1]
188 /// : vector<2x4x8xf32> to vector<2x8x4>
189 /// ```
190 /// Becomes:
191 /// ```mlir
192 /// %0 = vector.contract {
193 /// indexing_maps = [
194 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
195 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
196 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
197 /// ],
198 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
199 /// kind = #vector.kind<add>
200 /// } %lhs, %rhs, %acc
201 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
202 /// ```
203 struct CombineContractResultTranspose final
204  : public OpRewritePattern<vector::TransposeOp> {
206 
207  LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208  PatternRewriter &rewriter) const override {
209  auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210  if (!contractOp || !contractOp->hasOneUse())
211  return failure();
212 
213  auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
214  if (!accTOp)
215  return failure();
216 
217  MLIRContext *context = contractOp.getContext();
218  auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
219  AffineMap contractMap = maps.back();
220 
221  // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
222  // To index into A in contract, we need revert(f)(g(C)) -> A.
223  auto accTMap =
224  AffineMap::getPermutationMap(accTOp.getPermutation(), context);
225 
226  // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
227  // To index into E in contract, we need h(g(C)) -> E.
228  auto resTMap =
229  AffineMap::getPermutationMap(resTOp.getPermutation(), context);
230  auto combinedResMap = resTMap.compose(contractMap);
231 
232  // The accumulator and result share the same indexing map. So they should be
233  // the same to be able to merge. This means combinedResMap is the same as
234  // inversePermutation(accTMap).compose(contractMap), which means
235  if (inversePermutation(accTMap) != resTMap)
236  return failure();
237  maps.back() = combinedResMap;
238 
239  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
240  resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
241  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
242  return success();
243  }
244 };
245 
246 /// Merge BroadcastOp into ContractionOp user.
247 /// Ex:
248 /// ```
249 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
250 /// %1 = vector.contract {indexing_maps = [
251 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
252 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
253 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
254 /// iterator_types = ["parallel", "parallel", "reduction"],
255 /// kind = add} %0, %arg1, %cst_f0
256 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
257 /// ```
258 /// Gets converted to:
259 /// ```
260 /// %1 = vector.contract {indexing_maps = [
261 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
262 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
263 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
264 /// iterator_types = ["parallel", "parallel", "reduction"],
265 /// kind = add} %arg0, %arg1, %cst_f0
266 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267 /// ```
268 struct CombineContractBroadcast
269  : public OpRewritePattern<vector::ContractionOp> {
271 
272  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
273  PatternRewriter &rewriter) const override {
275  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
276  Value lhs = contractOp.getLhs();
277  Value rhs = contractOp.getRhs();
278  size_t index = 0;
279  bool changed = false;
280  for (Value *operand : {&lhs, &rhs}) {
281  AffineMap &map = maps[index++];
282  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
283  if (!broadcast)
284  continue;
285  // contractionOp can only take vector as operands.
286  auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
287  if (!srcType ||
288  srcType.getRank() == broadcast.getResultVectorType().getRank())
289  continue;
290  int64_t rankDiff =
291  broadcast.getResultVectorType().getRank() - srcType.getRank();
292  bool innerDimBroadcast = false;
293  SmallVector<AffineExpr> originalDims;
294  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
295  if (dim.value() != broadcast.getResultVectorType().getDimSize(
296  rankDiff + dim.index())) {
297  innerDimBroadcast = true;
298  break;
299  }
300  originalDims.push_back(
301  rewriter.getAffineDimExpr(dim.index() + rankDiff));
302  }
303  // Contract doesn't support inner dimension broadcast. Once this is
304  // relaxed we can remove this case.
305  if (innerDimBroadcast)
306  continue;
307 
308  // It would be incorrect to fold a broadcast onto a reduction dimension
309  // of non-unit size.
310  bool nonUnitDimReductionBroadcast = false;
311  for (int64_t i = 0; i < rankDiff; ++i) {
312  if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
313  isReductionIterator(contractOp.getIteratorTypes()
314  .getValue()[map.getDimPosition(i)])) {
315  nonUnitDimReductionBroadcast = true;
316  break;
317  }
318  }
319  if (nonUnitDimReductionBroadcast)
320  continue;
321 
322  AffineMap broadcastMap =
323  AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
324  originalDims, contractOp.getContext());
325  map = broadcastMap.compose(map);
326  *operand = broadcast.getSource();
327  changed = true;
328  }
329 
330  if (!changed)
331  return failure();
332 
333  // Determine which dims are usused, now that the maps have been composed
334  // with the broadcast maps.
335  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
336  // Compress unused dims.
337  for (auto &m : maps)
338  m = compressDims(m, unusedDimsBitVector);
339  // Compute the combined iterators.
340  SmallVector<Attribute> iterators;
341  for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
342  if (!unusedDimsBitVector.test(i))
343  iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
344  }
345  // Check that compressing unused dims isn't removing all reduction dimension
346  // pairs. For example, if the vector.contract had only one reduction
347  // iterator and that was a unit-dimension created by a broadcast,
348  // then we should bail here, otherwise we would create a contract without
349  // a reduction dimension pair.
350  bool hasReductionIteratorApplyingOnBothSides = false;
351  for (unsigned i = 0; i < iterators.size(); ++i) {
352  if (!isReductionIterator(iterators[i]))
353  continue;
354  if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
355  hasReductionIteratorApplyingOnBothSides = true;
356  break;
357  }
358  }
359  if (!hasReductionIteratorApplyingOnBothSides)
360  return failure();
361 
362  // If the compressed maps have a dimension that is not used by either LHS or
363  // RHS then the ContractionOp verifier would fail.
364  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
365  return failure();
366  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
367  contractOp, lhs, rhs, contractOp.getAcc(),
368  rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
369  return success();
370  }
371 };
372 
373 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
374 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
375 /// casting ops are around these operations.
376 /// Ex:
377 /// ```
378 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
379 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
380 /// ```
381 /// Gets converted to:
382 /// ```
383 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
384 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
385 /// ```
386 struct ReorderCastOpsOnBroadcast
387  : public OpInterfaceRewritePattern<CastOpInterface> {
389 
390  LogicalResult matchAndRewrite(CastOpInterface op,
391  PatternRewriter &rewriter) const override {
392  if (op->getNumOperands() != 1)
393  return failure();
394  auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
395  if (!bcastOp)
396  return failure();
397 
398  Type castResTy = getElementTypeOrSelf(op->getResult(0));
399  if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
400  castResTy = vecTy.clone(castResTy);
401  auto *castOp =
402  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
403  bcastOp.getSource(), castResTy, op->getAttrs());
404  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
405  op, op->getResult(0).getType(), castOp->getResult(0));
406  return success();
407  }
408 };
409 
410 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
411 /// transpose ops and contraction ops closer, which kicks in
412 /// CombineContractABTranspose pattern when elementwise ops are between these
413 /// operations. Ex:
414 /// ```
415 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
416 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
417 /// %r = arith.addf %at, %bt : vector<2x4xf32>
418 /// ```
419 /// Gets converted to:
420 /// ```
421 /// %0 = arith.addf %a, %b : vector<4x2xf32>
422 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
423 /// ```
424 struct ReorderElementwiseOpsOnTranspose final
425  : public OpTraitRewritePattern<OpTrait::Elementwise> {
427  LogicalResult matchAndRewrite(Operation *op,
428  PatternRewriter &rewriter) const override {
429  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
430  return failure();
431 
432  // Make sure all operands are transpose/constant ops and collect their
433  // transposition maps.
434  SmallVector<ArrayRef<int64_t>> transposeMaps;
435  transposeMaps.reserve(op->getNumOperands());
436  // Record the initial type before transposition. We'll use its shape later.
437  // Any type will do here as we will check all transpose maps are the same.
438  VectorType srcType;
439  for (Value operand : op->getOperands()) {
440  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
441  if (transposeOp) {
442  transposeMaps.push_back(transposeOp.getPermutation());
443  srcType = transposeOp.getSourceVectorType();
444  } else if (!matchPattern(operand, m_Constant())) {
445  return failure();
446  }
447  }
448  if (transposeMaps.empty())
449  return failure();
450  // This is an elementwise op, so all transposed operands should have the
451  // same type. We need to additionally check that all transposes uses the
452  // same map.
453  if (!llvm::all_equal(transposeMaps))
454  return rewriter.notifyMatchFailure(op, "different transpose map");
455 
456  SmallVector<Value> srcValues;
457  srcValues.reserve(op->getNumOperands());
458 
459  // If there are constant operands, we need to insert inverse transposes for
460  // them. Calculate the inverse order first.
461  auto order = transposeMaps.front();
462  SmallVector<int64_t> invOrder(order.size());
463  for (int i = 0, e = order.size(); i < e; ++i)
464  invOrder[order[i]] = i;
465 
466  for (Value operand : op->getOperands()) {
467  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
468  if (transposeOp) {
469  srcValues.push_back(transposeOp.getVector());
470  } else {
471  // This is a constant. Create a reverse transpose op for it.
472  auto vectorType =
473  srcType.clone(cast<VectorType>(operand.getType()).getElementType());
474  srcValues.push_back(rewriter.create<vector::TransposeOp>(
475  operand.getLoc(), vectorType, operand, invOrder));
476  }
477  }
478 
479  auto vectorType = srcType.clone(
480  cast<VectorType>(op->getResultTypes()[0]).getElementType());
481  Operation *elementwiseOp =
482  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
483  vectorType, op->getAttrs());
484  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
485  op, op->getResultTypes()[0], elementwiseOp->getResult(0),
486  transposeMaps.front());
487  return success();
488  }
489 };
490 
491 // Returns the values in `arrayAttr` as an integer vector.
492 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
493  return llvm::to_vector<4>(
494  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
495  [](IntegerAttr attr) { return attr.getInt(); }));
496 }
497 
498 // Shuffles vector.bitcast op after vector.extract op.
499 //
500 // This transforms IR like:
501 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
502 // %1 = vector.extract %0[3] : f16 from vector<8xf16>
503 // Into:
504 // %0 = vector.extract %src[1] : f32 from vector<4xf32>
505 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
506 // %2 = vector.extract %1[1] : f16 from vector<2xf16>
507 struct BubbleDownVectorBitCastForExtract
508  : public OpRewritePattern<vector::ExtractOp> {
510 
511  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
512  PatternRewriter &rewriter) const override {
513  // Only support extracting scalars for now.
514  if (extractOp.getSourceVectorType().getRank() != 1)
515  return failure();
516 
517  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
518  if (!castOp)
519  return failure();
520 
521  VectorType castSrcType = castOp.getSourceVectorType();
522  VectorType castDstType = castOp.getResultVectorType();
523  assert(castSrcType.getRank() == castDstType.getRank());
524 
525  // Fail to match if we only have one element in the cast op source.
526  // This is to avoid infinite loop given that this pattern can generate
527  // such cases.
528  if (castSrcType.getNumElements() == 1)
529  return failure();
530 
531  // Only support casting to a larger number of elements or now.
532  // E.g., vector<4xf32> -> vector<8xf16>.
533  if (castSrcType.getNumElements() > castDstType.getNumElements())
534  return failure();
535 
536  unsigned expandRatio =
537  castDstType.getNumElements() / castSrcType.getNumElements();
538 
539  // Get the first element of the mixed position as integer.
540  auto mixedPos = extractOp.getMixedPosition();
541  if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
542  return failure();
543  uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
544 
545  // Get the single scalar (as a vector) in the source value that packs the
546  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
547  Location loc = extractOp.getLoc();
548  Value packedValue = rewriter.create<vector::ExtractOp>(
549  loc, castOp.getSource(), index / expandRatio);
550  Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
551  Value zero = rewriter.create<arith::ConstantOp>(
552  loc, packedVecType, rewriter.getZeroAttr(packedVecType));
553  packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
554  /*position=*/0);
555 
556  // Cast it to a vector with the desired scalar's type.
557  // E.g. f32 -> vector<2xf16>
558  VectorType packedType =
559  VectorType::get({expandRatio}, castDstType.getElementType());
560  Value castedValue =
561  rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
562 
563  // Finally extract the desired scalar.
564  rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
565  index % expandRatio);
566  return success();
567  }
568 };
569 
570 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
571 //
572 // This transforms IR like:
573 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
574 // %0 = vector.extract_strided_slice %cast {
575 // offsets = [4], sizes = [4], strides = [1]
576 // } : vector<8xf16> to vector<4xf16>
577 // Into:
578 // %0 = vector.extract_strided_slice %src {
579 // offsets = [2], sizes = [2], strides = [1]
580 // } : vector<4xf32> to vector<2xf32>
581 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
582 struct BubbleDownBitCastForStridedSliceExtract
583  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
585 
586  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
587  PatternRewriter &rewriter) const override {
588  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
589  if (!castOp)
590  return failure();
591 
592  VectorType castSrcType = castOp.getSourceVectorType();
593  VectorType castDstType = castOp.getResultVectorType();
594  assert(castSrcType.getRank() == castDstType.getRank());
595 
596  int64_t castSrcLastDim = castSrcType.getShape().back();
597  int64_t castDstLastDim = castDstType.getShape().back();
598  // Require casting to more elements for now; other cases to be implemented.
599  if (castSrcLastDim > castDstLastDim)
600  return failure();
601 
602  // Only accept all one strides for now.
603  if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
604  [](const APInt &val) { return !val.isOne(); }))
605  return failure();
606 
607  unsigned rank = extractOp.getSourceVectorType().getRank();
608  assert(castDstLastDim % castSrcLastDim == 0);
609  int64_t expandRatio = castDstLastDim / castSrcLastDim;
610 
611  // If we have a less number of offsets than the rank, then implicitly we
612  // are selecting the full range for the last bitcasted dimension; other
613  // dimensions aren't affected. Otherwise, we need to scale down the last
614  // dimension's offset given we are extracting from less elements now.
615  ArrayAttr newOffsets = extractOp.getOffsets();
616  if (newOffsets.size() == rank) {
617  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
618  if (offsets.back() % expandRatio != 0)
619  return failure();
620  offsets.back() = offsets.back() / expandRatio;
621  newOffsets = rewriter.getI64ArrayAttr(offsets);
622  }
623 
624  // Similarly for sizes.
625  ArrayAttr newSizes = extractOp.getSizes();
626  if (newSizes.size() == rank) {
627  SmallVector<int64_t> sizes = getIntValueVector(newSizes);
628  if (sizes.back() % expandRatio != 0)
629  return failure();
630  sizes.back() = sizes.back() / expandRatio;
631  newSizes = rewriter.getI64ArrayAttr(sizes);
632  }
633 
634  SmallVector<int64_t> dims =
635  llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
636  dims.back() = dims.back() / expandRatio;
637  VectorType newExtractType =
638  VectorType::get(dims, castSrcType.getElementType());
639 
640  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
641  extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
642  newSizes, extractOp.getStrides());
643 
644  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
645  extractOp, extractOp.getType(), newExtractOp);
646 
647  return success();
648  }
649 };
650 
651 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
652 //
653 // This transforms IR like:
654 // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
655 // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
656 // Into:
657 // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
658 // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
659 // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
660 //
661 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
663 
664  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
665  PatternRewriter &rewriter) const override {
666  VectorType castSrcType = bitcastOp.getSourceVectorType();
667  VectorType castDstType = bitcastOp.getResultVectorType();
668 
669  // 0-D and scalable vectors are not supported yet.
670  if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
671  castDstType.isScalable())
672  return failure();
673 
674  int64_t castSrcLastDim = castSrcType.getShape().back();
675  int64_t castDstLastDim = castDstType.getShape().back();
676  bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
677  int64_t ratio;
678  if (isNumElemsShrink) {
679  assert(castSrcLastDim % castDstLastDim == 0);
680  ratio = castSrcLastDim / castDstLastDim;
681  } else {
682  assert(castDstLastDim % castSrcLastDim == 0);
683  ratio = castDstLastDim / castSrcLastDim;
684  }
685 
686  auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
687  if (!insertOp)
688  return failure();
689 
690  // Only vector sources are supported for now.
691  auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
692  if (!insertSrcType)
693  return failure();
694 
695  // Bitcast the source.
696  SmallVector<int64_t> srcDims(insertSrcType.getShape());
697  srcDims.back() =
698  isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
699  VectorType newCastSrcType =
700  VectorType::get(srcDims, castDstType.getElementType());
701  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
702  bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
703 
704  SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
705  dstDims.back() =
706  isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
707  VectorType newCastDstType =
708  VectorType::get(dstDims, castDstType.getElementType());
709 
710  // Bitcast the destination.
711  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
712  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
713 
714  // Generate new insert.
715  rewriter.replaceOpWithNewOp<vector::InsertOp>(
716  bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
717  return success();
718  }
719 };
720 
721 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
722 //
723 // This transforms IR like:
724 // %0 = vector.insert_strided_slice %src, %dst {
725 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
726 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
727 // Into:
728 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
729 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
730 // %2 = vector.insert_strided_slice %src, %dst {
731 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
732 struct BubbleUpBitCastForStridedSliceInsert
733  : public OpRewritePattern<vector::BitCastOp> {
735 
736  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
737  PatternRewriter &rewriter) const override {
738  VectorType castSrcType = bitcastOp.getSourceVectorType();
739  VectorType castDstType = bitcastOp.getResultVectorType();
740  assert(castSrcType.getRank() == castDstType.getRank());
741  // Skip 0-D vector which will not from InsertStridedSliceOp.
742  if (castSrcType.getRank() == 0)
743  return failure();
744 
745  int64_t castSrcLastDim = castSrcType.getShape().back();
746  int64_t castDstLastDim = castDstType.getShape().back();
747  // Require casting to less elements for now; other cases to be implemented.
748  if (castSrcLastDim < castDstLastDim)
749  return failure();
750 
751  assert(castSrcLastDim % castDstLastDim == 0);
752  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
753 
754  auto insertOp =
755  bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
756  if (!insertOp)
757  return failure();
758 
759  // Only accept all one strides for now.
760  if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
761  [](const APInt &val) { return !val.isOne(); }))
762  return failure();
763 
764  unsigned rank = insertOp.getSourceVectorType().getRank();
765  // Require insert op to have the same rank for the source and destination
766  // vector; other cases to be implemented.
767  if (rank != insertOp.getDestVectorType().getRank())
768  return failure();
769 
770  // Requires that shape of insert op src is castable to dstType.
771  unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
772  unsigned destinationWidth =
773  castDstType.getElementType().getIntOrFloatBitWidth();
774  unsigned numElements = destinationWidth / sourceWidth;
775  if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
776  return failure();
777 
778  ArrayAttr newOffsets = insertOp.getOffsets();
779  assert(newOffsets.size() == rank);
780  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
781  if (offsets.back() % shrinkRatio != 0)
782  return failure();
783  offsets.back() = offsets.back() / shrinkRatio;
784  newOffsets = rewriter.getI64ArrayAttr(offsets);
785 
786  SmallVector<int64_t> srcDims =
787  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
788  srcDims.back() = srcDims.back() / shrinkRatio;
789  VectorType newCastSrcType =
790  VectorType::get(srcDims, castDstType.getElementType());
791 
792  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
793  bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
794 
795  SmallVector<int64_t> dstDims =
796  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
797  dstDims.back() = dstDims.back() / shrinkRatio;
798  VectorType newCastDstType =
799  VectorType::get(dstDims, castDstType.getElementType());
800 
801  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
802  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
803 
804  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
805  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
806  insertOp.getStrides());
807 
808  return success();
809  }
810 };
811 
812 // Breaks down vector.bitcast op
813 //
814 // This transforms IR like:
815 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
816 // Into:
817 // %cst = vector.splat %c0_f32 : vector<4xf32>
818 // %1 = vector.extract_strided_slice %0 {
819 // offsets = [0], sizes = [4], strides = [1]
820 // } : vector<8xf16> to vector<4xf16>
821 // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
822 // %4 = vector.insert_strided_slice %2, %cst {
823 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
824 // %5 = vector.extract_strided_slice %0 {
825 // offsets = [4], sizes = [4], strides = [1]
826 // } : vector<8xf16> to vector<4xf16>
827 // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
828 // %7 = vector.insert_strided_slice %6, %cst {
829 // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
830 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
832 
833 public:
834  BreakDownVectorBitCast(MLIRContext *context,
835  std::function<bool(vector::BitCastOp)> controlFn,
836  PatternBenefit benefit)
837  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
838 
839  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
840  PatternRewriter &rewriter) const override {
841 
842  if (controlFn && !controlFn(bitcastOp))
843  return failure();
844 
845  VectorType castSrcType = bitcastOp.getSourceVectorType();
846  VectorType castDstType = bitcastOp.getResultVectorType();
847  assert(castSrcType.getRank() == castDstType.getRank());
848 
849  // This transformation builds on top of
850  // vector.{extract|insert}_strided_slice, which do not support
851  // extracting/inserting "scallable sub-vectors". Bail out.
852  if (castSrcType.isScalable())
853  return rewriter.notifyMatchFailure(bitcastOp,
854  "Scalable vectors are not supported");
855 
856  // Only support rank 1 case for now.
857  if (castSrcType.getRank() != 1)
858  return failure();
859 
860  int64_t castSrcLastDim = castSrcType.getShape().back();
861  int64_t castDstLastDim = castDstType.getShape().back();
862  // Require casting to less elements for now; other cases to be implemented.
863  if (castSrcLastDim < castDstLastDim)
864  return failure();
865 
866  assert(castSrcLastDim % castDstLastDim == 0);
867  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
868  // Nothing to do if it is already bitcasting to a single element.
869  if (castSrcLastDim == shrinkRatio)
870  return failure();
871 
872  Location loc = bitcastOp.getLoc();
873  Type elemType = castDstType.getElementType();
874  assert(elemType.isSignlessIntOrIndexOrFloat());
875 
876  Value zero = rewriter.create<arith::ConstantOp>(
877  loc, elemType, rewriter.getZeroAttr(elemType));
878  Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
879 
880  SmallVector<int64_t> sliceShape = {castDstLastDim};
881  SmallVector<int64_t> strides = {1};
882  VectorType newCastDstType =
883  VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
884  castDstType.getElementType());
885 
886  for (int i = 0, e = shrinkRatio; i < e; ++i) {
887  Value extracted = rewriter.create<ExtractStridedSliceOp>(
888  loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
889  sliceShape, strides);
890  Value bitcast =
891  rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
892  res = rewriter.create<InsertStridedSliceOp>(
893  loc, bitcast, res,
894  ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
895  }
896  rewriter.replaceOp(bitcastOp, res);
897  return success();
898  }
899 
900 private:
901  std::function<bool(BitCastOp)> controlFn;
902 };
903 
904 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
905 ///
906 /// Example:
907 /// ```
908 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
909 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
910 /// %r = arith.addi %a, %b : vector<1x4xindex>
911 /// ```
912 /// Gets converted to:
913 /// ```
914 /// %r = arith.addi %arg0, %arg1 : index
915 /// %b = vector.broadcast %r : index to vector<1x4xindex>
916 /// ```
917 ///
918 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
919 /// ops.
920 struct ReorderElementwiseOpsOnBroadcast final
921  : public OpTraitRewritePattern<OpTrait::Elementwise> {
923  LogicalResult matchAndRewrite(Operation *op,
924  PatternRewriter &rewriter) const override {
925  if (op->getNumResults() != 1)
926  return failure();
927  if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
928  return failure();
930  return rewriter.notifyMatchFailure(
931  op, "Op doesn't have ElementwiseMappableTraits");
932  if (op->getNumOperands() == 0)
933  return failure();
934  if (op->getResults()[0].getType() != op->getOperand(0).getType())
935  return rewriter.notifyMatchFailure(op,
936  "result and operand type mismatch");
937  if (isa<vector::FMAOp>(op)) {
938  return rewriter.notifyMatchFailure(
939  op,
940  "Op only accepts vector types - not supported as broadcast source "
941  "might be a scalar");
942  }
943 
944  // Get the type of the lhs operand
945  auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
946  if (!lhsBcastOrSplat ||
947  !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
948  return failure();
949  auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
950 
951  // Make sure that all operands are broadcast from identical types:
952  // * scalar (`vector.broadcast` + `vector.splat`), or
953  // * vector (`vector.broadcast`).
954  // Otherwise the re-ordering wouldn't be safe.
955  if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
956  auto bcast = val.getDefiningOp<vector::BroadcastOp>();
957  if (bcast)
958  return (bcast.getOperand().getType() == lhsBcastOrSplatType);
959  auto splat = val.getDefiningOp<vector::SplatOp>();
960  if (splat)
961  return (splat.getOperand().getType() == lhsBcastOrSplatType);
962  return false;
963  })) {
964  return failure();
965  }
966 
967  // Collect the source values before broadcasting
968  SmallVector<Value> srcValues;
969  srcValues.reserve(op->getNumOperands());
970  for (Value operand : op->getOperands()) {
971  srcValues.push_back(operand.getDefiningOp()->getOperand(0));
972  }
973 
974  // Create the "elementwise" Op
975  Operation *elementwiseOp =
976  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
977  lhsBcastOrSplatType, op->getAttrs());
978 
979  // Replace the original Op with the elementwise Op
980  auto vectorType = op->getResultTypes()[0];
981  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
982  op, vectorType, elementwiseOp->getResults());
983 
984  return success();
985  }
986 };
987 
988 /// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
989 /// This may result in cleaner code when extracting a single value
990 /// from multi-element vector and also to help canonicalize 1-element vectors to
991 /// scalars.
992 ///
993 /// Example:
994 /// ```
995 /// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
996 /// %1 = vector.extract %0[1] : f32 from vector<4xf32>
997 /// ```
998 /// Gets converted to:
999 /// ```
1000 /// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1001 /// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1002 /// %2 = arith.addf %0, %1 : f32
1003 /// ```
1004 class ExtractOpFromElementwise final
1005  : public OpRewritePattern<vector::ExtractOp> {
1006 public:
1008 
1009  LogicalResult matchAndRewrite(vector::ExtractOp op,
1010  PatternRewriter &rewriter) const override {
1011  Operation *eltwise = op.getVector().getDefiningOp();
1012 
1013  // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1014  // as it doesn't support scalars.
1015  if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1016  isa<vector::FMAOp>(eltwise))
1017  return rewriter.notifyMatchFailure(op, "not an elementwise op");
1018 
1019  if (eltwise->getNumResults() != 1)
1020  return rewriter.notifyMatchFailure(op, "expected single result");
1021 
1022  if (!eltwise->hasOneUse())
1023  return rewriter.notifyMatchFailure(op, "expected single op use");
1024 
1025  if (!llvm::all_equal(eltwise->getOperandTypes()))
1026  return rewriter.notifyMatchFailure(op, "operand types are different");
1027 
1028  Type dstType = op.getType();
1029 
1030  OpBuilder::InsertionGuard g(rewriter);
1031  rewriter.setInsertionPoint(eltwise);
1032 
1033  IRMapping mapping;
1034  Location loc = eltwise->getLoc();
1035  SmallVector<OpFoldResult> pos = op.getMixedPosition();
1036  for (Value arg : eltwise->getOperands()) {
1037  Value newArg = rewriter.create<vector::ExtractOp>(loc, arg, pos);
1038  mapping.map(arg, newArg);
1039  }
1040 
1041  Operation *newEltwise = rewriter.clone(*eltwise, mapping);
1042  newEltwise->getResult(0).setType(dstType);
1043 
1044  rewriter.replaceOp(op, newEltwise);
1045  rewriter.eraseOp(eltwise);
1046  return success();
1047  }
1048 };
1049 
1050 /// Check if the element type is suitable for vector.load/store sinking.
1051 /// Element type must be index or byte-aligned integer or floating-point type.
1052 static bool isSupportedMemSinkElementType(Type type) {
1053  if (isa<IndexType>(type))
1054  return true;
1055 
1056  return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1057 }
1058 
1059 /// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1060 /// Only index and byte-aligned integer and floating-point element types are
1061 /// supported for now.
1062 ///
1063 /// Example:
1064 /// ```
1065 /// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1066 /// vector.extract %0[1] : f32 from vector<4xf32>
1067 /// ```
1068 /// Gets converted to:
1069 /// ```
1070 /// %c1 = arith.constant 1 : index
1071 /// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1072 /// %1 = memref.load %arg0[%0] : memref<?xf32>
1073 /// ```
1074 class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1075 public:
1077 
1078  LogicalResult matchAndRewrite(vector::ExtractOp op,
1079  PatternRewriter &rewriter) const override {
1080  auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1081  if (!loadOp)
1082  return rewriter.notifyMatchFailure(op, "expected a load op");
1083 
1084  // Checking for single use so we won't duplicate load ops.
1085  if (!loadOp->hasOneUse())
1086  return rewriter.notifyMatchFailure(op, "expected single op use");
1087 
1088  VectorType loadVecType = loadOp.getVectorType();
1089  if (loadVecType.isScalable())
1090  return rewriter.notifyMatchFailure(op,
1091  "scalable vectors are not supported");
1092 
1093  MemRefType memType = loadOp.getMemRefType();
1094 
1095  // Non-byte-aligned types are tricky and may require special handling,
1096  // ignore them for now.
1097  if (!isSupportedMemSinkElementType(memType.getElementType()))
1098  return rewriter.notifyMatchFailure(op, "unsupported element type");
1099 
1100  int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1101  if (rankOffset < 0)
1102  return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1103 
1104  auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1105  int64_t finalRank = 0;
1106  if (extractVecType)
1107  finalRank = extractVecType.getRank();
1108 
1109  SmallVector<Value> indices = loadOp.getIndices();
1110  SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1111 
1112  // There may be memory stores between the load and the extract op, so we
1113  // need to make sure that the new load op is inserted at the same place as
1114  // the original load op.
1115  OpBuilder::InsertionGuard g(rewriter);
1116  rewriter.setInsertionPoint(loadOp);
1117  Location loc = loadOp.getLoc();
1118  ArithIndexingBuilder idxBuilderf(rewriter, loc);
1119  for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1120  OpFoldResult pos = extractPos[i - rankOffset];
1121  if (isConstantIntValue(pos, 0))
1122  continue;
1123 
1124  Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1125  indices[i] = idxBuilderf.add(indices[i], offset);
1126  }
1127 
1128  Value base = loadOp.getBase();
1129  if (extractVecType) {
1130  rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
1131  indices);
1132  } else {
1133  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1134  }
1135  // We checked for single use so we can safely erase the load op.
1136  rewriter.eraseOp(loadOp);
1137  return success();
1138  }
1139 };
1140 
1141 /// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1142 ///
1143 /// Example:
1144 /// ```
1145 /// %0 = vector.splat %arg2 : vector<1xf32>
1146 /// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1147 /// ```
1148 /// Gets converted to:
1149 /// ```
1150 /// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1151 /// ```
1152 class StoreOpFromSplatOrBroadcast final
1153  : public OpRewritePattern<vector::StoreOp> {
1154 public:
1156 
1157  LogicalResult matchAndRewrite(vector::StoreOp op,
1158  PatternRewriter &rewriter) const override {
1159  VectorType vecType = op.getVectorType();
1160  if (vecType.isScalable())
1161  return rewriter.notifyMatchFailure(op,
1162  "scalable vectors are not supported");
1163 
1164  if (isa<VectorType>(op.getMemRefType().getElementType()))
1165  return rewriter.notifyMatchFailure(
1166  op, "memrefs of vectors are not supported");
1167 
1168  if (vecType.getNumElements() != 1)
1169  return rewriter.notifyMatchFailure(
1170  op, "only 1-element vectors are supported");
1171 
1172  Operation *splat = op.getValueToStore().getDefiningOp();
1173  if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1174  return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
1175 
1176  // Checking for single use so we can remove splat.
1177  if (!splat->hasOneUse())
1178  return rewriter.notifyMatchFailure(op, "expected single op use");
1179 
1180  Value source = splat->getOperand(0);
1181  Value base = op.getBase();
1182  ValueRange indices = op.getIndices();
1183 
1184  if (isa<VectorType>(source.getType())) {
1185  rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1186  } else {
1187  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1188  }
1189  rewriter.eraseOp(splat);
1190  return success();
1191  }
1192 };
1193 
1194 // Helper that returns a vector comparison that constructs a mask:
1195 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1196 //
1197 // If `dim == 0` then the result will be a 0-D vector.
1198 //
1199 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1200 // much more compact, IR for this operation, but LLVM eventually
1201 // generates more elaborate instructions for this intrinsic since it
1202 // is very conservative on the boundary conditions.
1203 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1204  bool force32BitVectorIndices, int64_t dim,
1205  Value b, Value *off = nullptr) {
1206  auto loc = op->getLoc();
1207  // If we can assume all indices fit in 32-bit, we perform the vector
1208  // comparison in 32-bit to get a higher degree of SIMD parallelism.
1209  // Otherwise we perform the vector comparison using 64-bit indices.
1210  Type idxType =
1211  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1212  DenseIntElementsAttr indicesAttr;
1213  if (dim == 0 && force32BitVectorIndices) {
1214  indicesAttr = DenseIntElementsAttr::get(
1216  } else if (dim == 0) {
1217  indicesAttr = DenseIntElementsAttr::get(
1219  } else if (force32BitVectorIndices) {
1220  indicesAttr = rewriter.getI32VectorAttr(
1221  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1222  } else {
1223  indicesAttr = rewriter.getI64VectorAttr(
1224  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1225  }
1226  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
1227  // Add in an offset if requested.
1228  if (off) {
1229  Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1230  Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1231  indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1232  }
1233  // Construct the vector comparison.
1234  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1235  Value bounds =
1236  rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1237  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1238  bounds);
1239 }
1240 
1241 template <typename ConcreteOp>
1242 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1243 public:
1244  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1245  PatternBenefit benefit = 1)
1246  : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1247  force32BitVectorIndices(enableIndexOpt) {}
1248 
1249  LogicalResult matchAndRewrite(ConcreteOp xferOp,
1250  PatternRewriter &rewriter) const override {
1251  if (!xferOp.hasOutOfBoundsDim())
1252  return failure();
1253 
1254  if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1255  return failure();
1256 
1257  Location loc = xferOp->getLoc();
1258  VectorType vtp = xferOp.getVectorType();
1259 
1260  // Create the in-bounds mask with all elements between [0 .. dim - offset)
1261  // set and [dim - offset .. vector_length) unset.
1262  //
1263  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1264  // dimensions here.
1265  unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1266  Value off = xferOp.getIndices()[lastIndex];
1267  Value dim =
1268  vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
1269  Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1270  Value mask = rewriter.create<vector::CreateMaskOp>(
1271  loc,
1272  VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1273  vtp.getScalableDims()),
1274  b);
1275  if (xferOp.getMask()) {
1276  // Intersect the in-bounds with the mask specified as an op parameter.
1277  mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1278  }
1279 
1280  rewriter.modifyOpInPlace(xferOp, [&]() {
1281  xferOp.getMaskMutable().assign(mask);
1282  xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1283  });
1284 
1285  return success();
1286  }
1287 
1288 private:
1289  const bool force32BitVectorIndices;
1290 };
1291 
1292 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1293 class VectorCreateMaskOpConversion
1294  : public OpRewritePattern<vector::CreateMaskOp> {
1295 public:
1296  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1297  bool enableIndexOpt,
1298  PatternBenefit benefit = 1)
1299  : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1300  force32BitVectorIndices(enableIndexOpt) {}
1301 
1302  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1303  PatternRewriter &rewriter) const override {
1304  auto dstType = op.getType();
1305  if (cast<VectorType>(dstType).isScalable())
1306  return failure();
1307  int64_t rank = dstType.getRank();
1308  if (rank > 1)
1309  return failure();
1310  rewriter.replaceOp(
1311  op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1312  rank == 0 ? 0 : dstType.getDimSize(0),
1313  op.getOperand(0)));
1314  return success();
1315  }
1316 
1317 private:
1318  const bool force32BitVectorIndices;
1319 };
1320 
1321 /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1322 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1323  auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1324  // TODO: Support non-dense constant.
1325  if (!denseAttr)
1326  return false;
1327 
1328  assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1329  return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1330 }
1331 
1332 /// Folds a select operation between an all-true and all-false vector. For now,
1333 /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1334 ///
1335 /// %true = arith.constant dense<true> : vector<1xi1>
1336 /// %false = arith.constant dense<false> : vector<1xi1>
1337 /// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1338 /// =>
1339 /// %result = vector.broadcast %cond : i1 to vector<1xi1>
1340 ///
1341 /// InstCombine seems to handle vectors with multiple elements but not the
1342 /// single element ones.
1343 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1345 
1346  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1347  PatternRewriter &rewriter) const override {
1348  auto vecType = dyn_cast<VectorType>(selectOp.getType());
1349  if (!vecType || !vecType.getElementType().isInteger(1))
1350  return failure();
1351 
1352  // Only scalar conditions can be folded.
1353  Value cond = selectOp.getCondition();
1354  if (isa<VectorType>(cond.getType()))
1355  return failure();
1356 
1357  // TODO: Support n-D and scalable vectors.
1358  if (vecType.getRank() != 1 || vecType.isScalable())
1359  return failure();
1360 
1361  // TODO: Support vectors with multiple elements.
1362  if (vecType.getShape()[0] != 1)
1363  return failure();
1364 
1365  auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1366  if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1367  return failure();
1368 
1369  auto falseConst =
1370  selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1371  if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1372  return failure();
1373 
1374  // Replace select with its condition broadcasted to single element vector.
1375  auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1376  auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1377  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1378  return success();
1379  }
1380 };
1381 
1382 /// Returns the number of dims can be folded away from transfer ops. It returns
1383 /// a failure if it can not determine the number of dims to be folded.
1384 ///
1385 /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1386 /// `vectorType` is vector<16x16x1x1xf32>
1387 /// (there two inner most dims can be dropped by memref.subview ops)
1388 ///
1389 /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1390 /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1391 /// (only the inner most unit dim of `srcType` can be dropped)
1392 ///
1393 /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1394 /// `vectorType` is vector<16x16x1x[1]xf32>
1395 /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1396 /// unit")
1397 static FailureOr<size_t>
1398 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1399  SmallVector<int64_t> srcStrides;
1400  int64_t srcOffset;
1401  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1402  return failure();
1403 
1404  auto isUnitDim = [](VectorType type, int dim) {
1405  return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1406  };
1407 
1408  // According to vector.transfer_read/write semantics, the vector can be a
1409  // slice. Thus, we have to offset the check index with `rankDiff` in
1410  // `srcStrides` and source dim sizes.
1411  size_t result = 0;
1412  int rankDiff = srcType.getRank() - vectorType.getRank();
1413  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1414  // Check that the inner dim size is 1 for both memref type and vector slice.
1415  // It can be folded only if they are 1 and the stride is 1.
1416  int dim = vectorType.getRank() - i - 1;
1417  if (srcStrides[dim + rankDiff] != 1 ||
1418  srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1419  break;
1420  result++;
1421  }
1422  return result;
1423 }
1424 
1425 /// Drop inner most contiguous unit dimensions from transfer_read operand.
1426 class DropInnerMostUnitDimsTransferRead
1427  : public OpRewritePattern<vector::TransferReadOp> {
1429 
1430  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1431  PatternRewriter &rewriter) const override {
1432  // TODO: support 0-d corner case.
1433  if (readOp.getTransferRank() == 0)
1434  return failure();
1435 
1436  // TODO: support mask.
1437  if (readOp.getMask())
1438  return failure();
1439 
1440  auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1441  if (!srcType)
1442  return failure();
1443 
1444  if (!readOp.getPermutationMap().isMinorIdentity())
1445  return failure();
1446 
1447  auto targetType = readOp.getVectorType();
1448  if (targetType.getRank() <= 1)
1449  return failure();
1450 
1451  FailureOr<size_t> maybeDimsToDrop =
1452  getTransferFoldableInnerUnitDims(srcType, targetType);
1453  if (failed(maybeDimsToDrop))
1454  return failure();
1455 
1456  size_t dimsToDrop = maybeDimsToDrop.value();
1457  if (dimsToDrop == 0)
1458  return failure();
1459 
1460  auto inBounds = readOp.getInBoundsValues();
1461  auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1462  if (llvm::is_contained(droppedInBounds, false))
1463  return failure();
1464 
1465  auto resultTargetVecType =
1466  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1467  targetType.getElementType(),
1468  targetType.getScalableDims().drop_back(dimsToDrop));
1469 
1470  auto loc = readOp.getLoc();
1472  memref::getMixedSizes(rewriter, loc, readOp.getSource());
1473  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1474  rewriter.getIndexAttr(0));
1475  SmallVector<OpFoldResult> strides(srcType.getRank(),
1476  rewriter.getIndexAttr(1));
1477  MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1478  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1479  strides);
1480  ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1481  readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1482  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1483  loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1484  auto permMap = getTransferMinorIdentityMap(
1485  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1486  Value result = rewriter.create<vector::TransferReadOp>(
1487  loc, resultTargetVecType, rankedReducedView,
1488  readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1489  readOp.getPadding(),
1490  // TODO: support mask.
1491  /*mask=*/Value(), inBoundsAttr);
1492  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1493  result);
1494  return success();
1495  }
1496 };
1497 
1498 /// Drop inner most contiguous unit dimensions from transfer_write operand.
1499 /// E.g.,
1500 /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1501 /// {in_bounds = [true, true, true, true, true]}
1502 /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1503 ///
1504 /// will be replaced with
1505 ///
1506 /// %subview = memref.subview %arg0
1507 /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1508 /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1509 /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1510 /// to vector<1x16x16xf32>
1511 /// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1512 /// {in_bounds = [true, true, true]}
1513 /// : vector<1x16x16xf32>, memref<1x512x16xf32>
1514 ///
1515 /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1516 class DropInnerMostUnitDimsTransferWrite
1517  : public OpRewritePattern<vector::TransferWriteOp> {
1519 
1520  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1521  PatternRewriter &rewriter) const override {
1522  // TODO: support 0-d corner case.
1523  if (writeOp.getTransferRank() == 0)
1524  return failure();
1525 
1526  // TODO: support mask.
1527  if (writeOp.getMask())
1528  return failure();
1529 
1530  auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1531  if (!srcType)
1532  return failure();
1533 
1534  if (!writeOp.getPermutationMap().isMinorIdentity())
1535  return failure();
1536 
1537  auto targetType = writeOp.getVectorType();
1538  if (targetType.getRank() <= 1)
1539  return failure();
1540 
1541  FailureOr<size_t> maybeDimsToDrop =
1542  getTransferFoldableInnerUnitDims(srcType, targetType);
1543  if (failed(maybeDimsToDrop))
1544  return failure();
1545 
1546  size_t dimsToDrop = maybeDimsToDrop.value();
1547  if (dimsToDrop == 0)
1548  return failure();
1549 
1550  auto inBounds = writeOp.getInBoundsValues();
1551  auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1552  if (llvm::is_contained(droppedInBounds, false))
1553  return failure();
1554 
1555  auto resultTargetVecType =
1556  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1557  targetType.getElementType(),
1558  targetType.getScalableDims().drop_back(dimsToDrop));
1559 
1560  Location loc = writeOp.getLoc();
1562  memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1563  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1564  rewriter.getIndexAttr(0));
1565  SmallVector<OpFoldResult> strides(srcType.getRank(),
1566  rewriter.getIndexAttr(1));
1567  MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1568  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1569  strides);
1570  ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1571  writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1572 
1573  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1574  loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1575  auto permMap = getTransferMinorIdentityMap(
1576  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1577 
1578  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1579  loc, resultTargetVecType, writeOp.getVector());
1580  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1581  writeOp, shapeCast, rankedReducedView,
1582  writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1583  // TODO: support mask.
1584  /*mask=*/Value(), inBoundsAttr);
1585  return success();
1586  }
1587 };
1588 
1589 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1590 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1591 /// with the RHS transposed) lowering.
1592 struct CanonicalizeContractMatmulToMMT final
1593  : OpRewritePattern<vector::ContractionOp> {
1595 
1596  using FilterConstraintType =
1597  std::function<LogicalResult(vector::ContractionOp op)>;
1598 
1599  CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1600  FilterConstraintType constraint)
1601  : OpRewritePattern<vector::ContractionOp>(context, benefit),
1602  filter(std::move(constraint)) {}
1603 
1604  LogicalResult matchAndRewrite(vector::ContractionOp op,
1605  PatternRewriter &rewriter) const override {
1606  if (failed(filter(op)))
1607  return failure();
1608 
1609  Location loc = op.getLoc();
1610  Value lhs = op.getLhs();
1611  Value rhs = op.getRhs();
1612  Value res = op.getAcc();
1613 
1614  // Set up the parallel/reduction structure in right form.
1615  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1616  auto infer = [&](MapList m) {
1617  return AffineMap::inferFromExprList(m, op.getContext());
1618  };
1619  AffineExpr m;
1620  AffineExpr n;
1621  AffineExpr k;
1622  bindDims(rewriter.getContext(), m, n, k);
1623  static constexpr std::array<int64_t, 2> perm = {1, 0};
1624  auto iteratorTypes = op.getIteratorTypes().getValue();
1625  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1626  if (iteratorTypes.size() != 3 ||
1627  !vector::isParallelIterator(iteratorTypes[0]) ||
1628  !vector::isParallelIterator(iteratorTypes[1]) ||
1629  !vector::isReductionIterator(iteratorTypes[2]))
1630  return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1631 
1632  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1633  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1634  if (maps == canonicalForm)
1635  return rewriter.notifyMatchFailure(op, "already in the canonical form");
1636 
1637  // Create a vector transpose making sure to emit zero/sign-extend at the
1638  // end.
1639  auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1640  if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1641  Value trans =
1642  rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1643  VectorType newType =
1644  cast<VectorType>(trans.getType())
1645  .clone(cast<VectorType>(mat.getType()).getElementType());
1646  return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1647  }
1648  if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1649  Value trans =
1650  rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1651  VectorType newType =
1652  VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1653  cast<VectorType>(mat.getType()).getElementType());
1654  return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1655  }
1656  return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1657  };
1658 
1659  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1660  rhs = createTranspose(rhs);
1661  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1662  lhs = createTranspose(lhs);
1663  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1664  rhs = createTranspose(rhs);
1665  lhs = createTranspose(lhs);
1666  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1667  std::swap(rhs, lhs);
1668  rhs = createTranspose(rhs);
1669  lhs = createTranspose(lhs);
1670  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1671  std::swap(rhs, lhs);
1672  rhs = createTranspose(rhs);
1673  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1674  std::swap(lhs, rhs);
1675  lhs = createTranspose(lhs);
1676  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1677  std::swap(lhs, rhs);
1678  } else {
1679  return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1680  }
1681  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1682  op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1683  op.getIteratorTypes());
1684  return success();
1685  };
1686 
1687 private:
1688  FilterConstraintType filter;
1689 };
1690 
1691 /// Pattern to fold arithmetic extensions on floating point data types into
1692 /// vector contraction operations. linalg.matmul introduces arithmetic
1693 /// extensions on its operands. Please mlir snippets below for more details.
1694 /// ```mlir
1695 /// "linalg.matmul"(%lhs, %rhs, %acc) ({
1696 /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1697 /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1698 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1699 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1700 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1701 /// "linalg.yield"(%acc) : (f32) -> ()
1702 /// })
1703 /// ```
1704 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1705 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1706 /// This pattern folds the arithmetic extensions into the vector contraction and
1707 /// enables the usage of native mixed precision Tensor Core instructions.
1708 template <typename ExtOp>
1709 struct FoldArithExtIntoContractionOp
1710  : public OpRewritePattern<vector::ContractionOp> {
1712 
1713  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1714  PatternRewriter &rewriter) const override {
1715 
1716  auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1717  auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1718 
1719  if (!lhsDefOp || !rhsDefOp) {
1720  return rewriter.notifyMatchFailure(contractOp,
1721  "no defining op on contract operands");
1722  }
1723 
1724  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1725  contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1726  contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1727  contractOp.getIteratorTypesAttr());
1728 
1729  return success();
1730  }
1731 };
1732 
1733 /// Pattern to fold chained reduction to a series of vector additions and a
1734 /// final reduction. This form should require fewer subgroup operations.
1735 ///
1736 /// ```mlir
1737 /// %a = vector.reduction <add> %x, %acc
1738 /// %b = vector.reduction <add> %y, %a
1739 /// ==>
1740 /// %a = arith.addf %x, %y
1741 /// %b = vector.reduction <add> %a, %acc
1742 /// ```
1743 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1745 
1746  LogicalResult matchAndRewrite(vector::ReductionOp op,
1747  PatternRewriter &rewriter) const override {
1748  // TODO: Handle other combining kinds.
1749  if (op.getKind() != vector::CombiningKind::ADD)
1750  return failure();
1751 
1752  // Accumulator is optional.
1753  Value acc = op.getAcc();
1754  if (!acc)
1755  return failure();
1756 
1757  if (!acc.getType().isIntOrFloat())
1758  return failure();
1759 
1760  auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1761  if (!parentReduction)
1762  return failure();
1763 
1764  Location loc = op.getLoc();
1765  Value vAdd;
1766  if (isa<IntegerType>(acc.getType())) {
1767  vAdd = rewriter.createOrFold<arith::AddIOp>(
1768  loc, parentReduction.getVector(), op.getVector());
1769  } else {
1770  vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1771  op.getVector());
1772  }
1773  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1774  parentReduction.getAcc());
1775  return success();
1776  }
1777 };
1778 
1779 // Helper function dropping unit non-scalable dimension from a VectorType
1780 // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1781 // dimensions are not dropped. Folding such dimensions would require "shifting"
1782 // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1783 // vector<[4]xf32>). This could be implemented in the future.
1784 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1785  auto inVecShape = inVecTy.getShape();
1786  SmallVector<int64_t> newShape;
1787  SmallVector<bool> newScalableDims;
1788  for (auto [dim, isScalable] :
1789  llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1790  if (dim == 1 && !isScalable)
1791  continue;
1792 
1793  newShape.push_back(dim);
1794  newScalableDims.push_back(isScalable);
1795  }
1796  // All dims have been dropped, return vector<1xeType>.
1797  if (newShape.empty()) {
1798  newShape.push_back(1);
1799  newScalableDims.push_back(false);
1800  }
1801 
1802  return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1803 }
1804 
1805 /// For vectors with at least one unit dim, replaces:
1806 /// elementwise(a, b)
1807 /// with:
1808 /// sc_a = shape_cast(a)
1809 /// sc_b = shape_cast(b)
1810 /// res = elementwise(sc_a, sc_b)
1811 /// return shape_cast(res)
1812 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1813 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1814 /// required to be rank > 1.
1815 ///
1816 /// Ex:
1817 /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1818 /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1819 ///
1820 /// gets converted to:
1821 ///
1822 /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1823 /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1824 /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1825 /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1826 /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1827 ///
1828 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1829 /// `%cast`.
1830 struct DropUnitDimFromElementwiseOps final
1831  : public OpTraitRewritePattern<OpTrait::Elementwise> {
1833  LogicalResult matchAndRewrite(Operation *op,
1834  PatternRewriter &rewriter) const override {
1835  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1836  return failure();
1837 
1838  auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1839  if (!resultVectorType)
1840  return failure();
1841 
1842  // Check the operand pre-conditions. For `Elementwise` ops all operands are
1843  // guaranteed to have identical shapes (with some exceptions such as
1844  // `arith.select`) and it suffices to only check one of them.
1845  auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1846  if (!sourceVectorType)
1847  return failure();
1848  if (sourceVectorType.getRank() < 2)
1849  return failure();
1850 
1851  SmallVector<Value> newOperands;
1852  auto loc = op->getLoc();
1853  for (auto operand : op->getOperands()) {
1854  auto opVectorType = cast<VectorType>(operand.getType());
1855  auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1856  if (newVType == opVectorType)
1857  return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1858 
1859  auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1860  newOperands.push_back(opSC);
1861  }
1862 
1863  VectorType newResultVectorType =
1864  dropNonScalableUnitDimFromType(resultVectorType);
1865  // Create an updated elementwise Op without unit dim.
1866  Operation *elementwiseOp =
1867  rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1868  newResultVectorType, op->getAttrs());
1869 
1870  // Restore the unit dim by applying vector.shape_cast to the result.
1871  rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1872  elementwiseOp->getResult(0));
1873 
1874  return success();
1875  }
1876 };
1877 
1878 /// A pattern to drop unit dims from vector.transpose.
1879 ///
1880 /// Example:
1881 ///
1882 /// BEFORE:
1883 /// ```mlir
1884 /// %transpose = vector.transpose %vector, [3, 0, 1, 2]
1885 /// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1886 /// ```
1887 ///
1888 /// AFTER:
1889 /// ```mlir
1890 /// %dropDims = vector.shape_cast %vector
1891 /// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1892 /// %transpose = vector.transpose %0, [1, 0]
1893 /// : vector<4x[4]xf32> to vector<[4]x4xf32>
1894 /// %restoreDims = vector.shape_cast %transpose
1895 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1896 /// ```
1897 struct DropUnitDimsFromTransposeOp final
1898  : OpRewritePattern<vector::TransposeOp> {
1900 
1901  LogicalResult matchAndRewrite(vector::TransposeOp op,
1902  PatternRewriter &rewriter) const override {
1903  VectorType sourceType = op.getSourceVectorType();
1904  VectorType sourceTypeWithoutUnitDims =
1905  dropNonScalableUnitDimFromType(sourceType);
1906 
1907  if (sourceType == sourceTypeWithoutUnitDims)
1908  return failure();
1909 
1910  // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1911  auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
1912  SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
1913  int64_t droppedDims = 0;
1914  for (auto [i, dim] : llvm::enumerate(sourceDims)) {
1915  droppedDimsBefore[i] = droppedDims;
1916  if (dim == std::make_tuple(1, false))
1917  ++droppedDims;
1918  }
1919 
1920  // Drop unit dims from transpose permutation.
1921  ArrayRef<int64_t> perm = op.getPermutation();
1922  SmallVector<int64_t> newPerm;
1923  for (int64_t idx : perm) {
1924  if (sourceDims[idx] == std::make_tuple(1, false))
1925  continue;
1926  newPerm.push_back(idx - droppedDimsBefore[idx]);
1927  }
1928 
1929  // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
1930  // type when the dimensions are unit dimensions. In this case, the newPerm
1931  // should be [0].
1932  if (newPerm.empty()) {
1933  newPerm.push_back(0);
1934  }
1935 
1936  Location loc = op.getLoc();
1937  // Drop the unit dims via shape_cast.
1938  auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
1939  loc, sourceTypeWithoutUnitDims, op.getVector());
1940  // Create the new transpose.
1941  auto transposeWithoutUnitDims =
1942  rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1943  // Restore the unit dims via shape cast.
1944  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
1945  op, op.getResultVectorType(), transposeWithoutUnitDims);
1946 
1947  return success();
1948  }
1949 };
1950 
1951 /// A pattern to drop unit dims from the iter_args of an scf.for.
1952 ///
1953 /// Example:
1954 ///
1955 /// BEFORE:
1956 /// ```mlir
1957 /// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
1958 /// ...
1959 /// scf.yield %
1960 /// }
1961 /// ```
1962 ///
1963 /// AFTER:
1964 /// ```mlir
1965 /// %drop = vector.shape_cast %init
1966 /// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
1967 /// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
1968 /// %new_iter = vector.shape_cast %iter
1969 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1970 /// ...
1971 /// }
1972 /// %res = vector.shape_cast %new_loop
1973 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1974 /// ```
1975 struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
1977 
1978  LogicalResult matchAndRewrite(scf::ForOp forOp,
1979  PatternRewriter &rewriter) const override {
1980  /// Find the first iter_arg with droppable unit dims. Further applications
1981  /// of this pattern will apply to later arguments.
1982  for (OpOperand &operand : forOp.getInitArgsMutable()) {
1983  auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1984  if (!vectorType)
1985  continue;
1986 
1987  VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1988  if (vectorType == newVectorType)
1989  continue;
1990 
1991  // Create a new ForOp with that iter operand replaced.
1992  auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
1993  return b.create<vector::ShapeCastOp>(loc, type, source);
1994  };
1995 
1996  Value replacement =
1997  castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1998  rewriter.replaceOp(forOp,
1999  replaceAndCastForOpIterArg(rewriter, forOp, operand,
2000  replacement, castFn));
2001  return success();
2002  }
2003  return failure();
2004  }
2005 };
2006 
2007 /// Pattern to eliminate redundant zero-constants added to reduction operands.
2008 /// It's enough for there to be one initial zero value, so we can eliminate the
2009 /// extra ones that feed into `vector.reduction <add>`. These get created by the
2010 /// `ChainedReduction` pattern.
2011 ///
2012 /// ```mlir
2013 /// %a = arith.addf %x, %zero
2014 /// %b = arith.addf %a, %y
2015 /// %c = vector.reduction <add> %b, %acc
2016 /// ==>
2017 /// %b = arith.addf %a, %y
2018 /// %c = vector.reduction <add> %b, %acc
2019 /// ```
2020 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
2022 
2023  LogicalResult matchAndRewrite(vector::ReductionOp op,
2024  PatternRewriter &rewriter) const override {
2025  // TODO: Handle other reduction kinds and their identity values.
2026  if (op.getKind() != vector::CombiningKind::ADD)
2027  return failure();
2028 
2029  Type elemType = op.getSourceVectorType().getElementType();
2030  // The integer case should be handled by `arith.addi` folders, only check
2031  // for floats here.
2032  if (!isa<FloatType>(elemType))
2033  return failure();
2034 
2035  auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2036  if (!vAdd)
2037  return failure();
2038  auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2039  if (!addLhs)
2040  return failure();
2041 
2042  if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
2043  return failure();
2044 
2045  auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
2046  vAdd.getRhs());
2047  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
2048  op.getAcc());
2049  return success();
2050  }
2051 };
2052 
2053 /// Example:
2054 /// ```
2055 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
2056 /// ```
2057 /// is transformed into:
2058 /// ```
2059 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
2060 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
2061 /// %a = arith.addf %y, %z : f32
2062 /// ```
2063 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
2064  BreakDownVectorReduction(MLIRContext *context,
2065  unsigned maxNumElementsToExtract,
2066  PatternBenefit benefit)
2067  : OpRewritePattern(context, benefit),
2068  maxNumElementsToExtract(maxNumElementsToExtract) {}
2069 
2070  LogicalResult matchAndRewrite(vector::ReductionOp op,
2071  PatternRewriter &rewriter) const override {
2072  VectorType type = op.getSourceVectorType();
2073  if (type.isScalable() || op.isMasked())
2074  return failure();
2075  assert(type.getRank() == 1 && "Expected a 1-d vector");
2076 
2077  int64_t numElems = type.getNumElements();
2078  if (numElems > maxNumElementsToExtract) {
2079  return rewriter.notifyMatchFailure(
2080  op, llvm::formatv("has too many vector elements ({0}) to break down "
2081  "(max allowed: {1})",
2082  numElems, maxNumElementsToExtract));
2083  }
2084 
2085  Location loc = op.getLoc();
2086  SmallVector<Value> extracted(numElems, nullptr);
2087  for (auto [idx, extractedElem] : llvm::enumerate(extracted))
2088  extractedElem = rewriter.create<vector::ExtractOp>(
2089  loc, op.getVector(), static_cast<int64_t>(idx));
2090 
2091  Value res = extracted.front();
2092  for (auto extractedElem : llvm::drop_begin(extracted))
2093  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
2094  extractedElem, op.getFastmathAttr());
2095  if (Value acc = op.getAcc())
2096  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
2097  op.getFastmathAttr());
2098 
2099  rewriter.replaceOp(op, res);
2100  return success();
2101  }
2102 
2103 private:
2104  unsigned maxNumElementsToExtract = 0;
2105 };
2106 
2107 /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
2108 /// B)`.
2109 /// Example:
2110 /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
2111 /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
2112 /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
2113 /// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
2114 ///
2115 /// Becomes :
2116 ///
2117 /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
2118 ///
2119 /// Supports only 1D-to-2D broadcasts. The following cases are not supported.
2120 /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
2121 /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
2122 /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
2123 template <typename MulOpType>
2124 struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
2126  // Returns whether a vector.broadcast matches requirements for an outerproduct
2127  // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
2128  bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
2129  // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
2130  // shape_casts/broadcasts which does not belong in this pattern.
2131  if (!broadcastOp.computeBroadcastedUnitDims().empty())
2132  return false;
2133  // Avoid broadcast like f32 or vector<f32> -> ResType
2134  auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2135  return srcType && srcType.getRank() != 2;
2136  }
2137 
2138  LogicalResult matchAndRewrite(MulOpType mulOp,
2139  PatternRewriter &rewriter) const override {
2140  auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
2141  if (!resType)
2142  return failure();
2143  if (resType.getRank() != 2)
2144  return failure();
2145  /// If operandA can be written as tr(broadcast(A)) and operandB as
2146  /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2147  /// vector.outerproduct(A, B). Returns failure() otherwise.
2148  auto matchOuterProduct =
2149  [&](Value operandA,
2150  Value operandB) -> FailureOr<vector::OuterProductOp> {
2151  auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2152  if (!transposedLhs)
2153  return failure();
2154  // Fail unless this is a true 2-D matrix transpose.
2155  ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2156  if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2157  return failure();
2158 
2159  auto broadcastedLhs =
2160  transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2161  if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2162  return failure();
2163 
2164  auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2165  if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2166  return failure();
2167 
2168  return rewriter.create<vector::OuterProductOp>(
2169  mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2170  broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2171  };
2172 
2173  Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2174  auto maybeOuterP = matchOuterProduct(lhs, rhs);
2175  // Handle commutativity, the transposed op is the outerproduct LHS.
2176  if (failed(maybeOuterP))
2177  maybeOuterP = matchOuterProduct(rhs, lhs);
2178  if (failed(maybeOuterP))
2179  return failure();
2180  rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2181  return success();
2182  }
2183 };
2184 
2185 } // namespace
2186 
2189  patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2190  FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2191  patterns.getContext());
2192 }
2193 
2194 void mlir::vector::populateVectorMaskMaterializationPatterns(
2195  RewritePatternSet &patterns, bool force32BitVectorIndices,
2196  PatternBenefit benefit) {
2197  patterns.add<VectorCreateMaskOpConversion,
2198  MaterializeTransferMask<vector::TransferReadOp>,
2199  MaterializeTransferMask<vector::TransferWriteOp>>(
2200  patterns.getContext(), force32BitVectorIndices, benefit);
2201  patterns.add<FoldI1Select>(patterns.getContext(), benefit);
2202 }
2203 
2204 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2206  // TODO: Consider either:
2207  // * including DropInnerMostUnitDimsTransferRead and
2208  // DropInnerMostUnitDimsTransferWrite, or
2209  // * better naming to distinguish this and
2210  // populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
2211  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2212  DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
2213 }
2214 
2215 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2217  patterns.add<BubbleDownVectorBitCastForExtract,
2218  BubbleDownBitCastForStridedSliceExtract,
2219  BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2220  patterns.getContext(), benefit);
2221 }
2222 
2223 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2225  std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2226  patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2227  std::move(controlFn), benefit);
2228 }
2229 
2232  std::function<LogicalResult(vector::ContractionOp)> constraint,
2233  PatternBenefit benefit) {
2234  patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2235  std::move(constraint));
2236 }
2237 
2240  patterns.add<MultiReduceToContract, CombineContractBroadcast,
2241  CombineContractABTranspose, CombineContractResultTranspose>(
2242  patterns.getContext(), benefit);
2243 }
2244 
2248  patterns.add<DropInnerMostUnitDimsTransferRead,
2249  DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
2250  benefit);
2251 }
2252 
2254  PatternBenefit benefit) {
2255  patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2256  ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2257  patterns.getContext(), benefit);
2258 }
2259 
2260 void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2261  PatternBenefit benefit) {
2262  // TODO: Consider converting these patterns to canonicalizations.
2263  patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2264  patterns.getContext(), benefit);
2265 }
2266 
2267 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2269  patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2270  patterns.add<ReduceRedundantZero>(patterns.getContext(),
2271  PatternBenefit(benefit.getBenefit() + 1));
2272 }
2273 
2274 void mlir::vector::populateBreakDownVectorReductionPatterns(
2275  RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2276  PatternBenefit benefit) {
2277  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
2278  maxNumElementsToExtract, benefit);
2279 }
2280 
2283  patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2284  FoldArithToVectorOuterProduct<arith::MulIOp>>(
2285  patterns.getContext());
2286 }
2287 
2288 //===----------------------------------------------------------------------===//
2289 // TableGen'd enum attribute definitions
2290 //===----------------------------------------------------------------------===//
2291 
2292 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:402
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:118
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:124
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:346
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:348
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:850
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
type_range getType() const
Definition: ValueRange.cpp:42
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:78
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:781
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Definition: VectorUtils.h:124
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:154
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:41
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:399
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
Definition: AffineMap.cpp:717
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:930
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
ArithBuilder specialized specifically for tensor/memref indexing calculations.
Definition: Utils.h:126
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:334
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323