MLIR  18.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <cstdint>
16 #include <functional>
17 #include <optional>
18 #include <type_traits>
19 
33 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/Location.h"
36 #include "mlir/IR/Matchers.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
41 
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/raw_ostream.h"
48 
49 #define DEBUG_TYPE "vector-to-vector"
50 
51 using namespace mlir;
52 using namespace mlir::vector;
53 
54 template <typename IntType>
55 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
56  return llvm::to_vector<4>(llvm::map_range(
57  arrayAttr.getAsRange<IntegerAttr>(),
58  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
59 }
60 
61 // Helper to find an index in an affine map.
62 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
63  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
64  int64_t idx = map.getDimPosition(i);
65  if (idx == index)
66  return i;
67  }
68  return std::nullopt;
69 }
70 
71 namespace {
72 
73 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
74 //
75 // Example:
76 //
77 // The following MLIR with cancelling ShapeCastOps:
78 //
79 // %0 = source : vector<5x4x2xf32>
80 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
81 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
82 // %3 = user %2 : vector<5x4x2xf32>
83 //
84 // Should canonicalize to the following:
85 //
86 // %0 = source : vector<5x4x2xf32>
87 // %1 = user %0 : vector<5x4x2xf32>
88 //
89 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
91 
92  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
93  PatternRewriter &rewriter) const override {
94  // Check if 'shapeCastOp' has vector source/result type.
95  auto sourceVectorType =
96  dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
97  auto resultVectorType =
98  dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
99  if (!sourceVectorType || !resultVectorType)
100  return failure();
101 
102  // Check if shape cast op source operand is also a shape cast op.
103  auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
104  shapeCastOp.getSource().getDefiningOp());
105  if (!sourceShapeCastOp)
106  return failure();
107  auto operandSourceVectorType =
108  cast<VectorType>(sourceShapeCastOp.getSource().getType());
109  auto operandResultVectorType = sourceShapeCastOp.getType();
110 
111  // Check if shape cast operations invert each other.
112  if (operandSourceVectorType != resultVectorType ||
113  operandResultVectorType != sourceVectorType)
114  return failure();
115 
116  rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
117  return success();
118  }
119 };
120 
121 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
122 /// Ex:
123 /// ```
124 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
125 /// %1 = vector.multi_reduction add, %0 [1]
126 /// : vector<8x32x16xf32> to vector<8x16xf32>
127 /// ```
128 /// Gets converted to:
129 /// ```
130 /// %1 = vector.contract {indexing_maps = [
131 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
132 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
133 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
134 /// iterator_types = ["parallel", "parallel", "reduction"],
135 /// kind = add} %0, %arg1, %cst_f0
136 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
137 /// ```
138 struct MultiReduceToContract
139  : public OpRewritePattern<vector::MultiDimReductionOp> {
141 
142  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
143  PatternRewriter &rewriter) const override {
144  if (reduceOp.getKind() != vector::CombiningKind::ADD)
145  return failure();
146  Operation *mulOp = reduceOp.getSource().getDefiningOp();
147  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
148  return failure();
149  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
150  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
152  SmallVector<vector::IteratorType> iteratorTypes;
153  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
154  if (!isReduceDim.value()) {
155  iteratorTypes.push_back(vector::IteratorType::parallel);
156  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
157  } else {
158  iteratorTypes.push_back(vector::IteratorType::reduction);
159  }
160  }
161  auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
162  /*symCount=*/0, exprs, reduceOp.getContext());
163  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
164  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
165  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
166  rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
167  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
168  return IteratorTypeAttr::get(rewriter.getContext(), t);
169  }))));
170  return success();
171  }
172 };
173 
174 /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
175 /// Ex:
176 /// ```
177 /// %0 = vector.transpose %arg0, [2, 0, 1]
178 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
179 /// %1 = vector.contract {indexing_maps = [
180 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
181 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
182 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
183 /// iterator_types = ["parallel", "parallel", "reduction"],
184 /// kind = add} %0, %arg1, %cst_f0
185 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
186 /// ```
187 /// Gets converted to:
188 /// ```
189 /// %1 = vector.contract {indexing_maps = [
190 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
191 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
192 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
193 /// iterator_types = ["parallel", "parallel", "reduction"],
194 /// kind = add} %arg0, %arg1, %cst_f0
195 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
196 /// ```
197 struct CombineContractABTranspose final
198  : public OpRewritePattern<vector::ContractionOp> {
200 
201  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
202  PatternRewriter &rewriter) const override {
204  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
205  Value lhs = contractOp.getLhs();
206  Value rhs = contractOp.getRhs();
207  size_t index = 0;
208  bool changed = false;
209  for (Value *operand : {&lhs, &rhs}) {
210  AffineMap &map = maps[index++];
211  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
212  if (!transposeOp)
213  continue;
214  AffineMap permutationMap = AffineMap::getPermutationMap(
215  transposeOp.getPermutation(), contractOp.getContext());
216  map = inversePermutation(permutationMap).compose(map);
217  *operand = transposeOp.getVector();
218  changed = true;
219  }
220  if (!changed)
221  return failure();
222  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
223  contractOp, lhs, rhs, contractOp.getAcc(),
224  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
225  return success();
226  }
227 };
228 
229 /// Merges accumulator and result transposes into contract.
230 ///
231 /// For example:
232 /// ```mlir
233 /// %accT = vector.transpose %acc, [0, 2, 1]
234 /// : vector<2x8x4xf32> to vector<2x4x8xf32>
235 /// %contract = vector.contract {
236 /// indexing_maps = [
237 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
238 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
239 /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
240 /// ],
241 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
242 /// kind = #vector.kind<add>
243 /// } %lhs, %rhs, %accT
244 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
245 /// %0 = vector.transpose %contract, [0, 2, 1]
246 /// : vector<2x4x8xf32> to vector<2x8x4>
247 /// ```
248 /// Becomes:
249 /// ```mlir
250 /// %0 = vector.contract {
251 /// indexing_maps = [
252 /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
253 /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
254 /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
255 /// ],
256 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
257 /// kind = #vector.kind<add>
258 /// } %lhs, %rhs, %acc
259 /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
260 /// ```
261 struct CombineContractResultTranspose final
262  : public OpRewritePattern<vector::TransposeOp> {
264 
265  LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
266  PatternRewriter &rewriter) const override {
267  auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
268  if (!contractOp || !contractOp->hasOneUse())
269  return failure();
270 
271  auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
272  if (!accTOp)
273  return failure();
274 
275  MLIRContext *context = contractOp.getContext();
276  auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
277  AffineMap contractMap = maps.back();
278 
279  // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
280  // To index into A in contract, we need revert(f)(g(C)) -> A.
281  auto accTMap =
282  AffineMap::getPermutationMap(accTOp.getPermutation(), context);
283 
284  // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
285  // To index into E in contract, we need h(g(C)) -> E.
286  auto resTMap =
287  AffineMap::getPermutationMap(resTOp.getPermutation(), context);
288  auto combinedResMap = resTMap.compose(contractMap);
289 
290  // The accumulator and result share the same indexing map. So they should be
291  // the same to be able to merge. This means combinedResMap is the same as
292  // inversePermutation(accTMap).compose(contractMap), which means
293  if (inversePermutation(accTMap) != resTMap)
294  return failure();
295  maps.back() = combinedResMap;
296 
297  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
298  resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
299  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
300  return success();
301  }
302 };
303 
304 /// Merge BroadcastOp into ContractionOp user.
305 /// Ex:
306 /// ```
307 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
308 /// %1 = vector.contract {indexing_maps = [
309 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
310 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
311 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
312 /// iterator_types = ["parallel", "parallel", "reduction"],
313 /// kind = add} %0, %arg1, %cst_f0
314 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
315 /// ```
316 /// Gets converted to:
317 /// ```
318 /// %1 = vector.contract {indexing_maps = [
319 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
320 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
321 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
322 /// iterator_types = ["parallel", "parallel", "reduction"],
323 /// kind = add} %arg0, %arg1, %cst_f0
324 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
325 /// ```
326 struct CombineContractBroadcast
327  : public OpRewritePattern<vector::ContractionOp> {
329 
330  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
331  PatternRewriter &rewriter) const override {
333  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
334  Value lhs = contractOp.getLhs();
335  Value rhs = contractOp.getRhs();
336  size_t index = 0;
337  bool changed = false;
338  for (Value *operand : {&lhs, &rhs}) {
339  AffineMap &map = maps[index++];
340  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
341  if (!broadcast)
342  continue;
343  // contractionOp can only take vector as operands.
344  auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
345  if (!srcType ||
346  srcType.getRank() == broadcast.getResultVectorType().getRank())
347  continue;
348  int64_t rankDiff =
349  broadcast.getResultVectorType().getRank() - srcType.getRank();
350  bool innerDimBroadcast = false;
351  SmallVector<AffineExpr> originalDims;
352  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
353  if (dim.value() != broadcast.getResultVectorType().getDimSize(
354  rankDiff + dim.index())) {
355  innerDimBroadcast = true;
356  break;
357  }
358  originalDims.push_back(
359  rewriter.getAffineDimExpr(dim.index() + rankDiff));
360  }
361  // Contract doesn't support inner dimension broadcast. Once this is
362  // relaxed we can remove this case.
363  if (innerDimBroadcast)
364  continue;
365 
366  // It would be incorrect to fold a broadcast onto a reduction dimension
367  // of non-unit size.
368  bool nonUnitDimReductionBroadcast = false;
369  for (int64_t i = 0; i < rankDiff; ++i) {
370  if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
371  isReductionIterator(contractOp.getIteratorTypes()
372  .getValue()[map.getDimPosition(i)])) {
373  nonUnitDimReductionBroadcast = true;
374  break;
375  }
376  }
377  if (nonUnitDimReductionBroadcast)
378  continue;
379 
380  AffineMap broadcastMap =
381  AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
382  originalDims, contractOp.getContext());
383  map = broadcastMap.compose(map);
384  *operand = broadcast.getSource();
385  changed = true;
386  }
387 
388  if (!changed)
389  return failure();
390 
391  // Determine which dims are usused, now that the maps have been composed
392  // with the broadcast maps.
393  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
394  // Compress unused dims.
395  for (auto &m : maps)
396  m = compressDims(m, unusedDimsBitVector);
397  // Compute the combined iterators.
398  SmallVector<Attribute> iterators;
399  for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
400  if (!unusedDimsBitVector.test(i))
401  iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
402  }
403  // Check that compressing unused dims isn't removing all reduction dimension
404  // pairs. For example, if the vector.contract had only one reduction
405  // iterator and that was a unit-dimension created by a broadcast,
406  // then we should bail here, otherwise we would create a contract without
407  // a reduction dimension pair.
408  bool hasReductionIteratorApplyingOnBothSides = false;
409  for (unsigned i = 0; i < iterators.size(); ++i) {
410  if (!isReductionIterator(iterators[i]))
411  continue;
412  if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
413  hasReductionIteratorApplyingOnBothSides = true;
414  break;
415  }
416  }
417  if (!hasReductionIteratorApplyingOnBothSides)
418  return failure();
419 
420  // If the compressed maps have a dimension that is not used by either LHS or
421  // RHS then the ContractionOp verifier would fail.
422  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
423  return failure();
424  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
425  contractOp, lhs, rhs, contractOp.getAcc(),
426  rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
427  return success();
428  }
429 };
430 
431 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
432 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
433 /// casting ops are around these operations.
434 /// Ex:
435 /// ```
436 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
437 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
438 /// ```
439 /// Gets converted to:
440 /// ```
441 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
442 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
443 /// ```
444 struct ReorderCastOpsOnBroadcast
445  : public OpInterfaceRewritePattern<CastOpInterface> {
447 
448  LogicalResult matchAndRewrite(CastOpInterface op,
449  PatternRewriter &rewriter) const override {
450  if (op->getNumOperands() != 1)
451  return failure();
452  auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
453  if (!bcastOp)
454  return failure();
455 
456  Type castResTy = getElementTypeOrSelf(op->getResult(0));
457  if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
458  castResTy = vecTy.clone(castResTy);
459  auto *castOp =
460  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
461  bcastOp.getSource(), castResTy, op->getAttrs());
462  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
463  op, op->getResult(0).getType(), castOp->getResult(0));
464  return success();
465  }
466 };
467 
468 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
469 /// transpose ops and contraction ops closer, which kicks in
470 /// CombineContractABTranspose pattern when elementwise ops are between these
471 /// operations. Ex:
472 /// ```
473 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
474 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
475 /// %r = arith.addf %at, %bt : vector<2x4xf32>
476 /// ```
477 /// Gets converted to:
478 /// ```
479 /// %0 = arith.addf %a, %b : vector<4x2xf32>
480 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
481 /// ```
482 struct ReorderElementwiseOpsOnTranspose final
483  : public OpTraitRewritePattern<OpTrait::Elementwise> {
485  LogicalResult matchAndRewrite(Operation *op,
486  PatternRewriter &rewriter) const override {
487  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
488  return failure();
489 
490  // Make sure all operands are transpose/constant ops and collect their
491  // transposition maps.
492  SmallVector<ArrayRef<int64_t>> transposeMaps;
493  transposeMaps.reserve(op->getNumOperands());
494  // Record the initial type before transposition. We'll use its shape later.
495  // Any type will do here as we will check all transpose maps are the same.
496  VectorType srcType;
497  for (Value operand : op->getOperands()) {
498  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
499  if (transposeOp) {
500  transposeMaps.push_back(transposeOp.getPermutation());
501  srcType = transposeOp.getSourceVectorType();
502  } else if (!matchPattern(operand, m_Constant())) {
503  return failure();
504  }
505  }
506  if (transposeMaps.empty())
507  return failure();
508  // This is an elementwise op, so all transposed operands should have the
509  // same type. We need to additionally check that all transposes uses the
510  // same map.
511  if (!llvm::all_equal(transposeMaps))
512  return rewriter.notifyMatchFailure(op, "different transpose map");
513 
514  SmallVector<Value> srcValues;
515  srcValues.reserve(op->getNumOperands());
516 
517  // If there are constant operands, we need to insert inverse transposes for
518  // them. Calculate the inverse order first.
519  auto order = transposeMaps.front();
520  SmallVector<int64_t> invOrder(order.size());
521  for (int i = 0, e = order.size(); i < e; ++i)
522  invOrder[order[i]] = i;
523 
524  for (Value operand : op->getOperands()) {
525  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
526  if (transposeOp) {
527  srcValues.push_back(transposeOp.getVector());
528  } else {
529  // This is a constant. Create a reverse transpose op for it.
530  auto vectorType =
531  srcType.clone(cast<VectorType>(operand.getType()).getElementType());
532  srcValues.push_back(rewriter.create<vector::TransposeOp>(
533  operand.getLoc(), vectorType, operand, invOrder));
534  }
535  }
536 
537  auto vectorType = srcType.clone(
538  cast<VectorType>(op->getResultTypes()[0]).getElementType());
539  Operation *elementwiseOp =
540  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
541  vectorType, op->getAttrs());
542  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
543  op, op->getResultTypes()[0], elementwiseOp->getResult(0),
544  transposeMaps.front());
545  return success();
546  }
547 };
548 
549 // Returns the values in `arrayAttr` as an integer vector.
550 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
551  return llvm::to_vector<4>(
552  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
553  [](IntegerAttr attr) { return attr.getInt(); }));
554 }
555 
556 // Shuffles vector.bitcast op after vector.extract op.
557 //
558 // This transforms IR like:
559 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
560 // %1 = vector.extract %0[3] : f16 from vector<8xf16>
561 // Into:
562 // %0 = vector.extract %src[1] : f32 from vector<4xf32>
563 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
564 // %2 = vector.extract %1[1] : f16 from vector<2xf16>
565 struct BubbleDownVectorBitCastForExtract
566  : public OpRewritePattern<vector::ExtractOp> {
568 
569  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
570  PatternRewriter &rewriter) const override {
571  // Only support extracting scalars for now.
572  if (extractOp.getSourceVectorType().getRank() != 1)
573  return failure();
574 
575  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
576  if (!castOp)
577  return failure();
578 
579  VectorType castSrcType = castOp.getSourceVectorType();
580  VectorType castDstType = castOp.getResultVectorType();
581  assert(castSrcType.getRank() == castDstType.getRank());
582 
583  // Fail to match if we only have one element in the cast op source.
584  // This is to avoid infinite loop given that this pattern can generate
585  // such cases.
586  if (castSrcType.getNumElements() == 1)
587  return failure();
588 
589  // Only support casting to a larger number of elements or now.
590  // E.g., vector<4xf32> -> vector<8xf16>.
591  if (castSrcType.getNumElements() > castDstType.getNumElements())
592  return failure();
593 
594  unsigned expandRatio =
595  castDstType.getNumElements() / castSrcType.getNumElements();
596 
597  auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
598  assert(values[0].is<Attribute>() && "Unexpected non-constant index");
599  return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
600  };
601 
602  uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
603 
604  // Get the single scalar (as a vector) in the source value that packs the
605  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
606  Location loc = extractOp.getLoc();
607  Value packedValue = rewriter.create<vector::ExtractOp>(
608  loc, castOp.getSource(), index / expandRatio);
609  Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
610  Value zero = rewriter.create<arith::ConstantOp>(
611  loc, packedVecType, rewriter.getZeroAttr(packedVecType));
612  packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
613  /*position=*/0);
614 
615  // Cast it to a vector with the desired scalar's type.
616  // E.g. f32 -> vector<2xf16>
617  VectorType packedType =
618  VectorType::get({expandRatio}, castDstType.getElementType());
619  Value castedValue =
620  rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
621 
622  // Finally extract the desired scalar.
623  rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
624  index % expandRatio);
625  return success();
626  }
627 };
628 
629 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
630 //
631 // This transforms IR like:
632 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
633 // %0 = vector.extract_strided_slice %cast {
634 // offsets = [4], sizes = [4], strides = [1]
635 // } : vector<8xf16> to vector<4xf16>
636 // Into:
637 // %0 = vector.extract_strided_slice %src {
638 // offsets = [2], sizes = [2], strides = [1]
639 // } : vector<4xf32> to vector<2xf32>
640 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
641 struct BubbleDownBitCastForStridedSliceExtract
642  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
644 
645  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
646  PatternRewriter &rewriter) const override {
647  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
648  if (!castOp)
649  return failure();
650 
651  VectorType castSrcType = castOp.getSourceVectorType();
652  VectorType castDstType = castOp.getResultVectorType();
653  assert(castSrcType.getRank() == castDstType.getRank());
654 
655  int64_t castSrcLastDim = castSrcType.getShape().back();
656  int64_t castDstLastDim = castDstType.getShape().back();
657  // Require casting to more elements for now; other cases to be implemented.
658  if (castSrcLastDim > castDstLastDim)
659  return failure();
660 
661  // Only accept all one strides for now.
662  if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
663  [](const APInt &val) { return !val.isOne(); }))
664  return failure();
665 
666  unsigned rank = extractOp.getSourceVectorType().getRank();
667  assert(castDstLastDim % castSrcLastDim == 0);
668  int64_t expandRatio = castDstLastDim / castSrcLastDim;
669 
670  // If we have a less number of offsets than the rank, then implicitly we
671  // are selecting the full range for the last bitcasted dimension; other
672  // dimensions aren't affected. Otherwise, we need to scale down the last
673  // dimension's offset given we are extracting from less elements now.
674  ArrayAttr newOffsets = extractOp.getOffsets();
675  if (newOffsets.size() == rank) {
676  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
677  if (offsets.back() % expandRatio != 0)
678  return failure();
679  offsets.back() = offsets.back() / expandRatio;
680  newOffsets = rewriter.getI64ArrayAttr(offsets);
681  }
682 
683  // Similarly for sizes.
684  ArrayAttr newSizes = extractOp.getSizes();
685  if (newSizes.size() == rank) {
686  SmallVector<int64_t> sizes = getIntValueVector(newSizes);
687  if (sizes.back() % expandRatio != 0)
688  return failure();
689  sizes.back() = sizes.back() / expandRatio;
690  newSizes = rewriter.getI64ArrayAttr(sizes);
691  }
692 
693  SmallVector<int64_t> dims =
694  llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
695  dims.back() = dims.back() / expandRatio;
696  VectorType newExtractType =
697  VectorType::get(dims, castSrcType.getElementType());
698 
699  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
700  extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
701  newSizes, extractOp.getStrides());
702 
703  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
704  extractOp, extractOp.getType(), newExtractOp);
705 
706  return success();
707  }
708 };
709 
710 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
711 //
712 // This transforms IR like:
713 // %0 = vector.insert_strided_slice %src, %dst {
714 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
715 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
716 // Into:
717 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
718 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
719 // %2 = vector.insert_strided_slice %src, %dst {
720 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
721 struct BubbleUpBitCastForStridedSliceInsert
722  : public OpRewritePattern<vector::BitCastOp> {
724 
725  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
726  PatternRewriter &rewriter) const override {
727  VectorType castSrcType = bitcastOp.getSourceVectorType();
728  VectorType castDstType = bitcastOp.getResultVectorType();
729  assert(castSrcType.getRank() == castDstType.getRank());
730  // Skip 0-D vector which will not from InsertStridedSliceOp.
731  if (castSrcType.getRank() == 0)
732  return failure();
733 
734  int64_t castSrcLastDim = castSrcType.getShape().back();
735  int64_t castDstLastDim = castDstType.getShape().back();
736  // Require casting to less elements for now; other cases to be implemented.
737  if (castSrcLastDim < castDstLastDim)
738  return failure();
739 
740  assert(castSrcLastDim % castDstLastDim == 0);
741  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
742 
743  auto insertOp =
744  bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
745  if (!insertOp)
746  return failure();
747 
748  // Only accept all one strides for now.
749  if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
750  [](const APInt &val) { return !val.isOne(); }))
751  return failure();
752 
753  unsigned rank = insertOp.getSourceVectorType().getRank();
754  // Require insert op to have the same rank for the source and destination
755  // vector; other cases to be implemented.
756  if (rank != insertOp.getDestVectorType().getRank())
757  return failure();
758 
759  // Requires that shape of insert op src is castable to dstType.
760  unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
761  unsigned destinationWidth =
762  castDstType.getElementType().getIntOrFloatBitWidth();
763  unsigned numElements = destinationWidth / sourceWidth;
764  if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
765  return failure();
766 
767  ArrayAttr newOffsets = insertOp.getOffsets();
768  assert(newOffsets.size() == rank);
769  SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
770  if (offsets.back() % shrinkRatio != 0)
771  return failure();
772  offsets.back() = offsets.back() / shrinkRatio;
773  newOffsets = rewriter.getI64ArrayAttr(offsets);
774 
775  SmallVector<int64_t> srcDims =
776  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
777  srcDims.back() = srcDims.back() / shrinkRatio;
778  VectorType newCastSrcType =
779  VectorType::get(srcDims, castDstType.getElementType());
780 
781  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
782  bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
783 
784  SmallVector<int64_t> dstDims =
785  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
786  dstDims.back() = dstDims.back() / shrinkRatio;
787  VectorType newCastDstType =
788  VectorType::get(dstDims, castDstType.getElementType());
789 
790  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
791  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
792 
793  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
794  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
795  insertOp.getStrides());
796 
797  return success();
798  }
799 };
800 
801 // Breaks down vector.bitcast op
802 //
803 // This transforms IR like:
804 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
805 // Into:
806 // %cst = vector.splat %c0_f32 : vector<4xf32>
807 // %1 = vector.extract_strided_slice %0 {
808 // offsets = [0], sizes = [4], strides = [1]
809 // } : vector<8xf16> to vector<4xf16>
810 // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
811 // %4 = vector.insert_strided_slice %2, %cst {
812 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
813 // %5 = vector.extract_strided_slice %0 {
814 // offsets = [4], sizes = [4], strides = [1]
815 // } : vector<8xf16> to vector<4xf16>
816 // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
817 // %7 = vector.insert_strided_slice %6, %cst {
818 // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
819 struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
821 
822 public:
823  BreakDownVectorBitCast(MLIRContext *context,
824  std::function<bool(vector::BitCastOp)> controlFn,
825  PatternBenefit benefit)
826  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
827 
828  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
829  PatternRewriter &rewriter) const override {
830 
831  if (controlFn && !controlFn(bitcastOp))
832  return failure();
833 
834  VectorType castSrcType = bitcastOp.getSourceVectorType();
835  VectorType castDstType = bitcastOp.getResultVectorType();
836  assert(castSrcType.getRank() == castDstType.getRank());
837 
838  // Only support rank 1 case for now.
839  if (castSrcType.getRank() != 1)
840  return failure();
841 
842  int64_t castSrcLastDim = castSrcType.getShape().back();
843  int64_t castDstLastDim = castDstType.getShape().back();
844  // Require casting to less elements for now; other cases to be implemented.
845  if (castSrcLastDim < castDstLastDim)
846  return failure();
847 
848  assert(castSrcLastDim % castDstLastDim == 0);
849  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
850  // Nothing to do if it is already bitcasting to a single element.
851  if (castSrcLastDim == shrinkRatio)
852  return failure();
853 
854  Location loc = bitcastOp.getLoc();
855  Type elemType = castDstType.getElementType();
856  assert(elemType.isSignlessIntOrIndexOrFloat());
857 
858  Value zero = rewriter.create<arith::ConstantOp>(
859  loc, elemType, rewriter.getZeroAttr(elemType));
860  Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
861 
862  SmallVector<int64_t> sliceShape{castDstLastDim};
863  SmallVector<int64_t> strides{1};
864  VectorType newCastDstType =
865  VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
866  castDstType.getElementType());
867 
868  for (int i = 0, e = shrinkRatio; i < e; ++i) {
869  Value extracted = rewriter.create<ExtractStridedSliceOp>(
870  loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
871  sliceShape, strides);
872  Value bitcast =
873  rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
874  res = rewriter.create<InsertStridedSliceOp>(
875  loc, bitcast, res,
876  ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
877  }
878  rewriter.replaceOp(bitcastOp, res);
879  return success();
880  }
881 
882 private:
883  std::function<bool(BitCastOp)> controlFn;
884 };
885 
886 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
887 /// ```
888 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
889 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
890 /// %r = arith.addi %a, %b : vector<1x4xindex>
891 /// ```
892 /// Gets converted to:
893 /// ```
894 /// %r = arith.addi %arg0, %arg1 : index
895 /// %b = vector.broadcast %r : index to vector<1x4xindex>
896 /// ```
897 ///
898 /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
899 /// ops.
900 struct ReorderElementwiseOpsOnBroadcast final
901  : public OpTraitRewritePattern<OpTrait::Elementwise> {
903  LogicalResult matchAndRewrite(Operation *op,
904  PatternRewriter &rewriter) const override {
905  if (op->getNumResults() != 1)
906  return failure();
907  if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
908  return failure();
910  return failure();
911  if (op->getNumOperands() == 0 ||
912  op->getResults()[0].getType() != op->getOperand(0).getType()) {
913  return failure();
914  }
915  // Avoid operations that only accept vector types, since broadcast
916  // source might be scalar types.
917  if (isa<vector::FMAOp>(op)) {
918  return failure();
919  }
920 
921  // Get the type of the lhs operand
922  auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
923  if (!lhsBcastOrSplat ||
924  !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
925  return failure();
926  auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
927 
928  // Make sure that all operands are broadcast from identical types:
929  // * scalar (`vector.broadcast` + `vector.splat`), or
930  // * vector (`vector.broadcast`).
931  // Otherwise the re-ordering wouldn't be safe.
932  if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
933  auto bcast = val.getDefiningOp<vector::BroadcastOp>();
934  if (bcast)
935  return (bcast.getOperand().getType() == lhsBcastOrSplatType);
936  auto splat = val.getDefiningOp<vector::SplatOp>();
937  if (splat)
938  return (splat.getOperand().getType() == lhsBcastOrSplatType);
939  return false;
940  })) {
941  return failure();
942  }
943 
944  // Collect the source values before broadcasting
945  SmallVector<Value> srcValues;
946  srcValues.reserve(op->getNumOperands());
947  for (Value operand : op->getOperands()) {
948  srcValues.push_back(operand.getDefiningOp()->getOperand(0));
949  }
950 
951  // Create the "elementwise" Op
952  Operation *elementwiseOp =
953  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
954  lhsBcastOrSplatType, op->getAttrs());
955 
956  // Replace the original Op with the elementwise Op
957  auto vectorType = op->getResultTypes()[0];
958  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
959  op, vectorType, elementwiseOp->getResults());
960 
961  return success();
962  }
963 };
964 
965 // Helper that returns a vector comparison that constructs a mask:
966 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
967 //
968 // If `dim == 0` then the result will be a 0-D vector.
969 //
970 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
971 // much more compact, IR for this operation, but LLVM eventually
972 // generates more elaborate instructions for this intrinsic since it
973 // is very conservative on the boundary conditions.
974 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
975  bool force32BitVectorIndices, int64_t dim,
976  Value b, Value *off = nullptr) {
977  auto loc = op->getLoc();
978  // If we can assume all indices fit in 32-bit, we perform the vector
979  // comparison in 32-bit to get a higher degree of SIMD parallelism.
980  // Otherwise we perform the vector comparison using 64-bit indices.
981  Type idxType =
982  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
983  DenseIntElementsAttr indicesAttr;
984  if (dim == 0 && force32BitVectorIndices) {
985  indicesAttr = DenseIntElementsAttr::get(
987  } else if (dim == 0) {
988  indicesAttr = DenseIntElementsAttr::get(
990  } else if (force32BitVectorIndices) {
991  indicesAttr = rewriter.getI32VectorAttr(
992  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
993  } else {
994  indicesAttr = rewriter.getI64VectorAttr(
995  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
996  }
997  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
998  // Add in an offset if requested.
999  if (off) {
1000  Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1001  Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1002  indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1003  }
1004  // Construct the vector comparison.
1005  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1006  Value bounds =
1007  rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1008  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1009  bounds);
1010 }
1011 
1012 template <typename ConcreteOp>
1013 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1014 public:
1015  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1016  PatternBenefit benefit = 1)
1017  : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1018  force32BitVectorIndices(enableIndexOpt) {}
1019 
1020  LogicalResult matchAndRewrite(ConcreteOp xferOp,
1021  PatternRewriter &rewriter) const override {
1022  if (!xferOp.hasOutOfBoundsDim())
1023  return failure();
1024 
1025  if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1026  return failure();
1027 
1028  Location loc = xferOp->getLoc();
1029  VectorType vtp = xferOp.getVectorType();
1030 
1031  // Create the in-bounds mask with all elements between [0 .. dim - offset)
1032  // set and [dim - offset .. vector_length) unset.
1033  //
1034  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1035  // dimensions here.
1036  unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1037  Value off = xferOp.getIndices()[lastIndex];
1038  Value dim =
1039  vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
1040  Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1041  Value mask = rewriter.create<vector::CreateMaskOp>(
1042  loc,
1043  VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1044  vtp.getScalableDims()),
1045  b);
1046  if (xferOp.getMask()) {
1047  // Intersect the in-bounds with the mask specified as an op parameter.
1048  mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1049  }
1050 
1051  rewriter.updateRootInPlace(xferOp, [&]() {
1052  xferOp.getMaskMutable().assign(mask);
1053  xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1054  });
1055 
1056  return success();
1057  }
1058 
1059 private:
1060  const bool force32BitVectorIndices;
1061 };
1062 
1063 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1064 class VectorCreateMaskOpConversion
1065  : public OpRewritePattern<vector::CreateMaskOp> {
1066 public:
1067  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1068  bool enableIndexOpt,
1069  PatternBenefit benefit = 1)
1070  : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1071  force32BitVectorIndices(enableIndexOpt) {}
1072 
1073  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1074  PatternRewriter &rewriter) const override {
1075  auto dstType = op.getType();
1076  if (cast<VectorType>(dstType).isScalable())
1077  return failure();
1078  int64_t rank = dstType.getRank();
1079  if (rank > 1)
1080  return failure();
1081  rewriter.replaceOp(
1082  op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1083  rank == 0 ? 0 : dstType.getDimSize(0),
1084  op.getOperand(0)));
1085  return success();
1086  }
1087 
1088 private:
1089  const bool force32BitVectorIndices;
1090 };
1091 
1092 /// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1093 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1094  auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1095  // TODO: Support non-dense constant.
1096  if (!denseAttr)
1097  return false;
1098 
1099  assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1100  return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1101 }
1102 
1103 /// Folds a select operation between an all-true and all-false vector. For now,
1104 /// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1105 ///
1106 /// %true = arith.constant dense<true> : vector<1xi1>
1107 /// %false = arith.constant dense<false> : vector<1xi1>
1108 /// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1109 /// =>
1110 /// %result = vector.broadcast %cond : i1 to vector<1xi1>
1111 ///
1112 /// InstCombine seems to handle vectors with multiple elements but not the
1113 /// single element ones.
1114 struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1116 
1117  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1118  PatternRewriter &rewriter) const override {
1119  auto vecType = dyn_cast<VectorType>(selectOp.getType());
1120  if (!vecType || !vecType.getElementType().isInteger(1))
1121  return failure();
1122 
1123  // Only scalar conditions can be folded.
1124  Value cond = selectOp.getCondition();
1125  if (isa<VectorType>(cond.getType()))
1126  return failure();
1127 
1128  // TODO: Support n-D and scalable vectors.
1129  if (vecType.getRank() != 1 || vecType.isScalable())
1130  return failure();
1131 
1132  // TODO: Support vectors with multiple elements.
1133  if (vecType.getShape()[0] != 1)
1134  return failure();
1135 
1136  auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1137  if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1138  return failure();
1139 
1140  auto falseConst =
1141  selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1142  if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1143  return failure();
1144 
1145  // Replace select with its condition broadcasted to single element vector.
1146  auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1147  auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1148  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1149  return success();
1150  }
1151 };
1152 
1153 // Drop inner most contiguous unit dimensions from transfer_read operand.
1154 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
1156 
1157  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1158  PatternRewriter &rewriter) const override {
1159  // TODO: support 0-d corner case.
1160  if (readOp.getTransferRank() == 0)
1161  return failure();
1162 
1163  // TODO: support mask.
1164  if (readOp.getMask())
1165  return failure();
1166 
1167  auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1168  if (!srcType || !srcType.hasStaticShape())
1169  return failure();
1170 
1171  if (!readOp.getPermutationMap().isMinorIdentity())
1172  return failure();
1173 
1174  auto targetType = readOp.getVectorType();
1175  if (targetType.getRank() <= 1)
1176  return failure();
1177 
1178  SmallVector<int64_t> srcStrides;
1179  int64_t srcOffset;
1180  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1181  return failure();
1182 
1183  // According to vector.transfer_read semantics, the result can be a slice.
1184  // It pads the indices with `1` starting from beginning. Thus, we have to
1185  // offset the check index with `rankDiff` in `srcStrides` and source dim
1186  // sizes.
1187  size_t dimsToDrop = 0;
1188  int rankDiff = srcType.getRank() - targetType.getRank();
1189  for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
1190  // Check that the inner dim size is 1 for both memref/tensor type and
1191  // vector slice. It can be folded only if they are 1 and the stride is 1.
1192  int dim = targetType.getRank() - i - 1;
1193  if (srcStrides[dim + rankDiff] == 1 &&
1194  srcType.getDimSize(dim + rankDiff) == 1 &&
1195  targetType.getDimSize(dim) == 1) {
1196  dimsToDrop++;
1197  } else {
1198  break;
1199  }
1200  }
1201  if (dimsToDrop == 0)
1202  return failure();
1203 
1204  auto resultTargetVecType =
1205  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1206  targetType.getElementType());
1207 
1208  MemRefType resultMemrefType;
1209  MemRefLayoutAttrInterface layout = srcType.getLayout();
1210  if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
1211  resultMemrefType = MemRefType::get(
1212  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1213  nullptr, srcType.getMemorySpace());
1214  } else {
1215  MemRefLayoutAttrInterface updatedLayout;
1216  if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1217  auto strides =
1218  llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
1219  updatedLayout = StridedLayoutAttr::get(strided.getContext(),
1220  strided.getOffset(), strides);
1221  } else {
1222  AffineMap map = srcType.getLayout().getAffineMap();
1223  int numSymbols = map.getNumSymbols();
1224  for (size_t i = 0; i < dimsToDrop; ++i) {
1225  int dim = srcType.getRank() - i - 1;
1226  map = map.replace(rewriter.getAffineDimExpr(dim),
1227  rewriter.getAffineConstantExpr(0),
1228  map.getNumDims() - 1, numSymbols);
1229  }
1230  }
1231  resultMemrefType = MemRefType::get(
1232  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1233  updatedLayout, srcType.getMemorySpace());
1234  }
1235 
1236  auto loc = readOp.getLoc();
1237  SmallVector<int64_t> offsets(srcType.getRank(), 0);
1238  SmallVector<int64_t> strides(srcType.getRank(), 1);
1239 
1240  ArrayAttr inBoundsAttr =
1241  readOp.getInBounds()
1242  ? rewriter.getArrayAttr(
1243  readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1244  : ArrayAttr();
1245  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1246  loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
1247  strides);
1248  auto permMap = getTransferMinorIdentityMap(
1249  cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1250  Value result = rewriter.create<vector::TransferReadOp>(
1251  loc, resultTargetVecType, rankedReducedView,
1252  readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1253  readOp.getPadding(),
1254  // TODO: support mask.
1255  /*mask=*/Value(), inBoundsAttr);
1256  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1257  result);
1258  return success();
1259  }
1260 };
1261 
1262 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1263 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
1264 /// with the RHS transposed) lowering.
1265 struct CanonicalizeContractMatmulToMMT final
1266  : OpRewritePattern<vector::ContractionOp> {
1268 
1269  using FilterConstraintType =
1270  std::function<LogicalResult(vector::ContractionOp op)>;
1271 
1272  CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1273  FilterConstraintType constraint)
1274  : OpRewritePattern<vector::ContractionOp>(context, benefit),
1275  filter(std::move(constraint)) {}
1276 
1277  LogicalResult matchAndRewrite(vector::ContractionOp op,
1278  PatternRewriter &rewriter) const override {
1279  if (failed(filter(op)))
1280  return failure();
1281 
1282  Location loc = op.getLoc();
1283  Value lhs = op.getLhs();
1284  Value rhs = op.getRhs();
1285  Value res = op.getAcc();
1286 
1287  // Set up the parallel/reduction structure in right form.
1288  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1289  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1290  AffineExpr m;
1291  AffineExpr n;
1292  AffineExpr k;
1293  bindDims(rewriter.getContext(), m, n, k);
1294  static constexpr std::array<int64_t, 2> perm = {1, 0};
1295  auto iteratorTypes = op.getIteratorTypes().getValue();
1296  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1297  if (iteratorTypes.size() != 3 ||
1298  !vector::isParallelIterator(iteratorTypes[0]) ||
1299  !vector::isParallelIterator(iteratorTypes[1]) ||
1300  !vector::isReductionIterator(iteratorTypes[2]))
1301  return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1302 
1303  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1304  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1305  if (maps == canonicalForm)
1306  return rewriter.notifyMatchFailure(op, "already in the canonical form");
1307 
1308  // Create a vector transpose making sure to emit zero/sign-extend at the
1309  // end.
1310  auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1311  if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1312  Value trans =
1313  rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1314  VectorType newType =
1315  cast<VectorType>(trans.getType())
1316  .clone(cast<VectorType>(mat.getType()).getElementType());
1317  return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1318  }
1319  if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1320  Value trans =
1321  rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1322  VectorType newType =
1323  VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1324  cast<VectorType>(mat.getType()).getElementType());
1325  return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1326  }
1327  return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1328  };
1329 
1330  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1331  rhs = createTranspose(rhs);
1332  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1333  lhs = createTranspose(lhs);
1334  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1335  rhs = createTranspose(rhs);
1336  lhs = createTranspose(lhs);
1337  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1338  std::swap(rhs, lhs);
1339  rhs = createTranspose(rhs);
1340  lhs = createTranspose(lhs);
1341  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1342  std::swap(rhs, lhs);
1343  rhs = createTranspose(rhs);
1344  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1345  std::swap(lhs, rhs);
1346  lhs = createTranspose(lhs);
1347  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1348  std::swap(lhs, rhs);
1349  } else {
1350  return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1351  }
1352  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1353  op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1354  op.getIteratorTypes());
1355  return success();
1356  };
1357 
1358 private:
1359  FilterConstraintType filter;
1360 };
1361 
1362 /// Pattern to fold arithmetic extensions on floating point data types into
1363 /// vector contraction operations. linalg.matmul introduces arithmetic
1364 /// extensions on its operands. Please mlir snippets below for more details.
1365 /// ```mlir
1366 /// "linalg.matmul"(%lhs, %rhs, %acc) ({
1367 /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1368 /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1369 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1370 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1371 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1372 /// "linalg.yield"(%acc) : (f32) -> ()
1373 /// })
1374 /// ```
1375 /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1376 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1377 /// This pattern folds the arithmetic extensions into the vector contraction and
1378 /// enables the usage of native mixed precision Tensor Core instructions.
1379 struct FoldArithExtIntoContractionOp
1380  : public OpRewritePattern<vector::ContractionOp> {
1382 
1383  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1384  PatternRewriter &rewriter) const override {
1385 
1386  auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1387  auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1388 
1389  if (!lhsDefOp || !rhsDefOp) {
1390  return rewriter.notifyMatchFailure(contractOp,
1391  "no defining op on contract operands");
1392  }
1393 
1394  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1395  contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1396  contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1397  contractOp.getIteratorTypesAttr());
1398 
1399  return success();
1400  }
1401 };
1402 
1403 /// Pattern to fold chained reduction to a series of vector additions and a
1404 /// final reduction. This form should require fewer subgroup operations.
1405 ///
1406 /// ```mlir
1407 /// %a = vector.reduction <add> %x, %acc
1408 /// %b = vector.reduction <add> %y, %a
1409 /// ==>
1410 /// %a = arith.addf %x, %y
1411 /// %b = vector.reduction <add> %a, %acc
1412 /// ```
1413 struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1415 
1416  LogicalResult matchAndRewrite(vector::ReductionOp op,
1417  PatternRewriter &rewriter) const override {
1418  // TODO: Handle other combining kinds.
1419  if (op.getKind() != vector::CombiningKind::ADD)
1420  return failure();
1421 
1422  // Accumulator is optional.
1423  Value acc = op.getAcc();
1424  if (!acc)
1425  return failure();
1426 
1427  if (!acc.getType().isIntOrFloat())
1428  return failure();
1429 
1430  auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1431  if (!parentReduction)
1432  return failure();
1433 
1434  Location loc = op.getLoc();
1435  Value vAdd;
1436  if (isa<IntegerType>(acc.getType())) {
1437  vAdd = rewriter.createOrFold<arith::AddIOp>(
1438  loc, parentReduction.getVector(), op.getVector());
1439  } else {
1440  vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1441  op.getVector());
1442  }
1443  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1444  parentReduction.getAcc());
1445  return success();
1446  }
1447 };
1448 
1449 /// Pattern to eliminate redundant zero-constants added to reduction operands.
1450 /// It's enough for there to be one initial zero value, so we can eliminate the
1451 /// extra ones that feed into `vector.reduction <add>`. These get created by the
1452 /// `ChainedReduction` pattern.
1453 ///
1454 /// ```mlir
1455 /// %a = arith.addf %x, %zero
1456 /// %b = arith.addf %a, %y
1457 /// %c = vector.reduction <add> %b, %acc
1458 /// ==>
1459 /// %b = arith.addf %a, %y
1460 /// %c = vector.reduction <add> %b, %acc
1461 /// ```
1462 struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1464 
1465  LogicalResult matchAndRewrite(vector::ReductionOp op,
1466  PatternRewriter &rewriter) const override {
1467  // TODO: Handle other reduction kinds and their identity values.
1468  if (op.getKind() != vector::CombiningKind::ADD)
1469  return failure();
1470 
1471  Type elemType = op.getSourceVectorType().getElementType();
1472  // The integer case should be handled by `arith.addi` folders, only check
1473  // for floats here.
1474  if (!isa<FloatType>(elemType))
1475  return failure();
1476 
1477  auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1478  if (!vAdd)
1479  return failure();
1480  auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1481  if (!addLhs)
1482  return failure();
1483 
1484  if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1485  return failure();
1486 
1487  auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1488  vAdd.getRhs());
1489  rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1490  op.getAcc());
1491  return success();
1492  }
1493 };
1494 
1495 } // namespace
1496 
1498  RewritePatternSet &patterns) {
1499  patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
1500 }
1501 
1503  RewritePatternSet &patterns, bool force32BitVectorIndices,
1504  PatternBenefit benefit) {
1505  patterns.add<VectorCreateMaskOpConversion,
1506  MaterializeTransferMask<vector::TransferReadOp>,
1507  MaterializeTransferMask<vector::TransferWriteOp>>(
1508  patterns.getContext(), force32BitVectorIndices, benefit);
1509  patterns.add<FoldI1Select>(patterns.getContext(), benefit);
1510 }
1511 
1513  PatternBenefit benefit) {
1514  patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
1515 }
1516 
1518  RewritePatternSet &patterns, PatternBenefit benefit) {
1519  patterns.add<BubbleDownVectorBitCastForExtract,
1520  BubbleDownBitCastForStridedSliceExtract,
1521  BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
1522  benefit);
1523 }
1524 
1526  RewritePatternSet &patterns,
1527  std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
1528  patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
1529  std::move(controlFn), benefit);
1530 }
1531 
1533  RewritePatternSet &patterns,
1534  std::function<LogicalResult(vector::ContractionOp)> constraint,
1535  PatternBenefit benefit) {
1536  patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
1537  std::move(constraint));
1538 }
1539 
1541  RewritePatternSet &patterns, PatternBenefit benefit) {
1542  patterns.add<MultiReduceToContract, CombineContractBroadcast,
1543  CombineContractABTranspose, CombineContractResultTranspose,
1544  ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1545  patterns.getContext(), benefit);
1546 }
1547 
1550  RewritePatternSet &patterns, PatternBenefit benefit) {
1551  patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
1552 }
1553 
1555  RewritePatternSet &patterns, PatternBenefit benefit) {
1556  patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1557  patterns.getContext(), benefit);
1558 }
1559 
1561  RewritePatternSet &patterns, PatternBenefit benefit) {
1562  patterns.add<ChainedReduction>(patterns.getContext(), benefit);
1563  patterns.add<ReduceRedundantZero>(patterns.getContext(),
1564  PatternBenefit(benefit.getBenefit() + 1));
1565 }
1566 
1567 //===----------------------------------------------------------------------===//
1568 // TableGen'd enum attribute definitions
1569 //===----------------------------------------------------------------------===//
1570 
1571 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:395
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:378
unsigned getNumDims() const
Definition: AffineMap.cpp:374
unsigned getNumResults() const
Definition: AffineMap.cpp:382
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:292
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:495
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:536
Attributes are known-constant values of operations.
Definition: Attributes.h:25
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:361
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:138
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:144
IntegerType getI1Type()
Definition: Builders.cpp:73
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:383
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:385
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:686
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
type_range getType() const
Definition: ValueRange.cpp:39
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:105
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:128
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1344
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:137
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:152
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:132
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:37
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:334
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:749
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:49
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:340
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
Definition: AffineMap.cpp:674
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:886
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361