MLIR  17.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <cstdint>
16 #include <functional>
17 #include <optional>
18 #include <type_traits>
19 
33 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/Location.h"
36 #include "mlir/IR/Matchers.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
41 
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/raw_ostream.h"
48 
49 #define DEBUG_TYPE "vector-to-vector"
50 
51 using namespace mlir;
52 using namespace mlir::vector;
53 
54 template <typename IntType>
55 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
56  return llvm::to_vector<4>(llvm::map_range(
57  arrayAttr.getAsRange<IntegerAttr>(),
58  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
59 }
60 
61 // Helper to find an index in an affine map.
62 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
63  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
64  int64_t idx = map.getDimPosition(i);
65  if (idx == index)
66  return i;
67  }
68  return std::nullopt;
69 }
70 
71 namespace {
72 
73 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
74 //
75 // Example:
76 //
77 // The following MLIR with cancelling ShapeCastOps:
78 //
79 // %0 = source : vector<5x4x2xf32>
80 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
81 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
82 // %3 = user %2 : vector<5x4x2xf32>
83 //
84 // Should canonicalize to the following:
85 //
86 // %0 = source : vector<5x4x2xf32>
87 // %1 = user %0 : vector<5x4x2xf32>
88 //
89 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
91 
92  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
93  PatternRewriter &rewriter) const override {
94  // Check if 'shapeCastOp' has vector source/result type.
95  auto sourceVectorType =
96  shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
97  auto resultVectorType =
98  shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
99  if (!sourceVectorType || !resultVectorType)
100  return failure();
101 
102  // Check if shape cast op source operand is also a shape cast op.
103  auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
104  shapeCastOp.getSource().getDefiningOp());
105  if (!sourceShapeCastOp)
106  return failure();
107  auto operandSourceVectorType =
108  sourceShapeCastOp.getSource().getType().cast<VectorType>();
109  auto operandResultVectorType = sourceShapeCastOp.getType();
110 
111  // Check if shape cast operations invert each other.
112  if (operandSourceVectorType != resultVectorType ||
113  operandResultVectorType != sourceVectorType)
114  return failure();
115 
116  rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
117  return success();
118  }
119 };
120 
121 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
122 /// Ex:
123 /// ```
124 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
125 /// %1 = vector.multi_reduction add, %0 [1]
126 /// : vector<8x32x16xf32> to vector<8x16xf32>
127 /// ```
128 /// Gets converted to:
129 /// ```
130 /// %1 = vector.contract {indexing_maps = [
131 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
132 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
133 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
134 /// iterator_types = ["parallel", "parallel", "reduction"],
135 /// kind = add} %0, %arg1, %cst_f0
136 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
137 /// ```
138 struct MultiReduceToContract
139  : public OpRewritePattern<vector::MultiDimReductionOp> {
141 
142  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
143  PatternRewriter &rewriter) const override {
144  if (reduceOp.getKind() != vector::CombiningKind::ADD)
145  return failure();
146  Operation *mulOp = reduceOp.getSource().getDefiningOp();
147  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
148  return failure();
149  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
150  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
152  SmallVector<vector::IteratorType> iteratorTypes;
153  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
154  if (!isReduceDim.value()) {
155  iteratorTypes.push_back(vector::IteratorType::parallel);
156  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
157  } else {
158  iteratorTypes.push_back(vector::IteratorType::reduction);
159  }
160  }
161  auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
162  /*symCount=*/0, exprs, reduceOp.getContext());
163  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
164  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
165  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
166  rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
167  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
168  return IteratorTypeAttr::get(rewriter.getContext(), t);
169  }))));
170  return success();
171  }
172 };
173 
174 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
175 /// Ex:
176 /// ```
177 /// %0 = vector.transpose %arg0, [2, 0, 1]
178 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
179 /// %1 = vector.contract {indexing_maps = [
180 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
181 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
182 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
183 /// iterator_types = ["parallel", "parallel", "reduction"],
184 /// kind = add} %0, %arg1, %cst_f0
185 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
186 /// ```
187 /// Gets converted to:
188 /// ```
189 /// %1 = vector.contract {indexing_maps = [
190 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
191 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
192 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
193 /// iterator_types = ["parallel", "parallel", "reduction"],
194 /// kind = add} %arg0, %arg1, %cst_f0
195 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
196 /// ```
197 struct CombineContractABTranspose final
198  : public OpRewritePattern<vector::ContractionOp> {
200 
201  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
202  PatternRewriter &rewriter) const override {
204  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
205  Value lhs = contractOp.getLhs();
206  Value rhs = contractOp.getRhs();
207  size_t index = 0;
208  bool changed = false;
209  for (Value *operand : {&lhs, &rhs}) {
210  AffineMap &map = maps[index++];
211  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
212  if (!transposeOp)
213  continue;
214  AffineMap permutationMap = AffineMap::getPermutationMap(
215  extractVector<unsigned>(transposeOp.getTransp()),
216  contractOp.getContext());
217  map = inversePermutation(permutationMap).compose(map);
218  *operand = transposeOp.getVector();
219  changed = true;
220  }
221  if (!changed)
222  return failure();
223  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
224  contractOp, lhs, rhs, contractOp.getAcc(),
225  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
226  return success();
227  }
228 };
229 
230 /// Merges accumulator and result transposes into contract.
231 ///
232 /// For example:
233 /// ```mlir
234 /// %accT = vector.transpose %acc, [0, 2, 1]
235 /// : vector<2x8x4xf32> to vector<2x4x8xf32>
236 /// %contract = vector.contract {
237 /// indexing_maps = [
238 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
239 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
240 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
241 /// ],
242 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
243 /// kind = #vector.kind<add>
244 /// } %lhs, %rhs, %accT
245 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
246 /// %0 = vector.transpose %contract, [0, 2, 1]
247 /// : vector<2x4x8xf32> to vector<2x8x4>
248 /// ```
249 /// Becomes:
250 /// ```mlir
251 /// %0 = vector.contract {
252 /// indexing_maps = [
253 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
254 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
255 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
256 /// ],
257 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
258 /// kind = #vector.kind<add>
259 /// } %lhs, %rhs, %acc
260 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
261 /// ```
262 struct CombineContractResultTranspose final
263  : public OpRewritePattern<vector::TransposeOp> {
265 
266  LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
267  PatternRewriter &rewriter) const override {
268  auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
269  if (!contractOp || !contractOp->hasOneUse())
270  return failure();
271 
272  auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
273  if (!accTOp)
274  return failure();
275 
276  MLIRContext *context = contractOp.getContext();
277  auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
278  AffineMap contractMap = maps.back();
279 
280  // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
281  // To index into A in contract, we need revert(f)(g(C)) -> A.
282  auto accTMap = AffineMap::getPermutationMap(
283  extractVector<unsigned>(accTOp.getTransp()), context);
284 
285  // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
286  // To index into E in contract, we need h(g(C)) -> E.
287  auto resTMap = AffineMap::getPermutationMap(
288  extractVector<unsigned>(resTOp.getTransp()), context);
289  auto combinedResMap = resTMap.compose(contractMap);
290 
291  // The accumulator and result share the same indexing map. So they should be
292  // the same to be able to merge. This means combinedResMap is the same as
293  // inversePermutation(accTMap).compose(contractMap), which means
294  if (inversePermutation(accTMap) != resTMap)
295  return failure();
296  maps.back() = combinedResMap;
297 
298  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
299  resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
300  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
301  return success();
302  }
303 };
304 
305 /// Merge BroadcastOp into ContractionOp user.
306 /// Ex:
307 /// ```
308 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
309 /// %1 = vector.contract {indexing_maps = [
310 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
311 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
312 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
313 /// iterator_types = ["parallel", "parallel", "reduction"],
314 /// kind = add} %0, %arg1, %cst_f0
315 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
316 /// ```
317 /// Gets converted to:
318 /// ```
319 /// %1 = vector.contract {indexing_maps = [
320 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
321 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
322 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
323 /// iterator_types = ["parallel", "parallel", "reduction"],
324 /// kind = add} %arg0, %arg1, %cst_f0
325 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
326 /// ```
327 struct CombineContractBroadcast
328  : public OpRewritePattern<vector::ContractionOp> {
330 
331  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
332  PatternRewriter &rewriter) const override {
334  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
335  Value lhs = contractOp.getLhs();
336  Value rhs = contractOp.getRhs();
337  size_t index = 0;
338  bool changed = false;
339  for (Value *operand : {&lhs, &rhs}) {
340  AffineMap &map = maps[index++];
341  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
342  if (!broadcast)
343  continue;
344  // contractionOp can only take vector as operands.
345  auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
346  if (!srcType ||
347  srcType.getRank() == broadcast.getResultVectorType().getRank())
348  continue;
349  int64_t rankDiff =
350  broadcast.getResultVectorType().getRank() - srcType.getRank();
351  bool innerDimBroadcast = false;
352  SmallVector<AffineExpr> originalDims;
353  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
354  if (dim.value() != broadcast.getResultVectorType().getDimSize(
355  rankDiff + dim.index())) {
356  innerDimBroadcast = true;
357  break;
358  }
359  originalDims.push_back(
360  rewriter.getAffineDimExpr(dim.index() + rankDiff));
361  }
362  // Contract doesn't support inner dimension broadcast. Once this is
363  // relaxed we can remove this case.
364  if (innerDimBroadcast)
365  continue;
366 
367  // It would be incorrect to fold a broadcast onto a reduction dimension
368  // of non-unit size.
369  bool nonUnitDimReductionBroadcast = false;
370  for (int64_t i = 0; i < rankDiff; ++i) {
371  if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
372  isReductionIterator(contractOp.getIteratorTypes()
373  .getValue()[map.getDimPosition(i)])) {
374  nonUnitDimReductionBroadcast = true;
375  break;
376  }
377  }
378  if (nonUnitDimReductionBroadcast)
379  continue;
380 
381  AffineMap broadcastMap =
382  AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
383  originalDims, contractOp.getContext());
384  map = broadcastMap.compose(map);
385  *operand = broadcast.getSource();
386  changed = true;
387  }
388 
389  if (!changed)
390  return failure();
391 
392  // Determine which dims are usused, now that the maps have been composed
393  // with the broadcast maps.
394  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
395  // Compress unused dims.
396  for (auto &m : maps)
397  m = compressDims(m, unusedDimsBitVector);
398  // Compute the combined iterators.
399  SmallVector<Attribute> iterators;
400  for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
401  if (!unusedDimsBitVector.test(i))
402  iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
403  }
404  // Check that compressing unused dims isn't removing all reduction dimension
405  // pairs. For example, if the vector.contract had only one reduction
406  // iterator and that was a unit-dimension created by a broadcast,
407  // then we should bail here, otherwise we would create a contract without
408  // a reduction dimension pair.
409  bool hasReductionIteratorApplyingOnBothSides = false;
410  for (unsigned i = 0; i < iterators.size(); ++i) {
411  if (!isReductionIterator(iterators[i]))
412  continue;
413  if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
414  hasReductionIteratorApplyingOnBothSides = true;
415  break;
416  }
417  }
418  if (!hasReductionIteratorApplyingOnBothSides)
419  return failure();
420 
421  // If the compressed maps have a dimension that is not used by either LHS or
422  // RHS then the ContractionOp verifier would fail.
423  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
424  return failure();
425  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
426  contractOp, lhs, rhs, contractOp.getAcc(),
427  rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
428  return success();
429  }
430 };
431 
432 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
433 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
434 /// casting ops are around these operations.
435 /// Ex:
436 /// ```
437 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
438 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
439 /// ```
440 /// Gets converted to:
441 /// ```
442 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
443 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
444 /// ```
445 struct ReorderCastOpsOnBroadcast
446  : public OpInterfaceRewritePattern<CastOpInterface> {
448 
449  LogicalResult matchAndRewrite(CastOpInterface op,
450  PatternRewriter &rewriter) const override {
451  if (op->getNumOperands() != 1)
452  return failure();
453  auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
454  if (!bcastOp)
455  return failure();
456 
457  Type castResTy = getElementTypeOrSelf(op->getResult(0));
458  if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
459  castResTy = VectorType::get(vecTy.getShape(), castResTy);
460  auto *castOp =
461  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
462  bcastOp.getSource(), castResTy, op->getAttrs());
463  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
464  op, op->getResult(0).getType(), castOp->getResult(0));
465  return success();
466  }
467 };
468 
469 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
470 /// transpose ops and contraction ops closer, which kicks in
471 /// CombineContractABTranspose pattern when elementwise ops are between these
472 /// operations. Ex:
473 /// ```
474 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
475 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
476 /// %r = arith.addf %at, %bt : vector<2x4xf32>
477 /// ```
478 /// Gets converted to:
479 /// ```
480 /// %0 = arith.addf %a, %b : vector<4x2xf32>
481 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
482 /// ```
483 struct ReorderElementwiseOpsOnTranspose final
484  : public OpTraitRewritePattern<OpTrait::Elementwise> {
486  LogicalResult matchAndRewrite(Operation *op,
487  PatternRewriter &rewriter) const override {
488  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
489  return failure();
490 
491  // Make sure all operands are transpose/constant ops and collect their
492  // transposition maps.
493  SmallVector<ArrayAttr> transposeMaps;
494  transposeMaps.reserve(op->getNumOperands());
495  // Record the initial type before transposition. We'll use its shape later.
496  // Any type will do here as we will check all transpose maps are the same.
497  VectorType srcType;
498  for (Value operand : op->getOperands()) {
499  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
500  if (transposeOp) {
501  transposeMaps.push_back(transposeOp.getTransp());
502  srcType = transposeOp.getSourceVectorType();
503  } else if (!matchPattern(operand, m_Constant())) {
504  return failure();
505  }
506  }
507  if (transposeMaps.empty())
508  return failure();
509  // This is an elementwise op, so all transposed operands should have the
510  // same type. We need to additionally check that all transposes uses the
511  // same map.
512  if (!llvm::all_equal(transposeMaps))
513  return rewriter.notifyMatchFailure(op, "different transpose map");
514 
515  SmallVector<Value> srcValues;
516  srcValues.reserve(op->getNumOperands());
517 
518  // If there are constant operands, we need to insert inverse transposes for
519  // them. Calculate the inverse order first.
520  auto order = extractVector<unsigned>(transposeMaps.front());
521  SmallVector<int64_t> invOrder(order.size());
522  for (int i = 0, e = order.size(); i < e; ++i)
523  invOrder[order[i]] = i;
524 
525  for (Value operand : op->getOperands()) {
526  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
527  if (transposeOp) {
528  srcValues.push_back(transposeOp.getVector());
529  } else {
530  // This is a constant. Create a reverse transpose op for it.
531  auto vectorType = VectorType::get(
532  srcType.getShape(),
533  operand.getType().cast<VectorType>().getElementType());
534  srcValues.push_back(rewriter.create<vector::TransposeOp>(
535  operand.getLoc(), vectorType, operand,
536  rewriter.getI64ArrayAttr(invOrder)));
537  }
538  }
539 
540  auto vectorType = VectorType::get(
541  srcType.getShape(),
542  op->getResultTypes()[0].cast<VectorType>().getElementType());
543  Operation *elementwiseOp =
544  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
545  vectorType, op->getAttrs());
546  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
547  op, op->getResultTypes()[0], elementwiseOp->getResult(0),
548  transposeMaps.front());
549  return success();
550  }
551 };
552 
553 // Returns the values in `arrayAttr` as an integer vector.
554 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
555  return llvm::to_vector<4>(
556  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
557  [](IntegerAttr attr) { return attr.getInt(); }));
558 }
559 
560 // Shuffles vector.bitcast op after vector.extract op.
561 //
562 // This transforms IR like:
563 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
564 // %1 = vector.extract %0[3] : vector<8xf16>
565 // Into:
566 // %0 = vector.extract %src[1] : vector<4xf32>
567 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
568 // %2 = vector.extract %1[1] : vector<2xf16>
569 struct BubbleDownVectorBitCastForExtract
570  : public OpRewritePattern<vector::ExtractOp> {
572 
573  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574  PatternRewriter &rewriter) const override {
575  // Only support extracting scalars for now.
576  if (extractOp.getSourceVectorType().getRank() != 1)
577  return failure();
578 
579  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
580  if (!castOp)
581  return failure();
582 
583  VectorType castSrcType = castOp.getSourceVectorType();
584  VectorType castDstType = castOp.getResultVectorType();
585  assert(castSrcType.getRank() == castDstType.getRank());
586 
587  // Fail to match if we only have one element in the cast op source.
588  // This is to avoid infinite loop given that this pattern can generate
589  // such cases.
590  if (castSrcType.getNumElements() == 1)
591  return failure();
592 
593  // Only support casting to a larger number of elements or now.
594  // E.g., vector<4xf32> -> vector<8xf16>.
595  if (castSrcType.getNumElements() > castDstType.getNumElements())
596  return failure();
597 
598  unsigned expandRatio =
599  castDstType.getNumElements() / castSrcType.getNumElements();
600 
601  auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
602  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
603  };
604 
605  uint64_t index = getFirstIntValue(extractOp.getPosition());
606 
607  // Get the single scalar (as a vector) in the source value that packs the
608  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
609  VectorType oneScalarType =
610  VectorType::get({1}, castSrcType.getElementType());
611  Value packedValue = rewriter.create<vector::ExtractOp>(
612  extractOp.getLoc(), oneScalarType, castOp.getSource(),
613  rewriter.getI64ArrayAttr(index / expandRatio));
614 
615  // Cast it to a vector with the desired scalar's type.
616  // E.g. f32 -> vector<2xf16>
617  VectorType packedType =
618  VectorType::get({expandRatio}, castDstType.getElementType());
619  Value castedValue = rewriter.create<vector::BitCastOp>(
620  extractOp.getLoc(), packedType, packedValue);
621 
622  // Finally extract the desired scalar.
623  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
624  extractOp, extractOp.getType(), castedValue,
625  rewriter.getI64ArrayAttr(index % expandRatio));
626 
627  return success();
628  }
629 };
630 
631 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
632 //
633 // This transforms IR like:
634 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
635 // %0 = vector.extract_strided_slice %cast {
636 // offsets = [4], sizes = [4], strides = [1]
637 // } : vector<8xf16> to vector<4xf16>
638 // Into:
639 // %0 = vector.extract_strided_slice %src {
640 // offsets = [2], sizes = [2], strides = [1]
641 // } : vector<4xf32> to vector<2xf32>
642 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
643 struct BubbleDownBitCastForStridedSliceExtract
644  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
646 
647  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
648  PatternRewriter &rewriter) const override {
649  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
650  if (!castOp)
651  return failure();
652 
653  VectorType castSrcType = castOp.getSourceVectorType();
654  VectorType castDstType = castOp.getResultVectorType();
655  assert(castSrcType.getRank() == castDstType.getRank());
656 
657  int64_t castSrcLastDim = castSrcType.getShape().back();
658  int64_t castDstLastDim = castDstType.getShape().back();
659  // Require casting to more elements for now; other cases to be implemented.
660  if (castSrcLastDim > castDstLastDim)
661  return failure();
662 
663  // Only accept all one strides for now.
664  if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
665  [](const APInt &val) { return !val.isOne(); }))
666  return failure();
667 
668  unsigned rank = extractOp.getSourceVectorType().getRank();
669  assert(castDstLastDim % castSrcLastDim == 0);
670  int64_t expandRatio = castDstLastDim / castSrcLastDim;
671 
672  // If we have a less number of offsets than the rank, then implicitly we
673  // are selecting the full range for the last bitcasted dimension; other
674  // dimensions aren't affected. Otherwise, we need to scale down the last
675  // dimension's offset given we are extracting from less elements now.
676  ArrayAttr newOffsets = extractOp.getOffsets();
677  if (newOffsets.size() == rank) {
678  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
679  if (offsets.back() % expandRatio != 0)
680  return failure();
681  offsets.back() = offsets.back() / expandRatio;
682  newOffsets = rewriter.getI64ArrayAttr(offsets);
683  }
684 
685  // Similarly for sizes.
686  ArrayAttr newSizes = extractOp.getSizes();
687  if (newSizes.size() == rank) {
688  SmallVector<int64_t> sizes = getIntValueVector(newSizes);
689  if (sizes.back() % expandRatio != 0)
690  return failure();
691  sizes.back() = sizes.back() / expandRatio;
692  newSizes = rewriter.getI64ArrayAttr(sizes);
693  }
694 
695  SmallVector<int64_t> dims =
696  llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
697  dims.back() = dims.back() / expandRatio;
698  VectorType newExtractType =
699  VectorType::get(dims, castSrcType.getElementType());
700 
701  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
702  extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
703  newSizes, extractOp.getStrides());
704 
705  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
706  extractOp, extractOp.getType(), newExtractOp);
707 
708  return success();
709  }
710 };
711 
712 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
713 //
714 // This transforms IR like:
715 // %0 = vector.insert_strided_slice %src, %dst {
716 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
717 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
718 // Into:
719 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
720 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
721 // %2 = vector.insert_strided_slice %src, %dst {
722 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
723 struct BubbleUpBitCastForStridedSliceInsert
724  : public OpRewritePattern<vector::BitCastOp> {
726 
727  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
728  PatternRewriter &rewriter) const override {
729  VectorType castSrcType = bitcastOp.getSourceVectorType();
730  VectorType castDstType = bitcastOp.getResultVectorType();
731  assert(castSrcType.getRank() == castDstType.getRank());
732  // Skip 0-D vector which will not from InsertStridedSliceOp.
733  if (castSrcType.getRank() == 0)
734  return failure();
735 
736  int64_t castSrcLastDim = castSrcType.getShape().back();
737  int64_t castDstLastDim = castDstType.getShape().back();
738  // Require casting to less elements for now; other cases to be implemented.
739  if (castSrcLastDim < castDstLastDim)
740  return failure();
741 
742  assert(castSrcLastDim % castDstLastDim == 0);
743  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
744 
745  auto insertOp =
746  bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
747  if (!insertOp)
748  return failure();
749 
750  // Only accept all one strides for now.
751  if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
752  [](const APInt &val) { return !val.isOne(); }))
753  return failure();
754 
755  unsigned rank = insertOp.getSourceVectorType().getRank();
756  // Require insert op to have the same rank for the source and destination
757  // vector; other cases to be implemented.
758  if (rank != insertOp.getDestVectorType().getRank())
759  return failure();
760 
761  // Requires that shape of insert op src is castable to dstType.
762  unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
763  unsigned destinationWidth =
764  castDstType.getElementType().getIntOrFloatBitWidth();
765  unsigned numElements = destinationWidth / sourceWidth;
766  if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
767  return failure();
768 
769  ArrayAttr newOffsets = insertOp.getOffsets();
770  assert(newOffsets.size() == rank);
771  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
772  if (offsets.back() % shrinkRatio != 0)
773  return failure();
774  offsets.back() = offsets.back() / shrinkRatio;
775  newOffsets = rewriter.getI64ArrayAttr(offsets);
776 
777  SmallVector<int64_t> srcDims =
778  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
779  srcDims.back() = srcDims.back() / shrinkRatio;
780  VectorType newCastSrcType =
781  VectorType::get(srcDims, castDstType.getElementType());
782 
783  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
784  bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
785 
786  SmallVector<int64_t> dstDims =
787  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
788  dstDims.back() = dstDims.back() / shrinkRatio;
789  VectorType newCastDstType =
790  VectorType::get(dstDims, castDstType.getElementType());
791 
792  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
793  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
794 
795  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
796  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
797  insertOp.getStrides());
798 
799  return success();
800  }
801 };
802 
803 // Helper that returns a vector comparison that constructs a mask:
804 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
805 //
806 // If `dim == 0` then the result will be a 0-D vector.
807 //
808 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
809 // much more compact, IR for this operation, but LLVM eventually
810 // generates more elaborate instructions for this intrinsic since it
811 // is very conservative on the boundary conditions.
812 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
813  bool force32BitVectorIndices, int64_t dim,
814  Value b, Value *off = nullptr) {
815  auto loc = op->getLoc();
816  // If we can assume all indices fit in 32-bit, we perform the vector
817  // comparison in 32-bit to get a higher degree of SIMD parallelism.
818  // Otherwise we perform the vector comparison using 64-bit indices.
819  Type idxType =
820  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
821  DenseIntElementsAttr indicesAttr;
822  if (dim == 0 && force32BitVectorIndices) {
823  indicesAttr = DenseIntElementsAttr::get(
824  VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
825  } else if (dim == 0) {
826  indicesAttr = DenseIntElementsAttr::get(
827  VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
828  } else if (force32BitVectorIndices) {
829  indicesAttr = rewriter.getI32VectorAttr(
830  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
831  } else {
832  indicesAttr = rewriter.getI64VectorAttr(
833  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
834  }
835  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
836  // Add in an offset if requested.
837  if (off) {
838  Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
839  Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
840  indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
841  }
842  // Construct the vector comparison.
843  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
844  Value bounds =
845  rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
846  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
847  bounds);
848 }
849 
850 template <typename ConcreteOp>
851 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
852 public:
853  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
854  PatternBenefit benefit = 1)
855  : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
856  force32BitVectorIndices(enableIndexOpt) {}
857 
858  LogicalResult matchAndRewrite(ConcreteOp xferOp,
859  PatternRewriter &rewriter) const override {
860  if (!xferOp.hasOutOfBoundsDim())
861  return failure();
862 
863  if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
864  return failure();
865 
866  Location loc = xferOp->getLoc();
867  VectorType vtp = xferOp.getVectorType();
868 
869  // Create the in-bounds mask with all elements between [0 .. dim - offset)
870  // set and [dim - offset .. vector_length) unset.
871  //
872  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
873  // dimensions here.
874  unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
875  Value off = xferOp.getIndices()[lastIndex];
876  Value dim =
877  vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
878  Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
879  Value mask = rewriter.create<vector::CreateMaskOp>(
880  loc,
881  VectorType::get(vtp.getShape(), rewriter.getI1Type(),
882  vtp.getNumScalableDims()),
883  b);
884  if (xferOp.getMask()) {
885  // Intersect the in-bounds with the mask specified as an op parameter.
886  mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
887  }
888 
889  rewriter.updateRootInPlace(xferOp, [&]() {
890  xferOp.getMaskMutable().assign(mask);
891  xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
892  });
893 
894  return success();
895  }
896 
897 private:
898  const bool force32BitVectorIndices;
899 };
900 
901 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
902 class VectorCreateMaskOpConversion
903  : public OpRewritePattern<vector::CreateMaskOp> {
904 public:
905  explicit VectorCreateMaskOpConversion(MLIRContext *context,
906  bool enableIndexOpt,
907  PatternBenefit benefit = 1)
908  : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
909  force32BitVectorIndices(enableIndexOpt) {}
910 
911  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
912  PatternRewriter &rewriter) const override {
913  auto dstType = op.getType();
914  if (dstType.cast<VectorType>().isScalable())
915  return failure();
916  int64_t rank = dstType.getRank();
917  if (rank > 1)
918  return failure();
919  rewriter.replaceOp(
920  op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
921  rank == 0 ? 0 : dstType.getDimSize(0),
922  op.getOperand(0)));
923  return success();
924  }
925 
926 private:
927  const bool force32BitVectorIndices;
928 };
929 
930 // Drop inner most contiguous unit dimensions from transfer_read operand.
931 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
933 
934  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
935  PatternRewriter &rewriter) const override {
936  // TODO: support 0-d corner case.
937  if (readOp.getTransferRank() == 0)
938  return failure();
939 
940  // TODO: support mask.
941  if (readOp.getMask())
942  return failure();
943 
944  auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
945  if (!srcType || !srcType.hasStaticShape())
946  return failure();
947 
948  if (!readOp.getPermutationMap().isMinorIdentity())
949  return failure();
950 
951  auto targetType = readOp.getVectorType();
952  if (targetType.getRank() <= 1)
953  return failure();
954 
955  SmallVector<int64_t> srcStrides;
956  int64_t srcOffset;
957  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
958  return failure();
959 
960  size_t dimsToDrop = 0;
961  for (size_t i = 1; i < srcStrides.size(); ++i) {
962  int dim = srcType.getRank() - i - 1;
963  if (srcStrides[dim] == 1) {
964  dimsToDrop++;
965  } else {
966  break;
967  }
968  }
969  if (dimsToDrop == 0)
970  return failure();
971 
972  auto resultTargetVecType =
973  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
974  targetType.getElementType());
975 
976  MemRefType resultMemrefType;
977  MemRefLayoutAttrInterface layout = srcType.getLayout();
978  if (layout.isa<AffineMapAttr>() && layout.isIdentity()) {
979  resultMemrefType = MemRefType::get(
980  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
981  nullptr, srcType.getMemorySpace());
982  } else {
983  MemRefLayoutAttrInterface updatedLayout;
984  if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
985  auto strides =
986  llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
987  updatedLayout = StridedLayoutAttr::get(strided.getContext(),
988  strided.getOffset(), strides);
989  } else {
990  AffineMap map = srcType.getLayout().getAffineMap();
991  int numSymbols = map.getNumSymbols();
992  for (size_t i = 0; i < dimsToDrop; ++i) {
993  int dim = srcType.getRank() - i - 1;
994  map = map.replace(rewriter.getAffineDimExpr(dim),
995  rewriter.getAffineConstantExpr(0),
996  map.getNumDims() - 1, numSymbols);
997  }
998  }
999  resultMemrefType = MemRefType::get(
1000  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1001  updatedLayout, srcType.getMemorySpace());
1002  }
1003 
1004  auto loc = readOp.getLoc();
1005  SmallVector<int64_t> offsets(srcType.getRank(), 0);
1006  SmallVector<int64_t> strides(srcType.getRank(), 1);
1007 
1008  ArrayAttr inBoundsAttr =
1009  readOp.getInBounds()
1010  ? rewriter.getArrayAttr(
1011  readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1012  : ArrayAttr();
1013  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1014  loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
1015  strides);
1016  auto permMap = getTransferMinorIdentityMap(
1017  rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
1018  Value result = rewriter.create<vector::TransferReadOp>(
1019  loc, resultTargetVecType, rankedReducedView,
1020  readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1021  readOp.getPadding(),
1022  // TODO: support mask.
1023  /*mask=*/Value(), inBoundsAttr);
1024  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1025  result);
1026  return success();
1027  }
1028 };
1029 
1030 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1031 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1032 /// with the RHS transposed) lowering.
1033 struct CanonicalizeContractMatmulToMMT final
1034  : OpRewritePattern<vector::ContractionOp> {
1036 
1037  using FilterConstraintType =
1038  std::function<LogicalResult(vector::ContractionOp op)>;
1039 
1040  CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1041  FilterConstraintType constraint)
1042  : OpRewritePattern<vector::ContractionOp>(context, benefit),
1043  filter(std::move(constraint)) {}
1044 
1045  LogicalResult matchAndRewrite(vector::ContractionOp op,
1046  PatternRewriter &rewriter) const override {
1047  // TODO: Remove native masks from contraction op?
1048  if (!op.getMasks().empty())
1049  return failure();
1050 
1051  if (failed(filter(op)))
1052  return failure();
1053 
1054  Location loc = op.getLoc();
1055  Value lhs = op.getLhs();
1056  Value rhs = op.getRhs();
1057  Value res = op.getAcc();
1058 
1059  // Set up the parallel/reduction structure in right form.
1060  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1061  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1062  AffineExpr m;
1063  AffineExpr n;
1064  AffineExpr k;
1065  bindDims(rewriter.getContext(), m, n, k);
1066  static constexpr std::array<int64_t, 2> perm = {1, 0};
1067  auto iteratorTypes = op.getIteratorTypes().getValue();
1068  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1069  if (iteratorTypes.size() != 3 ||
1070  !vector::isParallelIterator(iteratorTypes[0]) ||
1071  !vector::isParallelIterator(iteratorTypes[1]) ||
1072  !vector::isReductionIterator(iteratorTypes[2]))
1073  return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1074 
1075  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1076  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1077  if (maps == canonicalForm)
1078  return rewriter.notifyMatchFailure(op, "already in the canonical form");
1079 
1080  // Create a vector transpose making sure to emit zero/sign-extend at the
1081  // end.
1082  auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1083  if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1084  Value trans =
1085  rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1086  return rewriter.create<arith::ExtSIOp>(loc, mat.getType(), trans);
1087  }
1088  if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1089  Value trans =
1090  rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1091  return rewriter.create<arith::ExtUIOp>(loc, mat.getType(), trans);
1092  }
1093  return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1094  };
1095 
1096  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1097  rhs = createTranspose(rhs);
1098  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1099  lhs = createTranspose(lhs);
1100  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1101  rhs = createTranspose(rhs);
1102  lhs = createTranspose(lhs);
1103  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1104  std::swap(rhs, lhs);
1105  rhs = createTranspose(rhs);
1106  lhs = createTranspose(lhs);
1107  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1108  std::swap(rhs, lhs);
1109  rhs = createTranspose(rhs);
1110  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1111  std::swap(lhs, rhs);
1112  lhs = createTranspose(lhs);
1113  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1114  std::swap(lhs, rhs);
1115  } else {
1116  return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1117  }
1118  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1119  op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1120  op.getIteratorTypes());
1121  return success();
1122  };
1123 
1124 private:
1125  FilterConstraintType filter;
1126 };
1127 
1128 } // namespace
1129 
1131  RewritePatternSet &patterns, bool force32BitVectorIndices,
1132  PatternBenefit benefit) {
1133  patterns.add<VectorCreateMaskOpConversion,
1134  MaterializeTransferMask<vector::TransferReadOp>,
1135  MaterializeTransferMask<vector::TransferWriteOp>>(
1136  patterns.getContext(), force32BitVectorIndices, benefit);
1137 }
1138 
1140  PatternBenefit benefit) {
1141  patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
1142 }
1143 
1145  RewritePatternSet &patterns, PatternBenefit benefit) {
1146  patterns.add<BubbleDownVectorBitCastForExtract,
1147  BubbleDownBitCastForStridedSliceExtract,
1148  BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
1149  benefit);
1150 }
1151 
1153  RewritePatternSet &patterns,
1154  std::function<LogicalResult(vector::ContractionOp)> constraint,
1155  PatternBenefit benefit) {
1156  patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
1157  std::move(constraint));
1158 }
1159 
1161  RewritePatternSet &patterns, PatternBenefit benefit) {
1162  patterns.add<MultiReduceToContract, CombineContractBroadcast,
1163  CombineContractABTranspose, CombineContractResultTranspose,
1164  ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1165  patterns.getContext(), benefit);
1166 }
1167 
1170  RewritePatternSet &patterns, PatternBenefit benefit) {
1171  patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
1172 }
1173 
1174 //===----------------------------------------------------------------------===//
1175 // TableGen'd enum attribute definitions
1176 //===----------------------------------------------------------------------===//
1177 
1178 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ArrayAttr attr)
Gets the first integer value from attr, assuming it is an integer array attribute.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:345
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:328
unsigned getNumDims() const
Definition: AffineMap.cpp:324
unsigned getNumResults() const
Definition: AffineMap.cpp:332
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:242
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:441
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:212
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:482
Attributes are known-constant values of operations.
Definition: Attributes.h:25
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:362
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:347
IntegerType getI64Type()
Definition: Builders.cpp:82
IntegerType getI32Type()
Definition: Builders.cpp:80
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:339
MLIRContext * getContext() const
Definition: Builders.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:135
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:141
IntegerType getI1Type()
Definition: Builders.cpp:70
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:259
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:274
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:263
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:312
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:383
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:385
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Value getOperand(unsigned idx)
Definition: Operation.h:329
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:537
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
unsigned getNumOperands()
Definition: Operation.h:325
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:418
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
result_type_range getResultTypes()
Definition: Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:597
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:125
Type getType() const
Return the type of this value.
Definition: Value.h:122
U dyn_cast() const
Definition: Value.h:103
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:170
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:142
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:165
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:37
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:322
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:329
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:667
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:66
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
Definition: AffineMap.cpp:620
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:804
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:248
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361