MLIR  20.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <cassert>
16 #include <cstdint>
17 #include <functional>
18 #include <optional>
19 #include <type_traits>
20 
34 #include "mlir/IR/BuiltinTypes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
41 
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/raw_ostream.h"
49 
50 #define DEBUG_TYPE "vector-to-vector"
51 
52 using namespace mlir;
53 using namespace mlir::vector;
54 
55 template <typename IntType>
56 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
57  return llvm::to_vector<4>(llvm::map_range(
58  arrayAttr.getAsRange<IntegerAttr>(),
59  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
60 }
61 
62 // Helper to find an index in an affine map.
63 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
64  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
65  int64_t idx = map.getDimPosition(i);
66  if (idx == index)
67  return i;
68  }
69  return std::nullopt;
70 }
71 
72 namespace {
73 
74 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
75 //
76 // Example:
77 //
78 // The following MLIR with cancelling ShapeCastOps:
79 //
80 // %0 = source : vector<5x4x2xf32>
81 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
82 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
83 // %3 = user %2 : vector<5x4x2xf32>
84 //
85 // Should canonicalize to the following:
86 //
87 // %0 = source : vector<5x4x2xf32>
88 // %1 = user %0 : vector<5x4x2xf32>
89 //
90 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
92 
93  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
94  PatternRewriter &rewriter) const override {
95  // Check if 'shapeCastOp' has vector source/result type.
96  auto sourceVectorType =
97  dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
98  auto resultVectorType =
99  dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
100  if (!sourceVectorType || !resultVectorType)
101  return failure();
102 
103  // Check if shape cast op source operand is also a shape cast op.
104  auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105  shapeCastOp.getSource().getDefiningOp());
106  if (!sourceShapeCastOp)
107  return failure();
108  auto operandSourceVectorType =
109  cast<VectorType>(sourceShapeCastOp.getSource().getType());
110  auto operandResultVectorType = sourceShapeCastOp.getType();
111 
112  // Check if shape cast operations invert each other.
113  if (operandSourceVectorType != resultVectorType ||
114  operandResultVectorType != sourceVectorType)
115  return failure();
116 
117  rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
118  return success();
119  }
120 };
121 
122 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
123 /// Ex:
124 /// ```
125 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
126 /// %1 = vector.multi_reduction add, %0 [1]
127 /// : vector<8x32x16xf32> to vector<8x16xf32>
128 /// ```
129 /// Gets converted to:
130 /// ```
131 /// %1 = vector.contract {indexing_maps = [
132 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
133 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
135 /// iterator_types = ["parallel", "parallel", "reduction"],
136 /// kind = add} %0, %arg1, %cst_f0
137 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138 /// ```
139 struct MultiReduceToContract
140  : public OpRewritePattern<vector::MultiDimReductionOp> {
142 
143  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
144  PatternRewriter &rewriter) const override {
145  if (reduceOp.getKind() != vector::CombiningKind::ADD)
146  return failure();
147  Operation *mulOp = reduceOp.getSource().getDefiningOp();
148  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
149  return failure();
150  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
151  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
153  SmallVector<vector::IteratorType> iteratorTypes;
154  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
155  if (!isReduceDim.value()) {
156  iteratorTypes.push_back(vector::IteratorType::parallel);
157  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
158  } else {
159  iteratorTypes.push_back(vector::IteratorType::reduction);
160  }
161  }
162  auto dstMap =
163  AffineMap::get(/*dimCount=*/reductionMask.size(),
164  /*symbolCount=*/0, exprs, reduceOp.getContext());
165  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
166  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
167  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
168  rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
169  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
170  return IteratorTypeAttr::get(rewriter.getContext(), t);
171  }))));
172  return success();
173  }
174 };
175 
176 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
177 /// Ex:
178 /// ```
179 /// %0 = vector.transpose %arg0, [2, 0, 1]
180 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
181 /// %1 = vector.contract {indexing_maps = [
182 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
183 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
184 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
185 /// iterator_types = ["parallel", "parallel", "reduction"],
186 /// kind = add} %0, %arg1, %cst_f0
187 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
188 /// ```
189 /// Gets converted to:
190 /// ```
191 /// %1 = vector.contract {indexing_maps = [
192 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
193 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
194 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
195 /// iterator_types = ["parallel", "parallel", "reduction"],
196 /// kind = add} %arg0, %arg1, %cst_f0
197 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
198 /// ```
199 struct CombineContractABTranspose final
200  : public OpRewritePattern<vector::ContractionOp> {
202 
203  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
204  PatternRewriter &rewriter) const override {
206  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
207  Value lhs = contractOp.getLhs();
208  Value rhs = contractOp.getRhs();
209  size_t index = 0;
210  bool changed = false;
211  for (Value *operand : {&lhs, &rhs}) {
212  AffineMap &map = maps[index++];
213  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
214  if (!transposeOp)
215  continue;
216  AffineMap permutationMap = AffineMap::getPermutationMap(
217  transposeOp.getPermutation(), contractOp.getContext());
218  map = inversePermutation(permutationMap).compose(map);
219  *operand = transposeOp.getVector();
220  changed = true;
221  }
222  if (!changed)
223  return failure();
224  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
225  contractOp, lhs, rhs, contractOp.getAcc(),
226  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
227  return success();
228  }
229 };
230 
231 /// Merges accumulator and result transposes into contract.
232 ///
233 /// For example:
234 /// ```mlir
235 /// %accT = vector.transpose %acc, [0, 2, 1]
236 /// : vector<2x8x4xf32> to vector<2x4x8xf32>
237 /// %contract = vector.contract {
238 /// indexing_maps = [
239 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
240 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
241 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
242 /// ],
243 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
244 /// kind = #vector.kind<add>
245 /// } %lhs, %rhs, %accT
246 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
247 /// %0 = vector.transpose %contract, [0, 2, 1]
248 /// : vector<2x4x8xf32> to vector<2x8x4>
249 /// ```
250 /// Becomes:
251 /// ```mlir
252 /// %0 = vector.contract {
253 /// indexing_maps = [
254 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
255 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
256 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
257 /// ],
258 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
259 /// kind = #vector.kind<add>
260 /// } %lhs, %rhs, %acc
261 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
262 /// ```
263 struct CombineContractResultTranspose final
264  : public OpRewritePattern<vector::TransposeOp> {
266 
267  LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
268  PatternRewriter &rewriter) const override {
269  auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
270  if (!contractOp || !contractOp->hasOneUse())
271  return failure();
272 
273  auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
274  if (!accTOp)
275  return failure();
276 
277  MLIRContext *context = contractOp.getContext();
278  auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
279  AffineMap contractMap = maps.back();
280 
281  // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
282  // To index into A in contract, we need revert(f)(g(C)) -> A.
283  auto accTMap =
284  AffineMap::getPermutationMap(accTOp.getPermutation(), context);
285 
286  // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
287  // To index into E in contract, we need h(g(C)) -> E.
288  auto resTMap =
289  AffineMap::getPermutationMap(resTOp.getPermutation(), context);
290  auto combinedResMap = resTMap.compose(contractMap);
291 
292  // The accumulator and result share the same indexing map. So they should be
293  // the same to be able to merge. This means combinedResMap is the same as
294  // inversePermutation(accTMap).compose(contractMap), which means
295  if (inversePermutation(accTMap) != resTMap)
296  return failure();
297  maps.back() = combinedResMap;
298 
299  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
300  resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
301  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
302  return success();
303  }
304 };
305 
306 /// Merge BroadcastOp into ContractionOp user.
307 /// Ex:
308 /// ```
309 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
310 /// %1 = vector.contract {indexing_maps = [
311 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
312 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
313 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
314 /// iterator_types = ["parallel", "parallel", "reduction"],
315 /// kind = add} %0, %arg1, %cst_f0
316 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
317 /// ```
318 /// Gets converted to:
319 /// ```
320 /// %1 = vector.contract {indexing_maps = [
321 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
322 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
323 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
324 /// iterator_types = ["parallel", "parallel", "reduction"],
325 /// kind = add} %arg0, %arg1, %cst_f0
326 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
327 /// ```
328 struct CombineContractBroadcast
329  : public OpRewritePattern<vector::ContractionOp> {
331 
332  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
333  PatternRewriter &rewriter) const override {
335  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
336  Value lhs = contractOp.getLhs();
337  Value rhs = contractOp.getRhs();
338  size_t index = 0;
339  bool changed = false;
340  for (Value *operand : {&lhs, &rhs}) {
341  AffineMap &map = maps[index++];
342  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
343  if (!broadcast)
344  continue;
345  // contractionOp can only take vector as operands.
346  auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
347  if (!srcType ||
348  srcType.getRank() == broadcast.getResultVectorType().getRank())
349  continue;
350  int64_t rankDiff =
351  broadcast.getResultVectorType().getRank() - srcType.getRank();
352  bool innerDimBroadcast = false;
353  SmallVector<AffineExpr> originalDims;
354  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
355  if (dim.value() != broadcast.getResultVectorType().getDimSize(
356  rankDiff + dim.index())) {
357  innerDimBroadcast = true;
358  break;
359  }
360  originalDims.push_back(
361  rewriter.getAffineDimExpr(dim.index() + rankDiff));
362  }
363  // Contract doesn't support inner dimension broadcast. Once this is
364  // relaxed we can remove this case.
365  if (innerDimBroadcast)
366  continue;
367 
368  // It would be incorrect to fold a broadcast onto a reduction dimension
369  // of non-unit size.
370  bool nonUnitDimReductionBroadcast = false;
371  for (int64_t i = 0; i < rankDiff; ++i) {
372  if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
373  isReductionIterator(contractOp.getIteratorTypes()
374  .getValue()[map.getDimPosition(i)])) {
375  nonUnitDimReductionBroadcast = true;
376  break;
377  }
378  }
379  if (nonUnitDimReductionBroadcast)
380  continue;
381 
382  AffineMap broadcastMap =
383  AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
384  originalDims, contractOp.getContext());
385  map = broadcastMap.compose(map);
386  *operand = broadcast.getSource();
387  changed = true;
388  }
389 
390  if (!changed)
391  return failure();
392 
393  // Determine which dims are usused, now that the maps have been composed
394  // with the broadcast maps.
395  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
396  // Compress unused dims.
397  for (auto &m : maps)
398  m = compressDims(m, unusedDimsBitVector);
399  // Compute the combined iterators.
400  SmallVector<Attribute> iterators;
401  for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
402  if (!unusedDimsBitVector.test(i))
403  iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
404  }
405  // Check that compressing unused dims isn't removing all reduction dimension
406  // pairs. For example, if the vector.contract had only one reduction
407  // iterator and that was a unit-dimension created by a broadcast,
408  // then we should bail here, otherwise we would create a contract without
409  // a reduction dimension pair.
410  bool hasReductionIteratorApplyingOnBothSides = false;
411  for (unsigned i = 0; i < iterators.size(); ++i) {
412  if (!isReductionIterator(iterators[i]))
413  continue;
414  if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
415  hasReductionIteratorApplyingOnBothSides = true;
416  break;
417  }
418  }
419  if (!hasReductionIteratorApplyingOnBothSides)
420  return failure();
421 
422  // If the compressed maps have a dimension that is not used by either LHS or
423  // RHS then the ContractionOp verifier would fail.
424  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
425  return failure();
426  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
427  contractOp, lhs, rhs, contractOp.getAcc(),
428  rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
429  return success();
430  }
431 };
432 
433 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
434 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
435 /// casting ops are around these operations.
436 /// Ex:
437 /// ```
438 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
439 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
440 /// ```
441 /// Gets converted to:
442 /// ```
443 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
444 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
445 /// ```
446 struct ReorderCastOpsOnBroadcast
447  : public OpInterfaceRewritePattern<CastOpInterface> {
449 
450  LogicalResult matchAndRewrite(CastOpInterface op,
451  PatternRewriter &rewriter) const override {
452  if (op->getNumOperands() != 1)
453  return failure();
454  auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
455  if (!bcastOp)
456  return failure();
457 
458  Type castResTy = getElementTypeOrSelf(op->getResult(0));
459  if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
460  castResTy = vecTy.clone(castResTy);
461  auto *castOp =
462  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
463  bcastOp.getSource(), castResTy, op->getAttrs());
464  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
465  op, op->getResult(0).getType(), castOp->getResult(0));
466  return success();
467  }
468 };
469 
470 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
471 /// transpose ops and contraction ops closer, which kicks in
472 /// CombineContractABTranspose pattern when elementwise ops are between these
473 /// operations. Ex:
474 /// ```
475 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
476 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
477 /// %r = arith.addf %at, %bt : vector<2x4xf32>
478 /// ```
479 /// Gets converted to:
480 /// ```
481 /// %0 = arith.addf %a, %b : vector<4x2xf32>
482 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
483 /// ```
484 struct ReorderElementwiseOpsOnTranspose final
485  : public OpTraitRewritePattern<OpTrait::Elementwise> {
487  LogicalResult matchAndRewrite(Operation *op,
488  PatternRewriter &rewriter) const override {
489  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
490  return failure();
491 
492  // Make sure all operands are transpose/constant ops and collect their
493  // transposition maps.
494  SmallVector<ArrayRef<int64_t>> transposeMaps;
495  transposeMaps.reserve(op->getNumOperands());
496  // Record the initial type before transposition. We'll use its shape later.
497  // Any type will do here as we will check all transpose maps are the same.
498  VectorType srcType;
499  for (Value operand : op->getOperands()) {
500  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
501  if (transposeOp) {
502  transposeMaps.push_back(transposeOp.getPermutation());
503  srcType = transposeOp.getSourceVectorType();
504  } else if (!matchPattern(operand, m_Constant())) {
505  return failure();
506  }
507  }
508  if (transposeMaps.empty())
509  return failure();
510  // This is an elementwise op, so all transposed operands should have the
511  // same type. We need to additionally check that all transposes uses the
512  // same map.
513  if (!llvm::all_equal(transposeMaps))
514  return rewriter.notifyMatchFailure(op, "different transpose map");
515 
516  SmallVector<Value> srcValues;
517  srcValues.reserve(op->getNumOperands());
518 
519  // If there are constant operands, we need to insert inverse transposes for
520  // them. Calculate the inverse order first.
521  auto order = transposeMaps.front();
522  SmallVector<int64_t> invOrder(order.size());
523  for (int i = 0, e = order.size(); i < e; ++i)
524  invOrder[order[i]] = i;
525 
526  for (Value operand : op->getOperands()) {
527  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
528  if (transposeOp) {
529  srcValues.push_back(transposeOp.getVector());
530  } else {
531  // This is a constant. Create a reverse transpose op for it.
532  auto vectorType =
533  srcType.clone(cast<VectorType>(operand.getType()).getElementType());
534  srcValues.push_back(rewriter.create<vector::TransposeOp>(
535  operand.getLoc(), vectorType, operand, invOrder));
536  }
537  }
538 
539  auto vectorType = srcType.clone(
540  cast<VectorType>(op->getResultTypes()[0]).getElementType());
541  Operation *elementwiseOp =
542  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
543  vectorType, op->getAttrs());
544  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
545  op, op->getResultTypes()[0], elementwiseOp->getResult(0),
546  transposeMaps.front());
547  return success();
548  }
549 };
550 
551 // Returns the values in `arrayAttr` as an integer vector.
552 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
553  return llvm::to_vector<4>(
554  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
555  [](IntegerAttr attr) { return attr.getInt(); }));
556 }
557 
558 // Shuffles vector.bitcast op after vector.extract op.
559 //
560 // This transforms IR like:
561 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
562 // %1 = vector.extract %0[3] : f16 from vector<8xf16>
563 // Into:
564 // %0 = vector.extract %src[1] : f32 from vector<4xf32>
565 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
566 // %2 = vector.extract %1[1] : f16 from vector<2xf16>
567 struct BubbleDownVectorBitCastForExtract
568  : public OpRewritePattern<vector::ExtractOp> {
570 
571  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
572  PatternRewriter &rewriter) const override {
573  // Only support extracting scalars for now.
574  if (extractOp.getSourceVectorType().getRank() != 1)
575  return failure();
576 
577  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
578  if (!castOp)
579  return failure();
580 
581  VectorType castSrcType = castOp.getSourceVectorType();
582  VectorType castDstType = castOp.getResultVectorType();
583  assert(castSrcType.getRank() == castDstType.getRank());
584 
585  // Fail to match if we only have one element in the cast op source.
586  // This is to avoid infinite loop given that this pattern can generate
587  // such cases.
588  if (castSrcType.getNumElements() == 1)
589  return failure();
590 
591  // Only support casting to a larger number of elements or now.
592  // E.g., vector<4xf32> -> vector<8xf16>.
593  if (castSrcType.getNumElements() > castDstType.getNumElements())
594  return failure();
595 
596  unsigned expandRatio =
597  castDstType.getNumElements() / castSrcType.getNumElements();
598 
599  auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
600  assert(values[0].is<Attribute>() && "Unexpected non-constant index");
601  return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
602  };
603 
604  uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
605 
606  // Get the single scalar (as a vector) in the source value that packs the
607  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
608  Location loc = extractOp.getLoc();
609  Value packedValue = rewriter.create<vector::ExtractOp>(
610  loc, castOp.getSource(), index / expandRatio);
611  Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
612  Value zero = rewriter.create<arith::ConstantOp>(
613  loc, packedVecType, rewriter.getZeroAttr(packedVecType));
614  packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
615  /*position=*/0);
616 
617  // Cast it to a vector with the desired scalar's type.
618  // E.g. f32 -> vector<2xf16>
619  VectorType packedType =
620  VectorType::get({expandRatio}, castDstType.getElementType());
621  Value castedValue =
622  rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
623 
624  // Finally extract the desired scalar.
625  rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
626  index % expandRatio);
627  return success();
628  }
629 };
630 
631 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
632 //
633 // This transforms IR like:
634 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
635 // %0 = vector.extract_strided_slice %cast {
636 // offsets = [4], sizes = [4], strides = [1]
637 // } : vector<8xf16> to vector<4xf16>
638 // Into:
639 // %0 = vector.extract_strided_slice %src {
640 // offsets = [2], sizes = [2], strides = [1]
641 // } : vector<4xf32> to vector<2xf32>
642 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
643 struct BubbleDownBitCastForStridedSliceExtract
644  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
646 
647  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
648  PatternRewriter &rewriter) const override {
649  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
650  if (!castOp)
651  return failure();
652 
653  VectorType castSrcType = castOp.getSourceVectorType();
654  VectorType castDstType = castOp.getResultVectorType();
655  assert(castSrcType.getRank() == castDstType.getRank());
656 
657  int64_t castSrcLastDim = castSrcType.getShape().back();
658  int64_t castDstLastDim = castDstType.getShape().back();
659  // Require casting to more elements for now; other cases to be implemented.
660  if (castSrcLastDim > castDstLastDim)
661  return failure();
662 
663  // Only accept all one strides for now.
664  if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
665  [](const APInt &val) { return !val.isOne(); }))
666  return failure();
667 
668  unsigned rank = extractOp.getSourceVectorType().getRank();
669  assert(castDstLastDim % castSrcLastDim == 0);
670  int64_t expandRatio = castDstLastDim / castSrcLastDim;
671 
672  // If we have a less number of offsets than the rank, then implicitly we
673  // are selecting the full range for the last bitcasted dimension; other
674  // dimensions aren't affected. Otherwise, we need to scale down the last
675  // dimension's offset given we are extracting from less elements now.
676  ArrayAttr newOffsets = extractOp.getOffsets();
677  if (newOffsets.size() == rank) {
678  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
679  if (offsets.back() % expandRatio != 0)
680  return failure();
681  offsets.back() = offsets.back() / expandRatio;
682  newOffsets = rewriter.getI64ArrayAttr(offsets);
683  }
684 
685  // Similarly for sizes.
686  ArrayAttr newSizes = extractOp.getSizes();
687  if (newSizes.size() == rank) {
688  SmallVector<int64_t> sizes = getIntValueVector(newSizes);
689  if (sizes.back() % expandRatio != 0)
690  return failure();
691  sizes.back() = sizes.back() / expandRatio;
692  newSizes = rewriter.getI64ArrayAttr(sizes);
693  }
694 
695  SmallVector<int64_t> dims =
696  llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
697  dims.back() = dims.back() / expandRatio;
698  VectorType newExtractType =
699  VectorType::get(dims, castSrcType.getElementType());
700 
701  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
702  extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
703  newSizes, extractOp.getStrides());
704 
705  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
706  extractOp, extractOp.getType(), newExtractOp);
707 
708  return success();
709  }
710 };
711 
712 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
713 //
714 // This transforms IR like:
715 // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
716 // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
717 // Into:
718 // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
719 // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
720 // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
721 //
722 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
724 
725  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
726  PatternRewriter &rewriter) const override {
727  VectorType castSrcType = bitcastOp.getSourceVectorType();
728  VectorType castDstType = bitcastOp.getResultVectorType();
729 
730  // 0-D and scalable vectors are not supported yet.
731  if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
732  castDstType.isScalable())
733  return failure();
734 
735  int64_t castSrcLastDim = castSrcType.getShape().back();
736  int64_t castDstLastDim = castDstType.getShape().back();
737  bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
738  int64_t ratio;
739  if (isNumElemsShrink) {
740  assert(castSrcLastDim % castDstLastDim == 0);
741  ratio = castSrcLastDim / castDstLastDim;
742  } else {
743  assert(castDstLastDim % castSrcLastDim == 0);
744  ratio = castDstLastDim / castSrcLastDim;
745  }
746 
747  auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
748  if (!insertOp)
749  return failure();
750 
751  // Only vector sources are supported for now.
752  auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
753  if (!insertSrcType)
754  return failure();
755 
756  // Bitcast the source.
757  SmallVector<int64_t> srcDims(insertSrcType.getShape());
758  srcDims.back() =
759  isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
760  VectorType newCastSrcType =
761  VectorType::get(srcDims, castDstType.getElementType());
762  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
763  bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
764 
765  SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
766  dstDims.back() =
767  isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
768  VectorType newCastDstType =
769  VectorType::get(dstDims, castDstType.getElementType());
770 
771  // Bitcast the destination.
772  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
773  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
774 
775  // Generate new insert.
776  rewriter.replaceOpWithNewOp<vector::InsertOp>(
777  bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
778  return success();
779  }
780 };
781 
782 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
783 //
784 // This transforms IR like:
785 // %0 = vector.insert_strided_slice %src, %dst {
786 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
787 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
788 // Into:
789 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
790 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
791 // %2 = vector.insert_strided_slice %src, %dst {
792 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
793 struct BubbleUpBitCastForStridedSliceInsert
794  : public OpRewritePattern<vector::BitCastOp> {
796 
797  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
798  PatternRewriter &rewriter) const override {
799  VectorType castSrcType = bitcastOp.getSourceVectorType();
800  VectorType castDstType = bitcastOp.getResultVectorType();
801  assert(castSrcType.getRank() == castDstType.getRank());
802  // Skip 0-D vector which will not from InsertStridedSliceOp.
803  if (castSrcType.getRank() == 0)
804  return failure();
805 
806  int64_t castSrcLastDim = castSrcType.getShape().back();
807  int64_t castDstLastDim = castDstType.getShape().back();
808  // Require casting to less elements for now; other cases to be implemented.
809  if (castSrcLastDim < castDstLastDim)
810  return failure();
811 
812  assert(castSrcLastDim % castDstLastDim == 0);
813  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
814 
815  auto insertOp =
816  bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
817  if (!insertOp)
818  return failure();
819 
820  // Only accept all one strides for now.
821  if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
822  [](const APInt &val) { return !val.isOne(); }))
823  return failure();
824 
825  unsigned rank = insertOp.getSourceVectorType().getRank();
826  // Require insert op to have the same rank for the source and destination
827  // vector; other cases to be implemented.
828  if (rank != insertOp.getDestVectorType().getRank())
829  return failure();
830 
831  // Requires that shape of insert op src is castable to dstType.
832  unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
833  unsigned destinationWidth =
834  castDstType.getElementType().getIntOrFloatBitWidth();
835  unsigned numElements = destinationWidth / sourceWidth;
836  if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
837  return failure();
838 
839  ArrayAttr newOffsets = insertOp.getOffsets();
840  assert(newOffsets.size() == rank);
841  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
842  if (offsets.back() % shrinkRatio != 0)
843  return failure();
844  offsets.back() = offsets.back() / shrinkRatio;
845  newOffsets = rewriter.getI64ArrayAttr(offsets);
846 
847  SmallVector<int64_t> srcDims =
848  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
849  srcDims.back() = srcDims.back() / shrinkRatio;
850  VectorType newCastSrcType =
851  VectorType::get(srcDims, castDstType.getElementType());
852 
853  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
854  bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
855 
856  SmallVector<int64_t> dstDims =
857  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
858  dstDims.back() = dstDims.back() / shrinkRatio;
859  VectorType newCastDstType =
860  VectorType::get(dstDims, castDstType.getElementType());
861 
862  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
863  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
864 
865  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
866  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
867  insertOp.getStrides());
868 
869  return success();
870  }
871 };
872 
873 // Breaks down vector.bitcast op
874 //
875 // This transforms IR like:
876 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
877 // Into:
878 // %cst = vector.splat %c0_f32 : vector<4xf32>
879 // %1 = vector.extract_strided_slice %0 {
880 // offsets = [0], sizes = [4], strides = [1]
881 // } : vector<8xf16> to vector<4xf16>
882 // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
883 // %4 = vector.insert_strided_slice %2, %cst {
884 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
885 // %5 = vector.extract_strided_slice %0 {
886 // offsets = [4], sizes = [4], strides = [1]
887 // } : vector<8xf16> to vector<4xf16>
888 // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
889 // %7 = vector.insert_strided_slice %6, %cst {
890 // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
891 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
893 
894 public:
895  BreakDownVectorBitCast(MLIRContext *context,
896  std::function<bool(vector::BitCastOp)> controlFn,
897  PatternBenefit benefit)
898  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
899 
900  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
901  PatternRewriter &rewriter) const override {
902 
903  if (controlFn && !controlFn(bitcastOp))
904  return failure();
905 
906  VectorType castSrcType = bitcastOp.getSourceVectorType();
907  VectorType castDstType = bitcastOp.getResultVectorType();
908  assert(castSrcType.getRank() == castDstType.getRank());
909 
910  // Only support rank 1 case for now.
911  if (castSrcType.getRank() != 1)
912  return failure();
913 
914  int64_t castSrcLastDim = castSrcType.getShape().back();
915  int64_t castDstLastDim = castDstType.getShape().back();
916  // Require casting to less elements for now; other cases to be implemented.
917  if (castSrcLastDim < castDstLastDim)
918  return failure();
919 
920  assert(castSrcLastDim % castDstLastDim == 0);
921  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
922  // Nothing to do if it is already bitcasting to a single element.
923  if (castSrcLastDim == shrinkRatio)
924  return failure();
925 
926  Location loc = bitcastOp.getLoc();
927  Type elemType = castDstType.getElementType();
928  assert(elemType.isSignlessIntOrIndexOrFloat());
929 
930  Value zero = rewriter.create<arith::ConstantOp>(
931  loc, elemType, rewriter.getZeroAttr(elemType));
932  Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
933 
934  SmallVector<int64_t> sliceShape{castDstLastDim};
935  SmallVector<int64_t> strides{1};
936  VectorType newCastDstType =
937  VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
938  castDstType.getElementType());
939 
940  for (int i = 0, e = shrinkRatio; i < e; ++i) {
941  Value extracted = rewriter.create<ExtractStridedSliceOp>(
942  loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
943  sliceShape, strides);
944  Value bitcast =
945  rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
946  res = rewriter.create<InsertStridedSliceOp>(
947  loc, bitcast, res,
948  ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
949  }
950  rewriter.replaceOp(bitcastOp, res);
951  return success();
952  }
953 
954 private:
955  std::function<bool(BitCastOp)> controlFn;
956 };
957 
958 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
959 /// ```
960 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
961 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
962 /// %r = arith.addi %a, %b : vector<1x4xindex>
963 /// ```
964 /// Gets converted to:
965 /// ```
966 /// %r = arith.addi %arg0, %arg1 : index
967 /// %b = vector.broadcast %r : index to vector<1x4xindex>
968 /// ```
969 ///
970 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
971 /// ops.
972 struct ReorderElementwiseOpsOnBroadcast final
973  : public OpTraitRewritePattern<OpTrait::Elementwise> {
975  LogicalResult matchAndRewrite(Operation *op,
976  PatternRewriter &rewriter) const override {
977  if (op->getNumResults() != 1)
978  return failure();
979  if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
980  return failure();
982  return rewriter.notifyMatchFailure(
983  op, "Op doesn't have ElementwiseMappableTraits");
984  if (op->getNumOperands() == 0)
985  return failure();
986  if (op->getResults()[0].getType() != op->getOperand(0).getType())
987  return rewriter.notifyMatchFailure(op,
988  "result and operand type mismatch");
989  if (isa<vector::FMAOp>(op)) {
990  return rewriter.notifyMatchFailure(
991  op,
992  "Op only accepts vector types - not supported as broadcast source "
993  "might be a scalar");
994  }
995 
996  // Get the type of the lhs operand
997  auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
998  if (!lhsBcastOrSplat ||
999  !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1000  return failure();
1001  auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1002 
1003  // Make sure that all operands are broadcast from identical types:
1004  // * scalar (`vector.broadcast` + `vector.splat`), or
1005  // * vector (`vector.broadcast`).
1006  // Otherwise the re-ordering wouldn't be safe.
1007  if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
1008  auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1009  if (bcast)
1010  return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1011  auto splat = val.getDefiningOp<vector::SplatOp>();
1012  if (splat)
1013  return (splat.getOperand().getType() == lhsBcastOrSplatType);
1014  return false;
1015  })) {
1016  return failure();
1017  }
1018 
1019  // Collect the source values before broadcasting
1020  SmallVector<Value> srcValues;
1021  srcValues.reserve(op->getNumOperands());
1022  for (Value operand : op->getOperands()) {
1023  srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1024  }
1025 
1026  // Create the "elementwise" Op
1027  Operation *elementwiseOp =
1028  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1029  lhsBcastOrSplatType, op->getAttrs());
1030 
1031  // Replace the original Op with the elementwise Op
1032  auto vectorType = op->getResultTypes()[0];
1033  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1034  op, vectorType, elementwiseOp->getResults());
1035 
1036  return success();
1037  }
1038 };
1039 
1040 // Helper that returns a vector comparison that constructs a mask:
1041 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1042 //
1043 // If `dim == 0` then the result will be a 0-D vector.
1044 //
1045 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1046 // much more compact, IR for this operation, but LLVM eventually
1047 // generates more elaborate instructions for this intrinsic since it
1048 // is very conservative on the boundary conditions.
1049 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1050  bool force32BitVectorIndices, int64_t dim,
1051  Value b, Value *off = nullptr) {
1052  auto loc = op->getLoc();
1053  // If we can assume all indices fit in 32-bit, we perform the vector
1054  // comparison in 32-bit to get a higher degree of SIMD parallelism.
1055  // Otherwise we perform the vector comparison using 64-bit indices.
1056  Type idxType =
1057  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1058  DenseIntElementsAttr indicesAttr;
1059  if (dim == 0 && force32BitVectorIndices) {
1060  indicesAttr = DenseIntElementsAttr::get(
1062  } else if (dim == 0) {
1063  indicesAttr = DenseIntElementsAttr::get(
1065  } else if (force32BitVectorIndices) {
1066  indicesAttr = rewriter.getI32VectorAttr(
1067  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1068  } else {
1069  indicesAttr = rewriter.getI64VectorAttr(
1070  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1071  }
1072  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
1073  // Add in an offset if requested.
1074  if (off) {
1075  Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1076  Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1077  indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1078  }
1079  // Construct the vector comparison.
1080  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1081  Value bounds =
1082  rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1083  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1084  bounds);
1085 }
1086 
1087 template <typename ConcreteOp>
1088 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1089 public:
1090  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1091  PatternBenefit benefit = 1)
1092  : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1093  force32BitVectorIndices(enableIndexOpt) {}
1094 
1095  LogicalResult matchAndRewrite(ConcreteOp xferOp,
1096  PatternRewriter &rewriter) const override {
1097  if (!xferOp.hasOutOfBoundsDim())
1098  return failure();
1099 
1100  if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1101  return failure();
1102 
1103  Location loc = xferOp->getLoc();
1104  VectorType vtp = xferOp.getVectorType();
1105 
1106  // Create the in-bounds mask with all elements between [0 .. dim - offset)
1107  // set and [dim - offset .. vector_length) unset.
1108  //
1109  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1110  // dimensions here.
1111  unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1112  Value off = xferOp.getIndices()[lastIndex];
1113  Value dim =
1114  vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
1115  Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1116  Value mask = rewriter.create<vector::CreateMaskOp>(
1117  loc,
1118  VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1119  vtp.getScalableDims()),
1120  b);
1121  if (xferOp.getMask()) {
1122  // Intersect the in-bounds with the mask specified as an op parameter.
1123  mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1124  }
1125 
1126  rewriter.modifyOpInPlace(xferOp, [&]() {
1127  xferOp.getMaskMutable().assign(mask);
1128  xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1129  });
1130 
1131  return success();
1132  }
1133 
1134 private:
1135  const bool force32BitVectorIndices;
1136 };
1137 
1138 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1139 class VectorCreateMaskOpConversion
1140  : public OpRewritePattern<vector::CreateMaskOp> {
1141 public:
1142  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1143  bool enableIndexOpt,
1144  PatternBenefit benefit = 1)
1145  : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1146  force32BitVectorIndices(enableIndexOpt) {}
1147 
1148  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1149  PatternRewriter &rewriter) const override {
1150  auto dstType = op.getType();
1151  if (cast<VectorType>(dstType).isScalable())
1152  return failure();
1153  int64_t rank = dstType.getRank();
1154  if (rank > 1)
1155  return failure();
1156  rewriter.replaceOp(
1157  op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1158  rank == 0 ? 0 : dstType.getDimSize(0),
1159  op.getOperand(0)));
1160  return success();
1161  }
1162 
1163 private:
1164  const bool force32BitVectorIndices;
1165 };
1166 
1167 /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1168 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1169  auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1170  // TODO: Support non-dense constant.
1171  if (!denseAttr)
1172  return false;
1173 
1174  assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1175  return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1176 }
1177 
1178 /// Folds a select operation between an all-true and all-false vector. For now,
1179 /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1180 ///
1181 /// %true = arith.constant dense<true> : vector<1xi1>
1182 /// %false = arith.constant dense<false> : vector<1xi1>
1183 /// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1184 /// =>
1185 /// %result = vector.broadcast %cond : i1 to vector<1xi1>
1186 ///
1187 /// InstCombine seems to handle vectors with multiple elements but not the
1188 /// single element ones.
1189 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1191 
1192  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1193  PatternRewriter &rewriter) const override {
1194  auto vecType = dyn_cast<VectorType>(selectOp.getType());
1195  if (!vecType || !vecType.getElementType().isInteger(1))
1196  return failure();
1197 
1198  // Only scalar conditions can be folded.
1199  Value cond = selectOp.getCondition();
1200  if (isa<VectorType>(cond.getType()))
1201  return failure();
1202 
1203  // TODO: Support n-D and scalable vectors.
1204  if (vecType.getRank() != 1 || vecType.isScalable())
1205  return failure();
1206 
1207  // TODO: Support vectors with multiple elements.
1208  if (vecType.getShape()[0] != 1)
1209  return failure();
1210 
1211  auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1212  if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1213  return failure();
1214 
1215  auto falseConst =
1216  selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1217  if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1218  return failure();
1219 
1220  // Replace select with its condition broadcasted to single element vector.
1221  auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1222  auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1223  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1224  return success();
1225  }
1226 };
1227 
1228 /// Returns the number of dims can be folded away from transfer ops. It returns
1229 /// a failure if it can not determine the number of dims to be folded.
1230 ///
1231 /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1232 /// `vectorType` is vector<16x16x1x1xf32>
1233 /// (there two inner most dims can be dropped by memref.subview ops)
1234 ///
1235 /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1236 /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1237 /// (only the inner most unit dim of `srcType` can be dropped)
1238 ///
1239 /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1240 /// `vectorType` is vector<16x16x1x[1]xf32>
1241 /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1242 /// unit")
1243 static FailureOr<size_t>
1244 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1245  SmallVector<int64_t> srcStrides;
1246  int64_t srcOffset;
1247  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1248  return failure();
1249 
1250  auto isUnitDim = [](VectorType type, int dim) {
1251  return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1252  };
1253 
1254  // According to vector.transfer_read/write semantics, the vector can be a
1255  // slice. Thus, we have to offset the check index with `rankDiff` in
1256  // `srcStrides` and source dim sizes.
1257  size_t result = 0;
1258  int rankDiff = srcType.getRank() - vectorType.getRank();
1259  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1260  // Check that the inner dim size is 1 for both memref type and vector slice.
1261  // It can be folded only if they are 1 and the stride is 1.
1262  int dim = vectorType.getRank() - i - 1;
1263  if (srcStrides[dim + rankDiff] != 1 ||
1264  srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1265  break;
1266  result++;
1267  }
1268  return result;
1269 }
1270 
1271 /// Drop inner most contiguous unit dimensions from transfer_read operand.
1272 class DropInnerMostUnitDimsTransferRead
1273  : public OpRewritePattern<vector::TransferReadOp> {
1275 
1276  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1277  PatternRewriter &rewriter) const override {
1278  // TODO: support 0-d corner case.
1279  if (readOp.getTransferRank() == 0)
1280  return failure();
1281 
1282  // TODO: support mask.
1283  if (readOp.getMask())
1284  return failure();
1285 
1286  auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1287  if (!srcType)
1288  return failure();
1289 
1290  if (!readOp.getPermutationMap().isMinorIdentity())
1291  return failure();
1292 
1293  auto targetType = readOp.getVectorType();
1294  if (targetType.getRank() <= 1)
1295  return failure();
1296 
1297  FailureOr<size_t> maybeDimsToDrop =
1298  getTransferFoldableInnerUnitDims(srcType, targetType);
1299  if (failed(maybeDimsToDrop))
1300  return failure();
1301 
1302  size_t dimsToDrop = maybeDimsToDrop.value();
1303  if (dimsToDrop == 0)
1304  return failure();
1305 
1306  auto inBounds = readOp.getInBoundsValues();
1307  auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1308  if (llvm::is_contained(droppedInBounds, false))
1309  return failure();
1310 
1311  auto resultTargetVecType =
1312  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1313  targetType.getElementType(),
1314  targetType.getScalableDims().drop_back(dimsToDrop));
1315 
1316  auto loc = readOp.getLoc();
1318  memref::getMixedSizes(rewriter, loc, readOp.getSource());
1319  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1320  rewriter.getIndexAttr(0));
1321  SmallVector<OpFoldResult> strides(srcType.getRank(),
1322  rewriter.getIndexAttr(1));
1323  auto resultMemrefType =
1324  cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1325  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1326  strides));
1327  ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1328  readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1329  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1330  loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1331  auto permMap = getTransferMinorIdentityMap(
1332  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1333  Value result = rewriter.create<vector::TransferReadOp>(
1334  loc, resultTargetVecType, rankedReducedView,
1335  readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1336  readOp.getPadding(),
1337  // TODO: support mask.
1338  /*mask=*/Value(), inBoundsAttr);
1339  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1340  result);
1341  return success();
1342  }
1343 };
1344 
1345 /// Drop inner most contiguous unit dimensions from transfer_write operand.
1346 /// E.g.,
1347 /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1348 /// {in_bounds = [true, true, true, true, true]}
1349 /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1350 ///
1351 /// will be replaced with
1352 ///
1353 /// %subview = memref.subview %arg0
1354 /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1355 /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1356 /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1357 /// to vector<1x16x16xf32>
1358 /// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1359 /// {in_bounds = [true, true, true]}
1360 /// : vector<1x16x16xf32>, memref<1x512x16xf32>
1361 ///
1362 /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1363 class DropInnerMostUnitDimsTransferWrite
1364  : public OpRewritePattern<vector::TransferWriteOp> {
1366 
1367  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1368  PatternRewriter &rewriter) const override {
1369  // TODO: support 0-d corner case.
1370  if (writeOp.getTransferRank() == 0)
1371  return failure();
1372 
1373  // TODO: support mask.
1374  if (writeOp.getMask())
1375  return failure();
1376 
1377  auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1378  if (!srcType)
1379  return failure();
1380 
1381  if (!writeOp.getPermutationMap().isMinorIdentity())
1382  return failure();
1383 
1384  auto targetType = writeOp.getVectorType();
1385  if (targetType.getRank() <= 1)
1386  return failure();
1387 
1388  FailureOr<size_t> maybeDimsToDrop =
1389  getTransferFoldableInnerUnitDims(srcType, targetType);
1390  if (failed(maybeDimsToDrop))
1391  return failure();
1392 
1393  size_t dimsToDrop = maybeDimsToDrop.value();
1394  if (dimsToDrop == 0)
1395  return failure();
1396 
1397  auto inBounds = writeOp.getInBoundsValues();
1398  auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1399  if (llvm::is_contained(droppedInBounds, false))
1400  return failure();
1401 
1402  auto resultTargetVecType =
1403  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1404  targetType.getElementType(),
1405  targetType.getScalableDims().drop_back(dimsToDrop));
1406 
1407  Location loc = writeOp.getLoc();
1409  memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1410  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1411  rewriter.getIndexAttr(0));
1412  SmallVector<OpFoldResult> strides(srcType.getRank(),
1413  rewriter.getIndexAttr(1));
1414  auto resultMemrefType =
1415  cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1416  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1417  strides));
1418  ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1419  writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1420 
1421  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1422  loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1423  auto permMap = getTransferMinorIdentityMap(
1424  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1425 
1426  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1427  loc, resultTargetVecType, writeOp.getVector());
1428  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1429  writeOp, shapeCast, rankedReducedView,
1430  writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1431  // TODO: support mask.
1432  /*mask=*/Value(), inBoundsAttr);
1433  return success();
1434  }
1435 };
1436 
1437 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1438 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1439 /// with the RHS transposed) lowering.
1440 struct CanonicalizeContractMatmulToMMT final
1441  : OpRewritePattern<vector::ContractionOp> {
1443 
1444  using FilterConstraintType =
1445  std::function<LogicalResult(vector::ContractionOp op)>;
1446 
1447  CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1448  FilterConstraintType constraint)
1449  : OpRewritePattern<vector::ContractionOp>(context, benefit),
1450  filter(std::move(constraint)) {}
1451 
1452  LogicalResult matchAndRewrite(vector::ContractionOp op,
1453  PatternRewriter &rewriter) const override {
1454  if (failed(filter(op)))
1455  return failure();
1456 
1457  Location loc = op.getLoc();
1458  Value lhs = op.getLhs();
1459  Value rhs = op.getRhs();
1460  Value res = op.getAcc();
1461 
1462  // Set up the parallel/reduction structure in right form.
1463  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1464  auto infer = [&](MapList m) {
1465  return AffineMap::inferFromExprList(m, op.getContext());
1466  };
1467  AffineExpr m;
1468  AffineExpr n;
1469  AffineExpr k;
1470  bindDims(rewriter.getContext(), m, n, k);
1471  static constexpr std::array<int64_t, 2> perm = {1, 0};
1472  auto iteratorTypes = op.getIteratorTypes().getValue();
1473  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1474  if (iteratorTypes.size() != 3 ||
1475  !vector::isParallelIterator(iteratorTypes[0]) ||
1476  !vector::isParallelIterator(iteratorTypes[1]) ||
1477  !vector::isReductionIterator(iteratorTypes[2]))
1478  return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1479 
1480  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1481  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1482  if (maps == canonicalForm)
1483  return rewriter.notifyMatchFailure(op, "already in the canonical form");
1484 
1485  // Create a vector transpose making sure to emit zero/sign-extend at the
1486  // end.
1487  auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1488  if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1489  Value trans =
1490  rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1491  VectorType newType =
1492  cast<VectorType>(trans.getType())
1493  .clone(cast<VectorType>(mat.getType()).getElementType());
1494  return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1495  }
1496  if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1497  Value trans =
1498  rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1499  VectorType newType =
1500  VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1501  cast<VectorType>(mat.getType()).getElementType());
1502  return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1503  }
1504  return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1505  };
1506 
1507  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1508  rhs = createTranspose(rhs);
1509  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1510  lhs = createTranspose(lhs);
1511  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1512  rhs = createTranspose(rhs);
1513  lhs = createTranspose(lhs);
1514  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1515  std::swap(rhs, lhs);
1516  rhs = createTranspose(rhs);
1517  lhs = createTranspose(lhs);
1518  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1519  std::swap(rhs, lhs);
1520  rhs = createTranspose(rhs);
1521  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1522  std::swap(lhs, rhs);
1523  lhs = createTranspose(lhs);
1524  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1525  std::swap(lhs, rhs);
1526  } else {
1527  return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1528  }
1529  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1530  op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1531  op.getIteratorTypes());
1532  return success();
1533  };
1534 
1535 private:
1536  FilterConstraintType filter;
1537 };
1538 
1539 /// Pattern to fold arithmetic extensions on floating point data types into
1540 /// vector contraction operations. linalg.matmul introduces arithmetic
1541 /// extensions on its operands. Please mlir snippets below for more details.
1542 /// ```mlir
1543 /// "linalg.matmul"(%lhs, %rhs, %acc) ({
1544 /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1545 /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1546 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1547 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1548 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1549 /// "linalg.yield"(%acc) : (f32) -> ()
1550 /// })
1551 /// ```
1552 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1553 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1554 /// This pattern folds the arithmetic extensions into the vector contraction and
1555 /// enables the usage of native mixed precision Tensor Core instructions.
1556 template <typename ExtOp>
1557 struct FoldArithExtIntoContractionOp
1558  : public OpRewritePattern<vector::ContractionOp> {
1560 
1561  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1562  PatternRewriter &rewriter) const override {
1563 
1564  auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1565  auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1566 
1567  if (!lhsDefOp || !rhsDefOp) {
1568  return rewriter.notifyMatchFailure(contractOp,
1569  "no defining op on contract operands");
1570  }
1571 
1572  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1573  contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1574  contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1575  contractOp.getIteratorTypesAttr());
1576 
1577  return success();
1578  }
1579 };
1580 
1581 /// Pattern to fold chained reduction to a series of vector additions and a
1582 /// final reduction. This form should require fewer subgroup operations.
1583 ///
1584 /// ```mlir
1585 /// %a = vector.reduction <add> %x, %acc
1586 /// %b = vector.reduction <add> %y, %a
1587 /// ==>
1588 /// %a = arith.addf %x, %y
1589 /// %b = vector.reduction <add> %a, %acc
1590 /// ```
1591 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1593 
1594  LogicalResult matchAndRewrite(vector::ReductionOp op,
1595  PatternRewriter &rewriter) const override {
1596  // TODO: Handle other combining kinds.
1597  if (op.getKind() != vector::CombiningKind::ADD)
1598  return failure();
1599 
1600  // Accumulator is optional.
1601  Value acc = op.getAcc();
1602  if (!acc)
1603  return failure();
1604 
1605  if (!acc.getType().isIntOrFloat())
1606  return failure();
1607 
1608  auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1609  if (!parentReduction)
1610  return failure();
1611 
1612  Location loc = op.getLoc();
1613  Value vAdd;
1614  if (isa<IntegerType>(acc.getType())) {
1615  vAdd = rewriter.createOrFold<arith::AddIOp>(
1616  loc, parentReduction.getVector(), op.getVector());
1617  } else {
1618  vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1619  op.getVector());
1620  }
1621  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1622  parentReduction.getAcc());
1623  return success();
1624  }
1625 };
1626 
1627 // Helper function dropping unit non-scalable dimension from a VectorType
1628 // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1629 // dimensions are not dropped. Folding such dimensions would require "shifting"
1630 // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1631 // vector<[4]xf32>). This could be implemented in the future.
1632 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1633  auto inVecShape = inVecTy.getShape();
1634  SmallVector<int64_t> newShape;
1635  SmallVector<bool> newScalableDims;
1636  for (auto [dim, isScalable] :
1637  llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1638  if (dim == 1 && !isScalable)
1639  continue;
1640 
1641  newShape.push_back(dim);
1642  newScalableDims.push_back(isScalable);
1643  }
1644  // All dims have been dropped, return vector<1xeType>.
1645  if (newShape.empty()) {
1646  newShape.push_back(1);
1647  newScalableDims.push_back(false);
1648  }
1649 
1650  return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1651 }
1652 
1653 /// For vectors with at least one unit dim, replaces:
1654 /// elementwise(a, b)
1655 /// with:
1656 /// sc_a = shape_cast(a)
1657 /// sc_b = shape_cast(b)
1658 /// res = elementwise(sc_a, sc_b)
1659 /// return shape_cast(res)
1660 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1661 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1662 /// required to be rank > 1.
1663 ///
1664 /// Ex:
1665 /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1666 /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1667 ///
1668 /// gets converted to:
1669 ///
1670 /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1671 /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1672 /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1673 /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1674 /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1675 ///
1676 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1677 /// `%cast`.
1678 struct DropUnitDimFromElementwiseOps final
1679  : public OpTraitRewritePattern<OpTrait::Elementwise> {
1681  LogicalResult matchAndRewrite(Operation *op,
1682  PatternRewriter &rewriter) const override {
1683  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1684  return failure();
1685 
1686  auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1687  if (!resultVectorType)
1688  return failure();
1689 
1690  // Check the operand pre-conditions. For `Elementwise` ops all operands are
1691  // guaranteed to have identical shapes (with some exceptions such as
1692  // `arith.select`) and it suffices to only check one of them.
1693  auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1694  if (!sourceVectorType)
1695  return failure();
1696  if (sourceVectorType.getRank() < 2)
1697  return failure();
1698 
1699  SmallVector<Value> newOperands;
1700  auto loc = op->getLoc();
1701  for (auto operand : op->getOperands()) {
1702  auto opVectorType = cast<VectorType>(operand.getType());
1703  auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1704  if (newVType == opVectorType)
1705  return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1706 
1707  auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1708  newOperands.push_back(opSC);
1709  }
1710 
1711  VectorType newResultVectorType =
1712  dropNonScalableUnitDimFromType(resultVectorType);
1713  // Create an updated elementwise Op without unit dim.
1714  Operation *elementwiseOp =
1715  rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1716  newResultVectorType, op->getAttrs());
1717 
1718  // Restore the unit dim by applying vector.shape_cast to the result.
1719  rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1720  elementwiseOp->getResult(0));
1721 
1722  return success();
1723  }
1724 };
1725 
1726 /// A pattern to drop unit dims from vector.transpose.
1727 ///
1728 /// Example:
1729 ///
1730 /// BEFORE:
1731 /// ```mlir
1732 /// %transpose = vector.transpose %vector, [3, 0, 1, 2]
1733 /// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1734 /// ```
1735 ///
1736 /// AFTER:
1737 /// ```mlir
1738 /// %dropDims = vector.shape_cast %vector
1739 /// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1740 /// %transpose = vector.transpose %0, [1, 0]
1741 /// : vector<4x[4]xf32> to vector<[4]x4xf32>
1742 /// %restoreDims = vector.shape_cast %transpose
1743 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1744 /// ```
1745 struct DropUnitDimsFromTransposeOp final
1746  : OpRewritePattern<vector::TransposeOp> {
1748 
1749  LogicalResult matchAndRewrite(vector::TransposeOp op,
1750  PatternRewriter &rewriter) const override {
1751  VectorType sourceType = op.getSourceVectorType();
1752  VectorType sourceTypeWithoutUnitDims =
1753  dropNonScalableUnitDimFromType(sourceType);
1754 
1755  if (sourceType == sourceTypeWithoutUnitDims)
1756  return failure();
1757 
1758  // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1759  auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
1760  SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
1761  int64_t droppedDims = 0;
1762  for (auto [i, dim] : llvm::enumerate(sourceDims)) {
1763  droppedDimsBefore[i] = droppedDims;
1764  if (dim == std::make_tuple(1, false))
1765  ++droppedDims;
1766  }
1767 
1768  // Drop unit dims from transpose permutation.
1769  ArrayRef<int64_t> perm = op.getPermutation();
1770  SmallVector<int64_t> newPerm;
1771  for (int64_t idx : perm) {
1772  if (sourceDims[idx] == std::make_tuple(1, false))
1773  continue;
1774  newPerm.push_back(idx - droppedDimsBefore[idx]);
1775  }
1776 
1777  // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
1778  // type when the dimensions are unit dimensions. In this case, the newPerm
1779  // should be [0].
1780  if (newPerm.empty()) {
1781  newPerm.push_back(0);
1782  }
1783 
1784  Location loc = op.getLoc();
1785  // Drop the unit dims via shape_cast.
1786  auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
1787  loc, sourceTypeWithoutUnitDims, op.getVector());
1788  // Create the new transpose.
1789  auto tranposeWithoutUnitDims =
1790  rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1791  // Restore the unit dims via shape cast.
1792  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
1793  op, op.getResultVectorType(), tranposeWithoutUnitDims);
1794 
1795  return success();
1796  }
1797 };
1798 
1799 /// A pattern to drop unit dims from the iter_args of an scf.for.
1800 ///
1801 /// Example:
1802 ///
1803 /// BEFORE:
1804 /// ```mlir
1805 /// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
1806 /// ...
1807 /// scf.yield %
1808 /// }
1809 /// ```
1810 ///
1811 /// AFTER:
1812 /// ```mlir
1813 /// %drop = vector.shape_cast %init
1814 /// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
1815 /// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
1816 /// %new_iter = vector.shape_cast %iter
1817 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1818 /// ...
1819 /// }
1820 /// %res = vector.shape_cast %new_loop
1821 /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1822 /// ```
1823 struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
1825 
1826  LogicalResult matchAndRewrite(scf::ForOp forOp,
1827  PatternRewriter &rewriter) const override {
1828  /// Find the first iter_arg with droppable unit dims. Further applications
1829  /// of this pattern will apply to later arguments.
1830  for (OpOperand &operand : forOp.getInitArgsMutable()) {
1831  auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1832  if (!vectorType)
1833  continue;
1834 
1835  VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1836  if (vectorType == newVectorType)
1837  continue;
1838 
1839  // Create a new ForOp with that iter operand replaced.
1840  auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
1841  return b.create<vector::ShapeCastOp>(loc, type, source);
1842  };
1843 
1844  Value replacement =
1845  castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1846  rewriter.replaceOp(forOp,
1847  replaceAndCastForOpIterArg(rewriter, forOp, operand,
1848  replacement, castFn));
1849  return success();
1850  }
1851  return failure();
1852  }
1853 };
1854 
1855 /// Pattern to eliminate redundant zero-constants added to reduction operands.
1856 /// It's enough for there to be one initial zero value, so we can eliminate the
1857 /// extra ones that feed into `vector.reduction <add>`. These get created by the
1858 /// `ChainedReduction` pattern.
1859 ///
1860 /// ```mlir
1861 /// %a = arith.addf %x, %zero
1862 /// %b = arith.addf %a, %y
1863 /// %c = vector.reduction <add> %b, %acc
1864 /// ==>
1865 /// %b = arith.addf %a, %y
1866 /// %c = vector.reduction <add> %b, %acc
1867 /// ```
1868 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1870 
1871  LogicalResult matchAndRewrite(vector::ReductionOp op,
1872  PatternRewriter &rewriter) const override {
1873  // TODO: Handle other reduction kinds and their identity values.
1874  if (op.getKind() != vector::CombiningKind::ADD)
1875  return failure();
1876 
1877  Type elemType = op.getSourceVectorType().getElementType();
1878  // The integer case should be handled by `arith.addi` folders, only check
1879  // for floats here.
1880  if (!isa<FloatType>(elemType))
1881  return failure();
1882 
1883  auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1884  if (!vAdd)
1885  return failure();
1886  auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1887  if (!addLhs)
1888  return failure();
1889 
1890  if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1891  return failure();
1892 
1893  auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1894  vAdd.getRhs());
1895  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1896  op.getAcc());
1897  return success();
1898  }
1899 };
1900 
1901 /// Example:
1902 /// ```
1903 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
1904 /// ```
1905 /// is transformed into:
1906 /// ```
1907 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
1908 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
1909 /// %a = arith.addf %y, %z : f32
1910 /// ```
1911 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1912  BreakDownVectorReduction(MLIRContext *context,
1913  unsigned maxNumElementsToExtract,
1914  PatternBenefit benefit)
1915  : OpRewritePattern(context, benefit),
1916  maxNumElementsToExtract(maxNumElementsToExtract) {}
1917 
1918  LogicalResult matchAndRewrite(vector::ReductionOp op,
1919  PatternRewriter &rewriter) const override {
1920  VectorType type = op.getSourceVectorType();
1921  if (type.isScalable() || op.isMasked())
1922  return failure();
1923  assert(type.getRank() == 1 && "Expected a 1-d vector");
1924 
1925  int64_t numElems = type.getNumElements();
1926  if (numElems > maxNumElementsToExtract) {
1927  return rewriter.notifyMatchFailure(
1928  op, llvm::formatv("has too many vector elements ({0}) to break down "
1929  "(max allowed: {1})",
1930  numElems, maxNumElementsToExtract));
1931  }
1932 
1933  Location loc = op.getLoc();
1934  SmallVector<Value> extracted(numElems, nullptr);
1935  for (auto [idx, extractedElem] : llvm::enumerate(extracted))
1936  extractedElem = rewriter.create<vector::ExtractOp>(
1937  loc, op.getVector(), static_cast<int64_t>(idx));
1938 
1939  Value res = extracted.front();
1940  for (auto extractedElem : llvm::drop_begin(extracted))
1941  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
1942  extractedElem, op.getFastmathAttr());
1943  if (Value acc = op.getAcc())
1944  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
1945  op.getFastmathAttr());
1946 
1947  rewriter.replaceOp(op, res);
1948  return success();
1949  }
1950 
1951 private:
1952  unsigned maxNumElementsToExtract = 0;
1953 };
1954 
1955 /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
1956 /// B)`.
1957 /// Example:
1958 /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
1959 /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
1960 /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
1961 /// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
1962 ///
1963 /// Becomes :
1964 ///
1965 /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
1966 ///
1967 /// Supports only 1D-to-2D broadcasts. The following cases are not supported.
1968 /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
1969 /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
1970 /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
1971 template <typename MulOpType>
1972 struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
1974  // Returns whether a vector.broadcast matches requirements for an outerproduct
1975  // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
1976  bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
1977  // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
1978  // shape_casts/broadcasts which does not belong in this pattern.
1979  if (!broadcastOp.computeBroadcastedUnitDims().empty())
1980  return false;
1981  // Avoid broadcast like f32 or vector<f32> -> ResType
1982  auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1983  return srcType && srcType.getRank() != 2;
1984  }
1985 
1986  LogicalResult matchAndRewrite(MulOpType mulOp,
1987  PatternRewriter &rewriter) const override {
1988  auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1989  if (!resType)
1990  return failure();
1991  if (resType.getRank() != 2)
1992  return failure();
1993  /// If operandA can be written as tr(broadcast(A)) and operandB as
1994  /// broadcast(B) where broadcasts are 1D-to-2D, create and return
1995  /// vector.outerproduct(A, B). Returns failure() otherwise.
1996  auto matchOuterProduct =
1997  [&](Value operandA,
1998  Value operandB) -> FailureOr<vector::OuterProductOp> {
1999  auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2000  if (!transposedLhs)
2001  return failure();
2002  // Fail unless this is a true 2-D matrix transpose.
2003  ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2004  if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2005  return failure();
2006 
2007  auto broadcastedLhs =
2008  transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2009  if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2010  return failure();
2011 
2012  auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2013  if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2014  return failure();
2015 
2016  return rewriter.create<vector::OuterProductOp>(
2017  mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2018  broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2019  };
2020 
2021  Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2022  auto maybeOuterP = matchOuterProduct(lhs, rhs);
2023  // Handle commutativity, the transposed op is the outerproduct LHS.
2024  if (failed(maybeOuterP))
2025  maybeOuterP = matchOuterProduct(rhs, lhs);
2026  if (failed(maybeOuterP))
2027  return failure();
2028  rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2029  return success();
2030  }
2031 };
2032 
2033 } // namespace
2034 
2036  RewritePatternSet &patterns) {
2037  patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2038  FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2039  patterns.getContext());
2040 }
2041 
2043  RewritePatternSet &patterns, bool force32BitVectorIndices,
2044  PatternBenefit benefit) {
2045  patterns.add<VectorCreateMaskOpConversion,
2046  MaterializeTransferMask<vector::TransferReadOp>,
2047  MaterializeTransferMask<vector::TransferWriteOp>>(
2048  patterns.getContext(), force32BitVectorIndices, benefit);
2049  patterns.add<FoldI1Select>(patterns.getContext(), benefit);
2050 }
2051 
2053  PatternBenefit benefit) {
2054  patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
2055 }
2056 
2058  RewritePatternSet &patterns, PatternBenefit benefit) {
2059  // TODO: Consider either:
2060  // * including DropInnerMostUnitDimsTransferRead and
2061  // DropInnerMostUnitDimsTransferWrite, or
2062  // * better naming to distinguish this and
2063  // populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
2064  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2065  DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2066  patterns.getContext(), benefit);
2067 }
2068 
2070  RewritePatternSet &patterns, PatternBenefit benefit) {
2071  patterns.add<BubbleDownVectorBitCastForExtract,
2072  BubbleDownBitCastForStridedSliceExtract,
2073  BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2074  patterns.getContext(), benefit);
2075 }
2076 
2078  RewritePatternSet &patterns,
2079  std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2080  patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2081  std::move(controlFn), benefit);
2082 }
2083 
2085  RewritePatternSet &patterns,
2086  std::function<LogicalResult(vector::ContractionOp)> constraint,
2087  PatternBenefit benefit) {
2088  patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2089  std::move(constraint));
2090 }
2091 
2093  RewritePatternSet &patterns, PatternBenefit benefit) {
2094  patterns.add<MultiReduceToContract, CombineContractBroadcast,
2095  CombineContractABTranspose, CombineContractResultTranspose>(
2096  patterns.getContext(), benefit);
2097 }
2098 
2101  RewritePatternSet &patterns, PatternBenefit benefit) {
2102  patterns.add<DropInnerMostUnitDimsTransferRead,
2103  DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
2104  benefit);
2105 }
2106 
2108  PatternBenefit benefit) {
2109  patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2110  ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
2111  benefit);
2112 }
2113 
2115  RewritePatternSet &patterns, PatternBenefit benefit) {
2116  patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2117  patterns.add<ReduceRedundantZero>(patterns.getContext(),
2118  PatternBenefit(benefit.getBenefit() + 1));
2119 }
2120 
2122  RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2123  PatternBenefit benefit) {
2124  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
2125  maxNumElementsToExtract, benefit);
2126 }
2127 
2129  RewritePatternSet &patterns) {
2130  patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2131  FoldArithToVectorOuterProduct<arith::MulIOp>>(
2132  patterns.getContext());
2133 }
2134 
2135 //===----------------------------------------------------------------------===//
2136 // TableGen'd enum attribute definitions
2137 //===----------------------------------------------------------------------===//
2138 
2139 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:402
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:168
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:384
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:386
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
type_range getType() const
Definition: ValueRange.cpp:39
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:115
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:776
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Definition: VectorUtils.h:124
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:153
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:41
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:399
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
Definition: AffineMap.cpp:717
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:928
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362