MLIR  19.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <cassert>
16 #include <cstdint>
17 #include <functional>
18 #include <optional>
19 #include <type_traits>
20 
34 #include "mlir/IR/BuiltinTypes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
42 
43 #include "llvm/ADT/DenseSet.h"
44 #include "llvm/ADT/MapVector.h"
45 #include "llvm/ADT/STLExtras.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include "llvm/Support/raw_ostream.h"
50 
51 #define DEBUG_TYPE "vector-to-vector"
52 
53 using namespace mlir;
54 using namespace mlir::vector;
55 
56 template <typename IntType>
57 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
58  return llvm::to_vector<4>(llvm::map_range(
59  arrayAttr.getAsRange<IntegerAttr>(),
60  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
61 }
62 
63 // Helper to find an index in an affine map.
64 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
65  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
66  int64_t idx = map.getDimPosition(i);
67  if (idx == index)
68  return i;
69  }
70  return std::nullopt;
71 }
72 
73 namespace {
74 
75 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
76 //
77 // Example:
78 //
79 // The following MLIR with cancelling ShapeCastOps:
80 //
81 // %0 = source : vector<5x4x2xf32>
82 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
83 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
84 // %3 = user %2 : vector<5x4x2xf32>
85 //
86 // Should canonicalize to the following:
87 //
88 // %0 = source : vector<5x4x2xf32>
89 // %1 = user %0 : vector<5x4x2xf32>
90 //
91 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
93 
94  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
95  PatternRewriter &rewriter) const override {
96  // Check if 'shapeCastOp' has vector source/result type.
97  auto sourceVectorType =
98  dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
99  auto resultVectorType =
100  dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
101  if (!sourceVectorType || !resultVectorType)
102  return failure();
103 
104  // Check if shape cast op source operand is also a shape cast op.
105  auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
106  shapeCastOp.getSource().getDefiningOp());
107  if (!sourceShapeCastOp)
108  return failure();
109  auto operandSourceVectorType =
110  cast<VectorType>(sourceShapeCastOp.getSource().getType());
111  auto operandResultVectorType = sourceShapeCastOp.getType();
112 
113  // Check if shape cast operations invert each other.
114  if (operandSourceVectorType != resultVectorType ||
115  operandResultVectorType != sourceVectorType)
116  return failure();
117 
118  rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
119  return success();
120  }
121 };
122 
123 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
124 /// Ex:
125 /// ```
126 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
127 /// %1 = vector.multi_reduction add, %0 [1]
128 /// : vector<8x32x16xf32> to vector<8x16xf32>
129 /// ```
130 /// Gets converted to:
131 /// ```
132 /// %1 = vector.contract {indexing_maps = [
133 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
135 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
136 /// iterator_types = ["parallel", "parallel", "reduction"],
137 /// kind = add} %0, %arg1, %cst_f0
138 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
139 /// ```
140 struct MultiReduceToContract
141  : public OpRewritePattern<vector::MultiDimReductionOp> {
143 
144  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
145  PatternRewriter &rewriter) const override {
146  if (reduceOp.getKind() != vector::CombiningKind::ADD)
147  return failure();
148  Operation *mulOp = reduceOp.getSource().getDefiningOp();
149  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
150  return failure();
151  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
152  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
154  SmallVector<vector::IteratorType> iteratorTypes;
155  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
156  if (!isReduceDim.value()) {
157  iteratorTypes.push_back(vector::IteratorType::parallel);
158  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
159  } else {
160  iteratorTypes.push_back(vector::IteratorType::reduction);
161  }
162  }
163  auto dstMap =
164  AffineMap::get(/*dimCount=*/reductionMask.size(),
165  /*symbolCount=*/0, exprs, reduceOp.getContext());
166  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
167  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
168  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
169  rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
170  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
171  return IteratorTypeAttr::get(rewriter.getContext(), t);
172  }))));
173  return success();
174  }
175 };
176 
177 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
178 /// Ex:
179 /// ```
180 /// %0 = vector.transpose %arg0, [2, 0, 1]
181 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
182 /// %1 = vector.contract {indexing_maps = [
183 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
184 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
185 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
186 /// iterator_types = ["parallel", "parallel", "reduction"],
187 /// kind = add} %0, %arg1, %cst_f0
188 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
189 /// ```
190 /// Gets converted to:
191 /// ```
192 /// %1 = vector.contract {indexing_maps = [
193 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
194 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
195 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
196 /// iterator_types = ["parallel", "parallel", "reduction"],
197 /// kind = add} %arg0, %arg1, %cst_f0
198 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
199 /// ```
200 struct CombineContractABTranspose final
201  : public OpRewritePattern<vector::ContractionOp> {
203 
204  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
205  PatternRewriter &rewriter) const override {
207  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
208  Value lhs = contractOp.getLhs();
209  Value rhs = contractOp.getRhs();
210  size_t index = 0;
211  bool changed = false;
212  for (Value *operand : {&lhs, &rhs}) {
213  AffineMap &map = maps[index++];
214  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
215  if (!transposeOp)
216  continue;
217  AffineMap permutationMap = AffineMap::getPermutationMap(
218  transposeOp.getPermutation(), contractOp.getContext());
219  map = inversePermutation(permutationMap).compose(map);
220  *operand = transposeOp.getVector();
221  changed = true;
222  }
223  if (!changed)
224  return failure();
225  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
226  contractOp, lhs, rhs, contractOp.getAcc(),
227  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
228  return success();
229  }
230 };
231 
232 /// Merges accumulator and result transposes into contract.
233 ///
234 /// For example:
235 /// ```mlir
236 /// %accT = vector.transpose %acc, [0, 2, 1]
237 /// : vector<2x8x4xf32> to vector<2x4x8xf32>
238 /// %contract = vector.contract {
239 /// indexing_maps = [
240 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
241 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
242 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
243 /// ],
244 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
245 /// kind = #vector.kind<add>
246 /// } %lhs, %rhs, %accT
247 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
248 /// %0 = vector.transpose %contract, [0, 2, 1]
249 /// : vector<2x4x8xf32> to vector<2x8x4>
250 /// ```
251 /// Becomes:
252 /// ```mlir
253 /// %0 = vector.contract {
254 /// indexing_maps = [
255 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
256 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
257 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
258 /// ],
259 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
260 /// kind = #vector.kind<add>
261 /// } %lhs, %rhs, %acc
262 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
263 /// ```
264 struct CombineContractResultTranspose final
265  : public OpRewritePattern<vector::TransposeOp> {
267 
268  LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
269  PatternRewriter &rewriter) const override {
270  auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
271  if (!contractOp || !contractOp->hasOneUse())
272  return failure();
273 
274  auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
275  if (!accTOp)
276  return failure();
277 
278  MLIRContext *context = contractOp.getContext();
279  auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
280  AffineMap contractMap = maps.back();
281 
282  // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
283  // To index into A in contract, we need revert(f)(g(C)) -> A.
284  auto accTMap =
285  AffineMap::getPermutationMap(accTOp.getPermutation(), context);
286 
287  // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
288  // To index into E in contract, we need h(g(C)) -> E.
289  auto resTMap =
290  AffineMap::getPermutationMap(resTOp.getPermutation(), context);
291  auto combinedResMap = resTMap.compose(contractMap);
292 
293  // The accumulator and result share the same indexing map. So they should be
294  // the same to be able to merge. This means combinedResMap is the same as
295  // inversePermutation(accTMap).compose(contractMap), which means
296  if (inversePermutation(accTMap) != resTMap)
297  return failure();
298  maps.back() = combinedResMap;
299 
300  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
301  resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
302  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
303  return success();
304  }
305 };
306 
307 /// Merge BroadcastOp into ContractionOp user.
308 /// Ex:
309 /// ```
310 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
311 /// %1 = vector.contract {indexing_maps = [
312 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
313 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
314 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
315 /// iterator_types = ["parallel", "parallel", "reduction"],
316 /// kind = add} %0, %arg1, %cst_f0
317 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
318 /// ```
319 /// Gets converted to:
320 /// ```
321 /// %1 = vector.contract {indexing_maps = [
322 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
323 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
324 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
325 /// iterator_types = ["parallel", "parallel", "reduction"],
326 /// kind = add} %arg0, %arg1, %cst_f0
327 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
328 /// ```
329 struct CombineContractBroadcast
330  : public OpRewritePattern<vector::ContractionOp> {
332 
333  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
334  PatternRewriter &rewriter) const override {
336  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
337  Value lhs = contractOp.getLhs();
338  Value rhs = contractOp.getRhs();
339  size_t index = 0;
340  bool changed = false;
341  for (Value *operand : {&lhs, &rhs}) {
342  AffineMap &map = maps[index++];
343  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
344  if (!broadcast)
345  continue;
346  // contractionOp can only take vector as operands.
347  auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
348  if (!srcType ||
349  srcType.getRank() == broadcast.getResultVectorType().getRank())
350  continue;
351  int64_t rankDiff =
352  broadcast.getResultVectorType().getRank() - srcType.getRank();
353  bool innerDimBroadcast = false;
354  SmallVector<AffineExpr> originalDims;
355  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
356  if (dim.value() != broadcast.getResultVectorType().getDimSize(
357  rankDiff + dim.index())) {
358  innerDimBroadcast = true;
359  break;
360  }
361  originalDims.push_back(
362  rewriter.getAffineDimExpr(dim.index() + rankDiff));
363  }
364  // Contract doesn't support inner dimension broadcast. Once this is
365  // relaxed we can remove this case.
366  if (innerDimBroadcast)
367  continue;
368 
369  // It would be incorrect to fold a broadcast onto a reduction dimension
370  // of non-unit size.
371  bool nonUnitDimReductionBroadcast = false;
372  for (int64_t i = 0; i < rankDiff; ++i) {
373  if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
374  isReductionIterator(contractOp.getIteratorTypes()
375  .getValue()[map.getDimPosition(i)])) {
376  nonUnitDimReductionBroadcast = true;
377  break;
378  }
379  }
380  if (nonUnitDimReductionBroadcast)
381  continue;
382 
383  AffineMap broadcastMap =
384  AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
385  originalDims, contractOp.getContext());
386  map = broadcastMap.compose(map);
387  *operand = broadcast.getSource();
388  changed = true;
389  }
390 
391  if (!changed)
392  return failure();
393 
394  // Determine which dims are usused, now that the maps have been composed
395  // with the broadcast maps.
396  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
397  // Compress unused dims.
398  for (auto &m : maps)
399  m = compressDims(m, unusedDimsBitVector);
400  // Compute the combined iterators.
401  SmallVector<Attribute> iterators;
402  for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
403  if (!unusedDimsBitVector.test(i))
404  iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
405  }
406  // Check that compressing unused dims isn't removing all reduction dimension
407  // pairs. For example, if the vector.contract had only one reduction
408  // iterator and that was a unit-dimension created by a broadcast,
409  // then we should bail here, otherwise we would create a contract without
410  // a reduction dimension pair.
411  bool hasReductionIteratorApplyingOnBothSides = false;
412  for (unsigned i = 0; i < iterators.size(); ++i) {
413  if (!isReductionIterator(iterators[i]))
414  continue;
415  if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
416  hasReductionIteratorApplyingOnBothSides = true;
417  break;
418  }
419  }
420  if (!hasReductionIteratorApplyingOnBothSides)
421  return failure();
422 
423  // If the compressed maps have a dimension that is not used by either LHS or
424  // RHS then the ContractionOp verifier would fail.
425  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
426  return failure();
427  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
428  contractOp, lhs, rhs, contractOp.getAcc(),
429  rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
430  return success();
431  }
432 };
433 
434 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
435 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
436 /// casting ops are around these operations.
437 /// Ex:
438 /// ```
439 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
440 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
441 /// ```
442 /// Gets converted to:
443 /// ```
444 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
445 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
446 /// ```
447 struct ReorderCastOpsOnBroadcast
448  : public OpInterfaceRewritePattern<CastOpInterface> {
450 
451  LogicalResult matchAndRewrite(CastOpInterface op,
452  PatternRewriter &rewriter) const override {
453  if (op->getNumOperands() != 1)
454  return failure();
455  auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
456  if (!bcastOp)
457  return failure();
458 
459  Type castResTy = getElementTypeOrSelf(op->getResult(0));
460  if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
461  castResTy = vecTy.clone(castResTy);
462  auto *castOp =
463  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
464  bcastOp.getSource(), castResTy, op->getAttrs());
465  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
466  op, op->getResult(0).getType(), castOp->getResult(0));
467  return success();
468  }
469 };
470 
471 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
472 /// transpose ops and contraction ops closer, which kicks in
473 /// CombineContractABTranspose pattern when elementwise ops are between these
474 /// operations. Ex:
475 /// ```
476 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
477 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
478 /// %r = arith.addf %at, %bt : vector<2x4xf32>
479 /// ```
480 /// Gets converted to:
481 /// ```
482 /// %0 = arith.addf %a, %b : vector<4x2xf32>
483 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
484 /// ```
485 struct ReorderElementwiseOpsOnTranspose final
486  : public OpTraitRewritePattern<OpTrait::Elementwise> {
488  LogicalResult matchAndRewrite(Operation *op,
489  PatternRewriter &rewriter) const override {
490  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
491  return failure();
492 
493  // Make sure all operands are transpose/constant ops and collect their
494  // transposition maps.
495  SmallVector<ArrayRef<int64_t>> transposeMaps;
496  transposeMaps.reserve(op->getNumOperands());
497  // Record the initial type before transposition. We'll use its shape later.
498  // Any type will do here as we will check all transpose maps are the same.
499  VectorType srcType;
500  for (Value operand : op->getOperands()) {
501  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
502  if (transposeOp) {
503  transposeMaps.push_back(transposeOp.getPermutation());
504  srcType = transposeOp.getSourceVectorType();
505  } else if (!matchPattern(operand, m_Constant())) {
506  return failure();
507  }
508  }
509  if (transposeMaps.empty())
510  return failure();
511  // This is an elementwise op, so all transposed operands should have the
512  // same type. We need to additionally check that all transposes uses the
513  // same map.
514  if (!llvm::all_equal(transposeMaps))
515  return rewriter.notifyMatchFailure(op, "different transpose map");
516 
517  SmallVector<Value> srcValues;
518  srcValues.reserve(op->getNumOperands());
519 
520  // If there are constant operands, we need to insert inverse transposes for
521  // them. Calculate the inverse order first.
522  auto order = transposeMaps.front();
523  SmallVector<int64_t> invOrder(order.size());
524  for (int i = 0, e = order.size(); i < e; ++i)
525  invOrder[order[i]] = i;
526 
527  for (Value operand : op->getOperands()) {
528  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
529  if (transposeOp) {
530  srcValues.push_back(transposeOp.getVector());
531  } else {
532  // This is a constant. Create a reverse transpose op for it.
533  auto vectorType =
534  srcType.clone(cast<VectorType>(operand.getType()).getElementType());
535  srcValues.push_back(rewriter.create<vector::TransposeOp>(
536  operand.getLoc(), vectorType, operand, invOrder));
537  }
538  }
539 
540  auto vectorType = srcType.clone(
541  cast<VectorType>(op->getResultTypes()[0]).getElementType());
542  Operation *elementwiseOp =
543  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
544  vectorType, op->getAttrs());
545  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
546  op, op->getResultTypes()[0], elementwiseOp->getResult(0),
547  transposeMaps.front());
548  return success();
549  }
550 };
551 
552 // Returns the values in `arrayAttr` as an integer vector.
553 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
554  return llvm::to_vector<4>(
555  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
556  [](IntegerAttr attr) { return attr.getInt(); }));
557 }
558 
559 // Shuffles vector.bitcast op after vector.extract op.
560 //
561 // This transforms IR like:
562 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
563 // %1 = vector.extract %0[3] : f16 from vector<8xf16>
564 // Into:
565 // %0 = vector.extract %src[1] : f32 from vector<4xf32>
566 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
567 // %2 = vector.extract %1[1] : f16 from vector<2xf16>
568 struct BubbleDownVectorBitCastForExtract
569  : public OpRewritePattern<vector::ExtractOp> {
571 
572  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
573  PatternRewriter &rewriter) const override {
574  // Only support extracting scalars for now.
575  if (extractOp.getSourceVectorType().getRank() != 1)
576  return failure();
577 
578  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
579  if (!castOp)
580  return failure();
581 
582  VectorType castSrcType = castOp.getSourceVectorType();
583  VectorType castDstType = castOp.getResultVectorType();
584  assert(castSrcType.getRank() == castDstType.getRank());
585 
586  // Fail to match if we only have one element in the cast op source.
587  // This is to avoid infinite loop given that this pattern can generate
588  // such cases.
589  if (castSrcType.getNumElements() == 1)
590  return failure();
591 
592  // Only support casting to a larger number of elements or now.
593  // E.g., vector<4xf32> -> vector<8xf16>.
594  if (castSrcType.getNumElements() > castDstType.getNumElements())
595  return failure();
596 
597  unsigned expandRatio =
598  castDstType.getNumElements() / castSrcType.getNumElements();
599 
600  auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
601  assert(values[0].is<Attribute>() && "Unexpected non-constant index");
602  return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
603  };
604 
605  uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
606 
607  // Get the single scalar (as a vector) in the source value that packs the
608  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
609  Location loc = extractOp.getLoc();
610  Value packedValue = rewriter.create<vector::ExtractOp>(
611  loc, castOp.getSource(), index / expandRatio);
612  Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
613  Value zero = rewriter.create<arith::ConstantOp>(
614  loc, packedVecType, rewriter.getZeroAttr(packedVecType));
615  packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
616  /*position=*/0);
617 
618  // Cast it to a vector with the desired scalar's type.
619  // E.g. f32 -> vector<2xf16>
620  VectorType packedType =
621  VectorType::get({expandRatio}, castDstType.getElementType());
622  Value castedValue =
623  rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
624 
625  // Finally extract the desired scalar.
626  rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
627  index % expandRatio);
628  return success();
629  }
630 };
631 
632 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
633 //
634 // This transforms IR like:
635 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
636 // %0 = vector.extract_strided_slice %cast {
637 // offsets = [4], sizes = [4], strides = [1]
638 // } : vector<8xf16> to vector<4xf16>
639 // Into:
640 // %0 = vector.extract_strided_slice %src {
641 // offsets = [2], sizes = [2], strides = [1]
642 // } : vector<4xf32> to vector<2xf32>
643 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
644 struct BubbleDownBitCastForStridedSliceExtract
645  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
647 
648  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649  PatternRewriter &rewriter) const override {
650  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
651  if (!castOp)
652  return failure();
653 
654  VectorType castSrcType = castOp.getSourceVectorType();
655  VectorType castDstType = castOp.getResultVectorType();
656  assert(castSrcType.getRank() == castDstType.getRank());
657 
658  int64_t castSrcLastDim = castSrcType.getShape().back();
659  int64_t castDstLastDim = castDstType.getShape().back();
660  // Require casting to more elements for now; other cases to be implemented.
661  if (castSrcLastDim > castDstLastDim)
662  return failure();
663 
664  // Only accept all one strides for now.
665  if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666  [](const APInt &val) { return !val.isOne(); }))
667  return failure();
668 
669  unsigned rank = extractOp.getSourceVectorType().getRank();
670  assert(castDstLastDim % castSrcLastDim == 0);
671  int64_t expandRatio = castDstLastDim / castSrcLastDim;
672 
673  // If we have a less number of offsets than the rank, then implicitly we
674  // are selecting the full range for the last bitcasted dimension; other
675  // dimensions aren't affected. Otherwise, we need to scale down the last
676  // dimension's offset given we are extracting from less elements now.
677  ArrayAttr newOffsets = extractOp.getOffsets();
678  if (newOffsets.size() == rank) {
679  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
680  if (offsets.back() % expandRatio != 0)
681  return failure();
682  offsets.back() = offsets.back() / expandRatio;
683  newOffsets = rewriter.getI64ArrayAttr(offsets);
684  }
685 
686  // Similarly for sizes.
687  ArrayAttr newSizes = extractOp.getSizes();
688  if (newSizes.size() == rank) {
689  SmallVector<int64_t> sizes = getIntValueVector(newSizes);
690  if (sizes.back() % expandRatio != 0)
691  return failure();
692  sizes.back() = sizes.back() / expandRatio;
693  newSizes = rewriter.getI64ArrayAttr(sizes);
694  }
695 
696  SmallVector<int64_t> dims =
697  llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698  dims.back() = dims.back() / expandRatio;
699  VectorType newExtractType =
700  VectorType::get(dims, castSrcType.getElementType());
701 
702  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
703  extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
704  newSizes, extractOp.getStrides());
705 
706  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
707  extractOp, extractOp.getType(), newExtractOp);
708 
709  return success();
710  }
711 };
712 
713 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
714 //
715 // This transforms IR like:
716 // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
717 // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
718 // Into:
719 // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
720 // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
721 // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
722 //
723 struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
725 
726  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727  PatternRewriter &rewriter) const override {
728  VectorType castSrcType = bitcastOp.getSourceVectorType();
729  VectorType castDstType = bitcastOp.getResultVectorType();
730 
731  // 0-D and scalable vectors are not supported yet.
732  if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733  castDstType.isScalable())
734  return failure();
735 
736  int64_t castSrcLastDim = castSrcType.getShape().back();
737  int64_t castDstLastDim = castDstType.getShape().back();
738  bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
739  int64_t ratio;
740  if (isNumElemsShrink) {
741  assert(castSrcLastDim % castDstLastDim == 0);
742  ratio = castSrcLastDim / castDstLastDim;
743  } else {
744  assert(castDstLastDim % castSrcLastDim == 0);
745  ratio = castDstLastDim / castSrcLastDim;
746  }
747 
748  auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
749  if (!insertOp)
750  return failure();
751 
752  // Only vector sources are supported for now.
753  auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
754  if (!insertSrcType)
755  return failure();
756 
757  // Bitcast the source.
758  SmallVector<int64_t> srcDims(insertSrcType.getShape());
759  srcDims.back() =
760  isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761  VectorType newCastSrcType =
762  VectorType::get(srcDims, castDstType.getElementType());
763  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
764  bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
765 
766  SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
767  dstDims.back() =
768  isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
769  VectorType newCastDstType =
770  VectorType::get(dstDims, castDstType.getElementType());
771 
772  // Bitcast the destination.
773  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
774  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
775 
776  // Generate new insert.
777  rewriter.replaceOpWithNewOp<vector::InsertOp>(
778  bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
779  return success();
780  }
781 };
782 
783 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
784 //
785 // This transforms IR like:
786 // %0 = vector.insert_strided_slice %src, %dst {
787 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
788 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
789 // Into:
790 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
791 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
792 // %2 = vector.insert_strided_slice %src, %dst {
793 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
794 struct BubbleUpBitCastForStridedSliceInsert
795  : public OpRewritePattern<vector::BitCastOp> {
797 
798  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
799  PatternRewriter &rewriter) const override {
800  VectorType castSrcType = bitcastOp.getSourceVectorType();
801  VectorType castDstType = bitcastOp.getResultVectorType();
802  assert(castSrcType.getRank() == castDstType.getRank());
803  // Skip 0-D vector which will not from InsertStridedSliceOp.
804  if (castSrcType.getRank() == 0)
805  return failure();
806 
807  int64_t castSrcLastDim = castSrcType.getShape().back();
808  int64_t castDstLastDim = castDstType.getShape().back();
809  // Require casting to less elements for now; other cases to be implemented.
810  if (castSrcLastDim < castDstLastDim)
811  return failure();
812 
813  assert(castSrcLastDim % castDstLastDim == 0);
814  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
815 
816  auto insertOp =
817  bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
818  if (!insertOp)
819  return failure();
820 
821  // Only accept all one strides for now.
822  if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
823  [](const APInt &val) { return !val.isOne(); }))
824  return failure();
825 
826  unsigned rank = insertOp.getSourceVectorType().getRank();
827  // Require insert op to have the same rank for the source and destination
828  // vector; other cases to be implemented.
829  if (rank != insertOp.getDestVectorType().getRank())
830  return failure();
831 
832  // Requires that shape of insert op src is castable to dstType.
833  unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
834  unsigned destinationWidth =
835  castDstType.getElementType().getIntOrFloatBitWidth();
836  unsigned numElements = destinationWidth / sourceWidth;
837  if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
838  return failure();
839 
840  ArrayAttr newOffsets = insertOp.getOffsets();
841  assert(newOffsets.size() == rank);
842  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
843  if (offsets.back() % shrinkRatio != 0)
844  return failure();
845  offsets.back() = offsets.back() / shrinkRatio;
846  newOffsets = rewriter.getI64ArrayAttr(offsets);
847 
848  SmallVector<int64_t> srcDims =
849  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
850  srcDims.back() = srcDims.back() / shrinkRatio;
851  VectorType newCastSrcType =
852  VectorType::get(srcDims, castDstType.getElementType());
853 
854  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
855  bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
856 
857  SmallVector<int64_t> dstDims =
858  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
859  dstDims.back() = dstDims.back() / shrinkRatio;
860  VectorType newCastDstType =
861  VectorType::get(dstDims, castDstType.getElementType());
862 
863  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
864  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
865 
866  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
867  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
868  insertOp.getStrides());
869 
870  return success();
871  }
872 };
873 
874 // Breaks down vector.bitcast op
875 //
876 // This transforms IR like:
877 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
878 // Into:
879 // %cst = vector.splat %c0_f32 : vector<4xf32>
880 // %1 = vector.extract_strided_slice %0 {
881 // offsets = [0], sizes = [4], strides = [1]
882 // } : vector<8xf16> to vector<4xf16>
883 // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
884 // %4 = vector.insert_strided_slice %2, %cst {
885 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
886 // %5 = vector.extract_strided_slice %0 {
887 // offsets = [4], sizes = [4], strides = [1]
888 // } : vector<8xf16> to vector<4xf16>
889 // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
890 // %7 = vector.insert_strided_slice %6, %cst {
891 // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
892 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
894 
895 public:
896  BreakDownVectorBitCast(MLIRContext *context,
897  std::function<bool(vector::BitCastOp)> controlFn,
898  PatternBenefit benefit)
899  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
900 
901  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
902  PatternRewriter &rewriter) const override {
903 
904  if (controlFn && !controlFn(bitcastOp))
905  return failure();
906 
907  VectorType castSrcType = bitcastOp.getSourceVectorType();
908  VectorType castDstType = bitcastOp.getResultVectorType();
909  assert(castSrcType.getRank() == castDstType.getRank());
910 
911  // Only support rank 1 case for now.
912  if (castSrcType.getRank() != 1)
913  return failure();
914 
915  int64_t castSrcLastDim = castSrcType.getShape().back();
916  int64_t castDstLastDim = castDstType.getShape().back();
917  // Require casting to less elements for now; other cases to be implemented.
918  if (castSrcLastDim < castDstLastDim)
919  return failure();
920 
921  assert(castSrcLastDim % castDstLastDim == 0);
922  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
923  // Nothing to do if it is already bitcasting to a single element.
924  if (castSrcLastDim == shrinkRatio)
925  return failure();
926 
927  Location loc = bitcastOp.getLoc();
928  Type elemType = castDstType.getElementType();
929  assert(elemType.isSignlessIntOrIndexOrFloat());
930 
931  Value zero = rewriter.create<arith::ConstantOp>(
932  loc, elemType, rewriter.getZeroAttr(elemType));
933  Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
934 
935  SmallVector<int64_t> sliceShape{castDstLastDim};
936  SmallVector<int64_t> strides{1};
937  VectorType newCastDstType =
938  VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
939  castDstType.getElementType());
940 
941  for (int i = 0, e = shrinkRatio; i < e; ++i) {
942  Value extracted = rewriter.create<ExtractStridedSliceOp>(
943  loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
944  sliceShape, strides);
945  Value bitcast =
946  rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
947  res = rewriter.create<InsertStridedSliceOp>(
948  loc, bitcast, res,
949  ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
950  }
951  rewriter.replaceOp(bitcastOp, res);
952  return success();
953  }
954 
955 private:
956  std::function<bool(BitCastOp)> controlFn;
957 };
958 
959 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
960 /// ```
961 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
962 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
963 /// %r = arith.addi %a, %b : vector<1x4xindex>
964 /// ```
965 /// Gets converted to:
966 /// ```
967 /// %r = arith.addi %arg0, %arg1 : index
968 /// %b = vector.broadcast %r : index to vector<1x4xindex>
969 /// ```
970 ///
971 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
972 /// ops.
973 struct ReorderElementwiseOpsOnBroadcast final
974  : public OpTraitRewritePattern<OpTrait::Elementwise> {
976  LogicalResult matchAndRewrite(Operation *op,
977  PatternRewriter &rewriter) const override {
978  if (op->getNumResults() != 1)
979  return failure();
980  if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
981  return failure();
983  return failure();
984  if (op->getNumOperands() == 0 ||
985  op->getResults()[0].getType() != op->getOperand(0).getType()) {
986  return failure();
987  }
988  // Avoid operations that only accept vector types, since broadcast
989  // source might be scalar types.
990  if (isa<vector::FMAOp>(op)) {
991  return failure();
992  }
993 
994  // Get the type of the lhs operand
995  auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
996  if (!lhsBcastOrSplat ||
997  !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
998  return failure();
999  auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1000 
1001  // Make sure that all operands are broadcast from identical types:
1002  // * scalar (`vector.broadcast` + `vector.splat`), or
1003  // * vector (`vector.broadcast`).
1004  // Otherwise the re-ordering wouldn't be safe.
1005  if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
1006  auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1007  if (bcast)
1008  return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1009  auto splat = val.getDefiningOp<vector::SplatOp>();
1010  if (splat)
1011  return (splat.getOperand().getType() == lhsBcastOrSplatType);
1012  return false;
1013  })) {
1014  return failure();
1015  }
1016 
1017  // Collect the source values before broadcasting
1018  SmallVector<Value> srcValues;
1019  srcValues.reserve(op->getNumOperands());
1020  for (Value operand : op->getOperands()) {
1021  srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1022  }
1023 
1024  // Create the "elementwise" Op
1025  Operation *elementwiseOp =
1026  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1027  lhsBcastOrSplatType, op->getAttrs());
1028 
1029  // Replace the original Op with the elementwise Op
1030  auto vectorType = op->getResultTypes()[0];
1031  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1032  op, vectorType, elementwiseOp->getResults());
1033 
1034  return success();
1035  }
1036 };
1037 
1038 // Helper that returns a vector comparison that constructs a mask:
1039 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1040 //
1041 // If `dim == 0` then the result will be a 0-D vector.
1042 //
1043 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1044 // much more compact, IR for this operation, but LLVM eventually
1045 // generates more elaborate instructions for this intrinsic since it
1046 // is very conservative on the boundary conditions.
1047 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1048  bool force32BitVectorIndices, int64_t dim,
1049  Value b, Value *off = nullptr) {
1050  auto loc = op->getLoc();
1051  // If we can assume all indices fit in 32-bit, we perform the vector
1052  // comparison in 32-bit to get a higher degree of SIMD parallelism.
1053  // Otherwise we perform the vector comparison using 64-bit indices.
1054  Type idxType =
1055  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1056  DenseIntElementsAttr indicesAttr;
1057  if (dim == 0 && force32BitVectorIndices) {
1058  indicesAttr = DenseIntElementsAttr::get(
1060  } else if (dim == 0) {
1061  indicesAttr = DenseIntElementsAttr::get(
1063  } else if (force32BitVectorIndices) {
1064  indicesAttr = rewriter.getI32VectorAttr(
1065  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1066  } else {
1067  indicesAttr = rewriter.getI64VectorAttr(
1068  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1069  }
1070  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
1071  // Add in an offset if requested.
1072  if (off) {
1073  Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1074  Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1075  indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1076  }
1077  // Construct the vector comparison.
1078  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1079  Value bounds =
1080  rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1081  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1082  bounds);
1083 }
1084 
1085 template <typename ConcreteOp>
1086 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1087 public:
1088  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1089  PatternBenefit benefit = 1)
1090  : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1091  force32BitVectorIndices(enableIndexOpt) {}
1092 
1093  LogicalResult matchAndRewrite(ConcreteOp xferOp,
1094  PatternRewriter &rewriter) const override {
1095  if (!xferOp.hasOutOfBoundsDim())
1096  return failure();
1097 
1098  if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1099  return failure();
1100 
1101  Location loc = xferOp->getLoc();
1102  VectorType vtp = xferOp.getVectorType();
1103 
1104  // Create the in-bounds mask with all elements between [0 .. dim - offset)
1105  // set and [dim - offset .. vector_length) unset.
1106  //
1107  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1108  // dimensions here.
1109  unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1110  Value off = xferOp.getIndices()[lastIndex];
1111  Value dim =
1112  vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
1113  Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1114  Value mask = rewriter.create<vector::CreateMaskOp>(
1115  loc,
1116  VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1117  vtp.getScalableDims()),
1118  b);
1119  if (xferOp.getMask()) {
1120  // Intersect the in-bounds with the mask specified as an op parameter.
1121  mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1122  }
1123 
1124  rewriter.modifyOpInPlace(xferOp, [&]() {
1125  xferOp.getMaskMutable().assign(mask);
1126  xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1127  });
1128 
1129  return success();
1130  }
1131 
1132 private:
1133  const bool force32BitVectorIndices;
1134 };
1135 
1136 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1137 class VectorCreateMaskOpConversion
1138  : public OpRewritePattern<vector::CreateMaskOp> {
1139 public:
1140  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1141  bool enableIndexOpt,
1142  PatternBenefit benefit = 1)
1143  : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1144  force32BitVectorIndices(enableIndexOpt) {}
1145 
1146  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1147  PatternRewriter &rewriter) const override {
1148  auto dstType = op.getType();
1149  if (cast<VectorType>(dstType).isScalable())
1150  return failure();
1151  int64_t rank = dstType.getRank();
1152  if (rank > 1)
1153  return failure();
1154  rewriter.replaceOp(
1155  op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1156  rank == 0 ? 0 : dstType.getDimSize(0),
1157  op.getOperand(0)));
1158  return success();
1159  }
1160 
1161 private:
1162  const bool force32BitVectorIndices;
1163 };
1164 
1165 /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1166 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1167  auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1168  // TODO: Support non-dense constant.
1169  if (!denseAttr)
1170  return false;
1171 
1172  assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1173  return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1174 }
1175 
1176 /// Folds a select operation between an all-true and all-false vector. For now,
1177 /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1178 ///
1179 /// %true = arith.constant dense<true> : vector<1xi1>
1180 /// %false = arith.constant dense<false> : vector<1xi1>
1181 /// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1182 /// =>
1183 /// %result = vector.broadcast %cond : i1 to vector<1xi1>
1184 ///
1185 /// InstCombine seems to handle vectors with multiple elements but not the
1186 /// single element ones.
1187 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1189 
1190  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1191  PatternRewriter &rewriter) const override {
1192  auto vecType = dyn_cast<VectorType>(selectOp.getType());
1193  if (!vecType || !vecType.getElementType().isInteger(1))
1194  return failure();
1195 
1196  // Only scalar conditions can be folded.
1197  Value cond = selectOp.getCondition();
1198  if (isa<VectorType>(cond.getType()))
1199  return failure();
1200 
1201  // TODO: Support n-D and scalable vectors.
1202  if (vecType.getRank() != 1 || vecType.isScalable())
1203  return failure();
1204 
1205  // TODO: Support vectors with multiple elements.
1206  if (vecType.getShape()[0] != 1)
1207  return failure();
1208 
1209  auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1210  if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1211  return failure();
1212 
1213  auto falseConst =
1214  selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1215  if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1216  return failure();
1217 
1218  // Replace select with its condition broadcasted to single element vector.
1219  auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1220  auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1221  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1222  return success();
1223  }
1224 };
1225 
1226 /// Returns the number of dims can be folded away from transfer ops. It returns
1227 /// a failure if it can not determine the number of dims to be folded.
1228 ///
1229 /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1230 /// `vectorType` is vector<16x16x1x1xf32>
1231 /// (there two inner most dims can be dropped by memref.subview ops)
1232 ///
1233 /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1234 /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1235 /// (only the inner most unit dim of `srcType` can be dropped)
1236 ///
1237 /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1238 /// `vectorType` is vector<16x16x1x[1]xf32>
1239 /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1240 /// unit")
1241 static FailureOr<size_t>
1242 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1243  SmallVector<int64_t> srcStrides;
1244  int64_t srcOffset;
1245  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1246  return failure();
1247 
1248  auto isUnitDim = [](VectorType type, int dim) {
1249  return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1250  };
1251 
1252  // According to vector.transfer_read/write semantics, the vector can be a
1253  // slice. Thus, we have to offset the check index with `rankDiff` in
1254  // `srcStrides` and source dim sizes.
1255  size_t result = 0;
1256  int rankDiff = srcType.getRank() - vectorType.getRank();
1257  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1258  // Check that the inner dim size is 1 for both memref type and vector slice.
1259  // It can be folded only if they are 1 and the stride is 1.
1260  int dim = vectorType.getRank() - i - 1;
1261  if (srcStrides[dim + rankDiff] != 1 ||
1262  srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1263  break;
1264  result++;
1265  }
1266  return result;
1267 }
1268 
1269 /// Drop inner most contiguous unit dimensions from transfer_read operand.
1270 class DropInnerMostUnitDimsTransferRead
1271  : public OpRewritePattern<vector::TransferReadOp> {
1273 
1274  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1275  PatternRewriter &rewriter) const override {
1276  // TODO: support 0-d corner case.
1277  if (readOp.getTransferRank() == 0)
1278  return failure();
1279 
1280  // TODO: support mask.
1281  if (readOp.getMask())
1282  return failure();
1283 
1284  auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1285  if (!srcType)
1286  return failure();
1287 
1288  if (!readOp.getPermutationMap().isMinorIdentity())
1289  return failure();
1290 
1291  auto targetType = readOp.getVectorType();
1292  if (targetType.getRank() <= 1)
1293  return failure();
1294 
1295  FailureOr<size_t> maybeDimsToDrop =
1296  getTransferFoldableInnerUnitDims(srcType, targetType);
1297  if (failed(maybeDimsToDrop))
1298  return failure();
1299 
1300  size_t dimsToDrop = maybeDimsToDrop.value();
1301  if (dimsToDrop == 0)
1302  return failure();
1303 
1304  // Make sure that the indices to be dropped are equal 0.
1305  // TODO: Deal with cases when the indices are not 0.
1306  if (!llvm::all_of(readOp.getIndices().take_back(dimsToDrop), isZeroIndex))
1307  return failure();
1308 
1309  auto resultTargetVecType =
1310  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1311  targetType.getElementType(),
1312  targetType.getScalableDims().drop_back(dimsToDrop));
1313 
1314  auto loc = readOp.getLoc();
1316  memref::getMixedSizes(rewriter, loc, readOp.getSource());
1317  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1318  rewriter.getIndexAttr(0));
1319  SmallVector<OpFoldResult> strides(srcType.getRank(),
1320  rewriter.getIndexAttr(1));
1321  auto resultMemrefType =
1322  cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1323  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1324  strides));
1325  ArrayAttr inBoundsAttr =
1326  readOp.getInBounds()
1327  ? rewriter.getArrayAttr(
1328  readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1329  : ArrayAttr();
1330  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1331  loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1332  auto permMap = getTransferMinorIdentityMap(
1333  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1334  Value result = rewriter.create<vector::TransferReadOp>(
1335  loc, resultTargetVecType, rankedReducedView,
1336  readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1337  readOp.getPadding(),
1338  // TODO: support mask.
1339  /*mask=*/Value(), inBoundsAttr);
1340  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1341  result);
1342  return success();
1343  }
1344 };
1345 
1346 /// Drop inner most contiguous unit dimensions from transfer_write operand.
1347 /// E.g.,
1348 /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1349 /// {in_bounds = [true, true, true, true, true]}
1350 /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1351 ///
1352 /// will be replaced with
1353 ///
1354 /// %subview = memref.subview %arg0
1355 /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1356 /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1357 /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1358 /// to vector<1x16x16xf32>
1359 /// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1360 /// {in_bounds = [true, true, true]}
1361 /// : vector<1x16x16xf32>, memref<1x512x16xf32>
1362 ///
1363 /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1364 class DropInnerMostUnitDimsTransferWrite
1365  : public OpRewritePattern<vector::TransferWriteOp> {
1367 
1368  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1369  PatternRewriter &rewriter) const override {
1370  // TODO: support 0-d corner case.
1371  if (writeOp.getTransferRank() == 0)
1372  return failure();
1373 
1374  // TODO: support mask.
1375  if (writeOp.getMask())
1376  return failure();
1377 
1378  auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1379  if (!srcType)
1380  return failure();
1381 
1382  if (!writeOp.getPermutationMap().isMinorIdentity())
1383  return failure();
1384 
1385  auto targetType = writeOp.getVectorType();
1386  if (targetType.getRank() <= 1)
1387  return failure();
1388 
1389  FailureOr<size_t> maybeDimsToDrop =
1390  getTransferFoldableInnerUnitDims(srcType, targetType);
1391  if (failed(maybeDimsToDrop))
1392  return failure();
1393 
1394  size_t dimsToDrop = maybeDimsToDrop.value();
1395  if (dimsToDrop == 0)
1396  return failure();
1397 
1398  auto resultTargetVecType =
1399  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1400  targetType.getElementType(),
1401  targetType.getScalableDims().drop_back(dimsToDrop));
1402 
1403  Location loc = writeOp.getLoc();
1405  memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1406  SmallVector<OpFoldResult> offsets(srcType.getRank(),
1407  rewriter.getIndexAttr(0));
1408  SmallVector<OpFoldResult> strides(srcType.getRank(),
1409  rewriter.getIndexAttr(1));
1410  auto resultMemrefType =
1411  cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1412  srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1413  strides));
1414  ArrayAttr inBoundsAttr =
1415  writeOp.getInBounds()
1416  ? rewriter.getArrayAttr(
1417  writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1418  : ArrayAttr();
1419 
1420  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1421  loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1422  auto permMap = getTransferMinorIdentityMap(
1423  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1424 
1425  auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1426  loc, resultTargetVecType, writeOp.getVector());
1427  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1428  writeOp, shapeCast, rankedReducedView,
1429  writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1430  // TODO: support mask.
1431  /*mask=*/Value(), inBoundsAttr);
1432  return success();
1433  }
1434 };
1435 
1436 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1437 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1438 /// with the RHS transposed) lowering.
1439 struct CanonicalizeContractMatmulToMMT final
1440  : OpRewritePattern<vector::ContractionOp> {
1442 
1443  using FilterConstraintType =
1444  std::function<LogicalResult(vector::ContractionOp op)>;
1445 
1446  CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1447  FilterConstraintType constraint)
1448  : OpRewritePattern<vector::ContractionOp>(context, benefit),
1449  filter(std::move(constraint)) {}
1450 
1451  LogicalResult matchAndRewrite(vector::ContractionOp op,
1452  PatternRewriter &rewriter) const override {
1453  if (failed(filter(op)))
1454  return failure();
1455 
1456  Location loc = op.getLoc();
1457  Value lhs = op.getLhs();
1458  Value rhs = op.getRhs();
1459  Value res = op.getAcc();
1460 
1461  // Set up the parallel/reduction structure in right form.
1462  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1463  auto infer = [&](MapList m) {
1464  return AffineMap::inferFromExprList(m, op.getContext());
1465  };
1466  AffineExpr m;
1467  AffineExpr n;
1468  AffineExpr k;
1469  bindDims(rewriter.getContext(), m, n, k);
1470  static constexpr std::array<int64_t, 2> perm = {1, 0};
1471  auto iteratorTypes = op.getIteratorTypes().getValue();
1472  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1473  if (iteratorTypes.size() != 3 ||
1474  !vector::isParallelIterator(iteratorTypes[0]) ||
1475  !vector::isParallelIterator(iteratorTypes[1]) ||
1476  !vector::isReductionIterator(iteratorTypes[2]))
1477  return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1478 
1479  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1480  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1481  if (maps == canonicalForm)
1482  return rewriter.notifyMatchFailure(op, "already in the canonical form");
1483 
1484  // Create a vector transpose making sure to emit zero/sign-extend at the
1485  // end.
1486  auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1487  if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1488  Value trans =
1489  rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1490  VectorType newType =
1491  cast<VectorType>(trans.getType())
1492  .clone(cast<VectorType>(mat.getType()).getElementType());
1493  return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1494  }
1495  if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1496  Value trans =
1497  rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1498  VectorType newType =
1499  VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1500  cast<VectorType>(mat.getType()).getElementType());
1501  return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1502  }
1503  return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1504  };
1505 
1506  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1507  rhs = createTranspose(rhs);
1508  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1509  lhs = createTranspose(lhs);
1510  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1511  rhs = createTranspose(rhs);
1512  lhs = createTranspose(lhs);
1513  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1514  std::swap(rhs, lhs);
1515  rhs = createTranspose(rhs);
1516  lhs = createTranspose(lhs);
1517  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1518  std::swap(rhs, lhs);
1519  rhs = createTranspose(rhs);
1520  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1521  std::swap(lhs, rhs);
1522  lhs = createTranspose(lhs);
1523  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1524  std::swap(lhs, rhs);
1525  } else {
1526  return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1527  }
1528  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1529  op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1530  op.getIteratorTypes());
1531  return success();
1532  };
1533 
1534 private:
1535  FilterConstraintType filter;
1536 };
1537 
1538 /// Pattern to fold arithmetic extensions on floating point data types into
1539 /// vector contraction operations. linalg.matmul introduces arithmetic
1540 /// extensions on its operands. Please mlir snippets below for more details.
1541 /// ```mlir
1542 /// "linalg.matmul"(%lhs, %rhs, %acc) ({
1543 /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1544 /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1545 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1546 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1547 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1548 /// "linalg.yield"(%acc) : (f32) -> ()
1549 /// })
1550 /// ```
1551 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1552 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1553 /// This pattern folds the arithmetic extensions into the vector contraction and
1554 /// enables the usage of native mixed precision Tensor Core instructions.
1555 struct FoldArithExtIntoContractionOp
1556  : public OpRewritePattern<vector::ContractionOp> {
1558 
1559  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1560  PatternRewriter &rewriter) const override {
1561 
1562  auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1563  auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1564 
1565  if (!lhsDefOp || !rhsDefOp) {
1566  return rewriter.notifyMatchFailure(contractOp,
1567  "no defining op on contract operands");
1568  }
1569 
1570  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1571  contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1572  contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1573  contractOp.getIteratorTypesAttr());
1574 
1575  return success();
1576  }
1577 };
1578 
1579 /// Pattern to fold chained reduction to a series of vector additions and a
1580 /// final reduction. This form should require fewer subgroup operations.
1581 ///
1582 /// ```mlir
1583 /// %a = vector.reduction <add> %x, %acc
1584 /// %b = vector.reduction <add> %y, %a
1585 /// ==>
1586 /// %a = arith.addf %x, %y
1587 /// %b = vector.reduction <add> %a, %acc
1588 /// ```
1589 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1591 
1592  LogicalResult matchAndRewrite(vector::ReductionOp op,
1593  PatternRewriter &rewriter) const override {
1594  // TODO: Handle other combining kinds.
1595  if (op.getKind() != vector::CombiningKind::ADD)
1596  return failure();
1597 
1598  // Accumulator is optional.
1599  Value acc = op.getAcc();
1600  if (!acc)
1601  return failure();
1602 
1603  if (!acc.getType().isIntOrFloat())
1604  return failure();
1605 
1606  auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1607  if (!parentReduction)
1608  return failure();
1609 
1610  Location loc = op.getLoc();
1611  Value vAdd;
1612  if (isa<IntegerType>(acc.getType())) {
1613  vAdd = rewriter.createOrFold<arith::AddIOp>(
1614  loc, parentReduction.getVector(), op.getVector());
1615  } else {
1616  vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1617  op.getVector());
1618  }
1619  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1620  parentReduction.getAcc());
1621  return success();
1622  }
1623 };
1624 
1625 // Scalable unit dimensions are not supported. Folding such dimensions would
1626 // require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1627 // vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1628 // future.
1629 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1630  auto inVecShape = inVecTy.getShape();
1631  SmallVector<int64_t> newShape;
1632  SmallVector<bool> newScalableDims;
1633  for (auto [dim, isScalable] :
1634  llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1635  if (dim == 1 && !isScalable)
1636  continue;
1637 
1638  newShape.push_back(dim);
1639  newScalableDims.push_back(isScalable);
1640  }
1641 
1642  return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1643 }
1644 
1645 /// For vectors with at least an unit dim, replaces:
1646 /// elementwise(a, b)
1647 /// with:
1648 /// sc_a = shape_cast(a)
1649 /// sc_b = shape_cast(b)
1650 /// res = elementwise(sc_a, sc_b)
1651 /// return shape_cast(res)
1652 /// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1653 /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1654 /// required to be rank > 1.
1655 ///
1656 /// Ex:
1657 /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1658 /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1659 ///
1660 /// gets converted to:
1661 ///
1662 /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1663 /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1664 /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1665 /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1666 /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1667 ///
1668 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1669 /// `%cast`.
1670 struct DropUnitDimFromElementwiseOps final
1671  : public OpTraitRewritePattern<OpTrait::Elementwise> {
1673  LogicalResult matchAndRewrite(Operation *op,
1674  PatternRewriter &rewriter) const override {
1675  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1676  return failure();
1677 
1678  auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1679  if (!resultVectorType)
1680  return failure();
1681 
1682  // Check the operand pre-conditions. For `Elementwise` ops all operands are
1683  // guaranteed to have identical shapes (with some exceptions such as
1684  // `arith.select`) and it suffices to only check one of them.
1685  auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1686  if (!sourceVectorType || sourceVectorType.getRank() < 2)
1687  return failure();
1688 
1689  SmallVector<Value> newOperands;
1690  auto loc = op->getLoc();
1691  for (auto operand : op->getOperands()) {
1692  auto opVectorType = cast<VectorType>(operand.getType());
1693  auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1694  if (newVType == opVectorType)
1695  return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1696 
1697  auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1698  newOperands.push_back(opSC);
1699  }
1700 
1701  VectorType newResultVectorType =
1702  dropNonScalableUnitDimFromType(resultVectorType);
1703  // Create an updated elementwise Op without unit dim.
1704  Operation *elementwiseOp =
1705  rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1706  newResultVectorType, op->getAttrs());
1707 
1708  // Restore the unit dim by applying vector.shape_cast to the result.
1709  rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1710  elementwiseOp->getResult(0));
1711 
1712  return success();
1713  }
1714 };
1715 
1716 /// Pattern to eliminate redundant zero-constants added to reduction operands.
1717 /// It's enough for there to be one initial zero value, so we can eliminate the
1718 /// extra ones that feed into `vector.reduction <add>`. These get created by the
1719 /// `ChainedReduction` pattern.
1720 ///
1721 /// ```mlir
1722 /// %a = arith.addf %x, %zero
1723 /// %b = arith.addf %a, %y
1724 /// %c = vector.reduction <add> %b, %acc
1725 /// ==>
1726 /// %b = arith.addf %a, %y
1727 /// %c = vector.reduction <add> %b, %acc
1728 /// ```
1729 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1731 
1732  LogicalResult matchAndRewrite(vector::ReductionOp op,
1733  PatternRewriter &rewriter) const override {
1734  // TODO: Handle other reduction kinds and their identity values.
1735  if (op.getKind() != vector::CombiningKind::ADD)
1736  return failure();
1737 
1738  Type elemType = op.getSourceVectorType().getElementType();
1739  // The integer case should be handled by `arith.addi` folders, only check
1740  // for floats here.
1741  if (!isa<FloatType>(elemType))
1742  return failure();
1743 
1744  auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1745  if (!vAdd)
1746  return failure();
1747  auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1748  if (!addLhs)
1749  return failure();
1750 
1751  if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1752  return failure();
1753 
1754  auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1755  vAdd.getRhs());
1756  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1757  op.getAcc());
1758  return success();
1759  }
1760 };
1761 
1762 /// Example:
1763 /// ```
1764 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
1765 /// ```
1766 /// is transformed into:
1767 /// ```
1768 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
1769 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
1770 /// %a = arith.addf %y, %z : f32
1771 /// ```
1772 struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1773  BreakDownVectorReduction(MLIRContext *context,
1774  unsigned maxNumElementsToExtract,
1775  PatternBenefit benefit)
1776  : OpRewritePattern(context, benefit),
1777  maxNumElementsToExtract(maxNumElementsToExtract) {}
1778 
1779  LogicalResult matchAndRewrite(vector::ReductionOp op,
1780  PatternRewriter &rewriter) const override {
1781  VectorType type = op.getSourceVectorType();
1782  if (type.isScalable() || op.isMasked())
1783  return failure();
1784  assert(type.getRank() == 1 && "Expected a 1-d vector");
1785 
1786  int64_t numElems = type.getNumElements();
1787  if (numElems > maxNumElementsToExtract) {
1788  return rewriter.notifyMatchFailure(
1789  op, llvm::formatv("has too many vector elements ({0}) to break down "
1790  "(max allowed: {1})",
1791  numElems, maxNumElementsToExtract));
1792  }
1793 
1794  Location loc = op.getLoc();
1795  SmallVector<Value> extracted(numElems, nullptr);
1796  for (auto [idx, extractedElem] : llvm::enumerate(extracted))
1797  extractedElem = rewriter.create<vector::ExtractOp>(
1798  loc, op.getVector(), static_cast<int64_t>(idx));
1799 
1800  Value res = extracted.front();
1801  for (auto extractedElem : llvm::drop_begin(extracted))
1802  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
1803  extractedElem, op.getFastmathAttr());
1804  if (Value acc = op.getAcc())
1805  res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
1806  op.getFastmathAttr());
1807 
1808  rewriter.replaceOp(op, res);
1809  return success();
1810  }
1811 
1812 private:
1813  unsigned maxNumElementsToExtract = 0;
1814 };
1815 
1816 } // namespace
1817 
1819  RewritePatternSet &patterns) {
1820  patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
1821 }
1822 
1824  RewritePatternSet &patterns, bool force32BitVectorIndices,
1825  PatternBenefit benefit) {
1826  patterns.add<VectorCreateMaskOpConversion,
1827  MaterializeTransferMask<vector::TransferReadOp>,
1828  MaterializeTransferMask<vector::TransferWriteOp>>(
1829  patterns.getContext(), force32BitVectorIndices, benefit);
1830  patterns.add<FoldI1Select>(patterns.getContext(), benefit);
1831 }
1832 
1834  PatternBenefit benefit) {
1835  patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
1836 }
1837 
1839  RewritePatternSet &patterns, PatternBenefit benefit) {
1840  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1841  patterns.getContext(), benefit);
1842 }
1843 
1845  RewritePatternSet &patterns, PatternBenefit benefit) {
1846  patterns.add<BubbleDownVectorBitCastForExtract,
1847  BubbleDownBitCastForStridedSliceExtract,
1848  BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
1849  patterns.getContext(), benefit);
1850 }
1851 
1853  RewritePatternSet &patterns,
1854  std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
1855  patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
1856  std::move(controlFn), benefit);
1857 }
1858 
1860  RewritePatternSet &patterns,
1861  std::function<LogicalResult(vector::ContractionOp)> constraint,
1862  PatternBenefit benefit) {
1863  patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
1864  std::move(constraint));
1865 }
1866 
1868  RewritePatternSet &patterns, PatternBenefit benefit) {
1869  patterns.add<MultiReduceToContract, CombineContractBroadcast,
1870  CombineContractABTranspose, CombineContractResultTranspose,
1871  ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1872  patterns.getContext(), benefit);
1873 }
1874 
1877  RewritePatternSet &patterns, PatternBenefit benefit) {
1878  patterns.add<DropInnerMostUnitDimsTransferRead,
1879  DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
1880  benefit);
1881 }
1882 
1884  RewritePatternSet &patterns, PatternBenefit benefit) {
1885  patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1886  patterns.getContext(), benefit);
1887 }
1888 
1890  RewritePatternSet &patterns, PatternBenefit benefit) {
1891  patterns.add<ChainedReduction>(patterns.getContext(), benefit);
1892  patterns.add<ReduceRedundantZero>(patterns.getContext(),
1893  PatternBenefit(benefit.getBenefit() + 1));
1894 }
1895 
1897  RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
1898  PatternBenefit benefit) {
1899  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
1900  maxNumElementsToExtract, benefit);
1901 }
1902 
1903 //===----------------------------------------------------------------------===//
1904 // TableGen'd enum attribute definitions
1905 //===----------------------------------------------------------------------===//
1906 
1907 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:403
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:390
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:252
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:544
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:300
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:138
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:144
IntegerType getI1Type()
Definition: Builders.cpp:73
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:384
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:386
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
type_range getType() const
Definition: ValueRange.cpp:39
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:107
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:140
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:154
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:135
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:41
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:757
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:340
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
Definition: AffineMap.cpp:682
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:894
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362