MLIR  21.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/OpDefinition.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/Value.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/MathExtras.h"
42 #include "llvm/Support/raw_ostream.h"
43 #include <optional>
44 
45 using namespace mlir;
46 using namespace mlir::linalg;
47 
48 #define DEBUG_TYPE "linalg-vectorization"
49 
50 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
51 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
52 
53 /// Try to vectorize `convOp` as a convolution.
54 static FailureOr<Operation *>
55 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
56  ArrayRef<int64_t> inputVecSizes = {},
57  ArrayRef<bool> inputVecScalableFlags = {},
58  bool flatten1DDepthwiseConv = false);
59 
60 /// Vectorize tensor::InsertSliceOp with:
61 /// * vector::TransferReadOp + vector::TransferWriteOp
62 /// The vector sizes are either:
63 /// * user-provided in `inputVectorSizes`, or
64 /// * inferred from the static dims in the input and output tensors.
65 /// Bails out if:
66 /// * vector sizes are not user-provided, and
67 /// * at least one dim is dynamic (in both the input and output tensors).
68 ///
69 /// Before:
70 /// !t_in_type = tensor<1x2x3xf32>
71 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
72 /// !v_type = vector<1x2x3xf32>
73 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
74 /// into !t_out_type
75 /// After:
76 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
77 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
78 static LogicalResult
79 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
80  ArrayRef<int64_t> inputVectorSizes,
81  SmallVectorImpl<Value> &newResults);
82 
83 /// Returns the effective Pad value for the input op, provided it's a scalar.
84 ///
85 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
86 /// this Op performs padding, retrieve the padding value provided that it's
87 /// a scalar and static/fixed for all the padded values. Returns an empty value
88 /// otherwise.
89 static Value getStaticPadVal(Operation *op);
90 
91 /// Return the unique instance of OpType in `block` if it is indeed unique.
92 /// Return null if none or more than 1 instances exist.
93 template <typename OpType>
94 static OpType getSingleOpOfType(Block &block) {
95  OpType res;
96  block.walk([&](OpType op) {
97  if (res) {
98  res = nullptr;
99  return WalkResult::interrupt();
100  }
101  res = op;
102  return WalkResult::advance();
103  });
104  return res;
105 }
106 
107 /// Helper function to extract the input slices after filter is unrolled along
108 /// kw.
109 static SmallVector<Value>
111  int64_t nSize, int64_t wSize, int64_t cSize,
112  int64_t kwSize, int strideW, int dilationW,
113  int64_t wSizeStep, bool isSingleChanneled) {
114  SmallVector<Value> result;
115  if (isSingleChanneled) {
116  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
117  // convolution.
118  SmallVector<int64_t> sizes = {wSizeStep};
119  SmallVector<int64_t> strides = {1};
120  for (int64_t kw = 0; kw < kwSize; ++kw) {
121  for (int64_t w = 0; w < wSize; w += wSizeStep) {
122  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
123  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
124  }
125  }
126  } else {
127  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128  // for channeled convolution.
129  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130  SmallVector<int64_t> strides = {1, 1, 1};
131  for (int64_t kw = 0; kw < kwSize; ++kw) {
132  for (int64_t w = 0; w < wSize; w += wSizeStep) {
133  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
134  loc, input,
135  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136  sizes, strides));
137  }
138  }
139  }
140  return result;
141 }
142 
143 /// Helper function to extract the filter slices after filter is unrolled along
144 /// kw.
146  Location loc, Value filter,
147  int64_t kwSize) {
148  SmallVector<Value> result;
149  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150  // non-chanelled convolution] @ [kw].
151  for (int64_t kw = 0; kw < kwSize; ++kw) {
152  result.push_back(rewriter.create<vector::ExtractOp>(
153  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
154  }
155  return result;
156 }
157 
158 /// Helper function to extract the result slices after filter is unrolled along
159 /// kw.
160 static SmallVector<Value>
162  int64_t nSize, int64_t wSize, int64_t fSize,
163  int64_t wSizeStep, bool isSingleChanneled) {
164  SmallVector<Value> result;
165  if (isSingleChanneled) {
166  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167  SmallVector<int64_t> sizes = {wSizeStep};
168  SmallVector<int64_t> strides = {1};
169  for (int64_t w = 0; w < wSize; w += wSizeStep) {
170  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
171  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
172  }
173  } else {
174  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
175  // convolution.
176  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
177  SmallVector<int64_t> strides = {1, 1, 1};
178  for (int64_t w = 0; w < wSize; w += wSizeStep) {
179  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
180  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
181  }
182  }
183  return result;
184 }
185 
186 /// Helper function to insert the computed result slices.
188  Value res, int64_t wSize, int64_t wSizeStep,
189  SmallVectorImpl<Value> &resVals,
190  bool isSingleChanneled) {
191 
192  if (isSingleChanneled) {
193  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
194  // This does not depend on kw.
195  SmallVector<int64_t> strides = {1};
196  for (int64_t w = 0; w < wSize; w += wSizeStep) {
197  res = rewriter.create<vector::InsertStridedSliceOp>(
198  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
199  }
200  } else {
201  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
202  // convolution. This does not depend on kw.
203  SmallVector<int64_t> strides = {1, 1, 1};
204  for (int64_t w = 0; w < wSize; w += wSizeStep) {
205  res = rewriter.create<vector::InsertStridedSliceOp>(
206  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
207  strides);
208  }
209  }
210  return res;
211 }
212 
213 /// Contains the vectorization state and related methods used across the
214 /// vectorization process of a given operation.
216  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
217 
218  /// Initializes the vectorization state, including the computation of the
219  /// canonical vector shape for vectorization.
220  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
221  ArrayRef<int64_t> inputVectorSizes,
222  ArrayRef<bool> inputScalableVecDims);
223 
224  /// Returns the canonical vector shape used to vectorize the iteration space.
225  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
226 
227  /// Returns the vector dimensions that are scalable in the canonical vector
228  /// shape.
229  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
230 
231  /// Returns a vector type of the provided `elementType` with the canonical
232  /// vector shape and the corresponding fixed/scalable dimensions bit. If
233  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
234  /// accordingly.
236  Type elementType,
237  std::optional<AffineMap> dimPermutation = std::nullopt) const {
239  SmallVector<bool> scalableDims;
240  if (dimPermutation.has_value()) {
241  vectorShape =
242  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
243  scalableDims =
244  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
245  } else {
246  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
247  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
248  }
249 
250  return VectorType::get(vectorShape, elementType, scalableDims);
251  }
252 
253  /// Masks an operation with the canonical vector mask if the operation needs
254  /// masking. Returns the masked operation or the original operation if masking
255  /// is not needed. If provided, the canonical mask for this operation is
256  /// permuted using `maybeIndexingMap`.
257  Operation *
258  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
259  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
260 
261 private:
262  /// Initializes the iteration space static sizes using the Linalg op
263  /// information. This may become more complicated in the future.
264  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
265  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
266  }
267 
268  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
269  /// all the static and dynamic dimensions of the iteration space to be
270  /// vectorized and store them in `iterSpaceValueSizes`.
271  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
272  LinalgOp linalgOp);
273 
274  /// Create or retrieve an existing mask value to mask `opToMask` in the
275  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
276  /// permuted using that permutation map. If a new mask is created, it will be
277  /// cached for future users.
278  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
279  LinalgOp linalgOp,
280  std::optional<AffineMap> maybeMaskingMap);
281 
282  /// Check whether this permutation map can be used for masking. At the
283  /// moment we only make sure that there are no broadcast dimensions, but this
284  /// might change if indexing maps evolve.
285  bool isValidMaskingMap(AffineMap maskingMap) {
286  return maskingMap.getBroadcastDims().size() == 0;
287  }
288 
289  /// Turn the input indexing map into a valid masking map.
290  ///
291  /// The input indexing map may contain "zero" results, e.g.:
292  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
293  /// Applying such maps to canonical vector shapes like this one:
294  /// (1, 16, 16, 4)
295  /// would yield an invalid vector shape like this:
296  /// (16, 16, 1, 0)
297  /// Instead, drop the broadcasting dims that make no sense for masking perm.
298  /// maps:
299  /// (d0, d1, d2, d3) -> (d2, d1, d0)
300  /// This way, the corresponding vector/mask type will be:
301  /// vector<16x16x1xty>
302  /// rather than this invalid Vector type:
303  /// vector<16x16x1x0xty>
304  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
305  return indexingMap.dropZeroResults();
306  }
307 
308  // Holds the compile-time static sizes of the iteration space to vectorize.
309  // Dynamic dimensions are represented using ShapedType::kDynamic.
310  SmallVector<int64_t> iterSpaceStaticSizes;
311 
312  /// Holds the value sizes of the iteration space to vectorize. Static
313  /// dimensions are represented by 'arith.constant' and dynamic
314  /// dimensions by 'tensor/memref.dim'.
315  SmallVector<Value> iterSpaceValueSizes;
316 
317  /// Holds the canonical vector shape used to vectorize the iteration space.
318  SmallVector<int64_t> canonicalVecShape;
319 
320  /// Holds the vector dimensions that are scalable in the canonical vector
321  /// shape.
322  SmallVector<bool> scalableVecDims;
323 
324  /// Holds the active masks for permutations of the canonical vector iteration
325  /// space.
326  DenseMap<AffineMap, Value> activeMaskCache;
327 
328  /// Global vectorization guard for the incoming rewriter. It's initialized
329  /// when the vectorization state is initialized.
331 };
332 
333 LogicalResult
334 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
335  LinalgOp linalgOp) {
336  // TODO: Support 0-d vectors.
337  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
338  if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
339  // Create constant index op for static dimensions.
340  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
341  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
342  continue;
343  }
344 
345  // Find an operand defined on this dimension of the iteration space to
346  // extract the runtime dimension size.
347  Value operand;
348  unsigned operandDimPos;
349  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
350  operandDimPos)))
351  return failure();
352 
353  Value dynamicDim = linalgOp.hasPureTensorSemantics()
354  ? (Value)rewriter.create<tensor::DimOp>(
355  linalgOp.getLoc(), operand, operandDimPos)
356  : (Value)rewriter.create<memref::DimOp>(
357  linalgOp.getLoc(), operand, operandDimPos);
358  iterSpaceValueSizes.push_back(dynamicDim);
359  }
360 
361  return success();
362 }
363 
364 /// Initializes the vectorization state, including the computation of the
365 /// canonical vector shape for vectorization.
366 // TODO: Move this to the constructor when we can remove the failure cases.
367 LogicalResult
368 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
369  ArrayRef<int64_t> inputVectorSizes,
370  ArrayRef<bool> inputScalableVecDims) {
371  // Initialize the insertion point.
372  rewriter.setInsertionPoint(linalgOp);
373 
374  if (!inputVectorSizes.empty()) {
375  // Get the canonical vector shape from the input vector sizes provided. This
376  // path should be taken to vectorize code with dynamic shapes and when using
377  // vector sizes greater than the iteration space sizes.
378  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
379  scalableVecDims.append(inputScalableVecDims.begin(),
380  inputScalableVecDims.end());
381  } else {
382  // Compute the canonical vector shape from the operation shape. If there are
383  // dynamic shapes, the operation won't be vectorized. We assume all the
384  // vector dimensions are fixed.
385  canonicalVecShape = linalgOp.getStaticLoopRanges();
386  scalableVecDims.append(linalgOp.getNumLoops(), false);
387  }
388 
389  LDBG("Canonical vector shape: ");
390  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
391  LLVM_DEBUG(llvm::dbgs() << "\n");
392  LDBG("Scalable vector dims: ");
393  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
394  LLVM_DEBUG(llvm::dbgs() << "\n");
395 
396  if (ShapedType::isDynamicShape(canonicalVecShape))
397  return failure();
398 
399  // Initialize iteration space static sizes.
400  initIterSpaceStaticSizes(linalgOp);
401 
402  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
403  // all the static and dynamic dimensions of the iteration space, needed to
404  // compute a mask during vectorization.
405  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
406  return failure();
407 
408  return success();
409 }
410 
411 /// Create or retrieve an existing mask value to mask `opToMask` in the
412 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
413 /// using that permutation map. If a new mask is created, it will be cached for
414 /// future users.
415 Value VectorizationState::getOrCreateMaskFor(
416  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
417  std::optional<AffineMap> maybeMaskingMap) {
418 
419  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
420  "Ill-formed masking map.");
421 
422  // No mask is needed if the operation is not maskable.
423  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
424  if (!maskableOp)
425  return Value();
426 
427  assert(!maskableOp.isMasked() &&
428  "Masking an operation that is already masked");
429 
430  // If no masking map was provided, use an identity map with the loop dims.
431  assert((!maybeMaskingMap || *maybeMaskingMap) &&
432  "Unexpected null mask permutation map");
433  AffineMap maskingMap =
434  maybeMaskingMap ? *maybeMaskingMap
436  linalgOp.getNumLoops(), rewriter.getContext());
437 
438  LDBG("Masking map: " << maskingMap << "\n");
439 
440  // Return the active mask for the masking map of this operation if it was
441  // already created.
442  auto activeMaskIt = activeMaskCache.find(maskingMap);
443  if (activeMaskIt != activeMaskCache.end()) {
444  Value mask = activeMaskIt->second;
445  LDBG("Reusing mask: " << mask << "\n");
446  return mask;
447  }
448 
449  // Compute permuted projection of the iteration space to be masked and the
450  // corresponding mask shape. If the resulting iteration space dimensions are
451  // static and identical to the mask shape, masking is not needed for this
452  // operation.
453  // TODO: Improve this check. Only projected permutation indexing maps are
454  // supported.
455  SmallVector<int64_t> permutedStaticSizes =
456  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
457  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
458  auto maskShape = maskType.getShape();
459 
460  LDBG("Mask shape: ");
461  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
462  LLVM_DEBUG(llvm::dbgs() << "\n");
463 
464  if (permutedStaticSizes == maskShape) {
465  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
466  activeMaskCache[maskingMap] = Value();
467  return Value();
468  }
469 
470  // Permute the iteration space value sizes to compute the mask upper bounds.
471  SmallVector<Value> upperBounds =
472  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
473  assert(!maskShape.empty() && !upperBounds.empty() &&
474  "Masked 0-d vectors are not supported yet");
475 
476  // Create the mask based on the dimension values.
477  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
478  maskType, upperBounds);
479  LDBG("Creating new mask: " << mask << "\n");
480  activeMaskCache[maskingMap] = mask;
481  return mask;
482 }
483 
484 Operation *
486  LinalgOp linalgOp,
487  std::optional<AffineMap> maybeIndexingMap) {
488  LDBG("Trying to mask: " << *opToMask << "\n");
489 
490  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
491  if (maybeIndexingMap)
492  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
493 
494  // Create or retrieve mask for this operation.
495  Value mask =
496  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
497 
498  if (!mask) {
499  LDBG("No mask required\n");
500  return opToMask;
501  }
502 
503  // Wrap the operation with a new `vector.mask` and update D-U chain.
504  assert(opToMask && "Expected a valid operation to mask");
505  auto maskOp = cast<vector::MaskOp>(
506  mlir::vector::maskOperation(rewriter, opToMask, mask));
507  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
508 
509  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
510  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
511  maskOpTerminator);
512 
513  LDBG("Masked operation: " << *maskOp << "\n");
514  return maskOp;
515 }
516 
517 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
518 /// projectedPermutation, compress the unused dimensions to serve as a
519 /// permutation_map for a vector transfer operation.
520 /// For example, given a linalg op such as:
521 ///
522 /// ```
523 /// %0 = linalg.generic {
524 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
525 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
526 /// }
527 /// ins(%0 : tensor<2x3x4xf32>)
528 /// outs(%1 : tensor<5x6xf32>)
529 /// ```
530 ///
531 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
532 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
533 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
535  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
536  "expected projected permutation");
537  auto res = compressUnusedDims(map);
538  assert(res.getNumDims() ==
539  (res.getNumResults() - res.getNumOfZeroResults()) &&
540  "expected reindexed map with same number of dims and results");
541  return res;
542 }
543 
544 /// Helper enum to represent conv1d input traversal order.
545 enum class Conv1DOpOrder {
546  W, // Corresponds to non-channeled 1D convolution operation.
547  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
548  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
549 };
550 
551 /// Helper data structure to represent the result of vectorization for a single
552 /// operation. In certain specific cases, like terminators, we do not want to
553 /// propagate.
555  /// Op failed to vectorize.
556  Failure = 0,
557  /// Op vectorized and custom function took care of replacement logic
559  /// Op vectorized into a new Op whose results will replace original Op's
560  /// results.
561  NewOp
562  // TODO: support values if Op vectorized to Many-Ops whose results we need to
563  // aggregate for replacement.
564 };
565 /// VectorizationHookResult contains the vectorized op returned from a
566 /// CustomVectorizationHook. This is an internal implementation detail of
567 /// linalg vectorization, not to be confused with VectorizationResult.
569  /// Return status from vectorizing the current op.
571  /// New vectorized operation to replace the current op.
572  /// Replacement behavior is specified by `status`.
574 };
575 
576 std::optional<vector::CombiningKind>
578  using ::mlir::vector::CombiningKind;
579 
580  if (!combinerOp)
581  return std::nullopt;
583  .Case<arith::AddIOp, arith::AddFOp>(
584  [&](auto op) { return CombiningKind::ADD; })
585  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
586  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
587  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
588  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
589  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
590  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
591  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
592  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
593  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
594  .Case<arith::MulIOp, arith::MulFOp>(
595  [&](auto op) { return CombiningKind::MUL; })
596  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
597  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
598  .Default([&](auto op) { return std::nullopt; });
599 }
600 
601 /// Check whether `outputOperand` is a reduction with a single combiner
602 /// operation. Return the combiner operation of the reduction. Return
603 /// nullptr otherwise. Multiple reduction operations would impose an
604 /// ordering between reduction dimensions and is currently unsupported in
605 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
606 /// max(min(X))
607 // TODO: use in LinalgOp verification, there is a circular dependency atm.
608 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
609  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
610  unsigned outputPos =
611  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
612  // Only single combiner operations are supported for now.
613  SmallVector<Operation *, 4> combinerOps;
614  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
615  combinerOps.size() != 1)
616  return nullptr;
617 
618  // Return the combiner operation.
619  return combinerOps[0];
620 }
621 
622 /// Broadcast `value` to a vector of `shape` if possible. Return value
623 /// otherwise.
624 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
625  auto dstVecType = dyn_cast<VectorType>(dstType);
626  // If no shape to broadcast to, just return `value`.
627  if (dstVecType.getRank() == 0)
628  return value;
629  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
631  return value;
632  Location loc = b.getInsertionPoint()->getLoc();
633  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
634 }
635 
636 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
637 /// assumes that `reductionOp` has two operands and one of them is the reduction
638 /// initial value.buildMultiDimReduce
639 // Note: this is a true builder that notifies the OpBuilder listener.
640 // TODO: Consider moving as a static helper on the ReduceOp.
642  Value valueToReduce, Value acc,
643  ArrayRef<bool> dimsToMask) {
644  auto maybeKind = getCombinerOpKind(reduceOp);
645  assert(maybeKind && "Failed precondition: could not get reduction kind");
646  return b.create<vector::MultiDimReductionOp>(
647  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
648 }
649 
650 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
651  return llvm::to_vector(
652  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
653 }
654 
655 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
656 /// reduction iterator.
657 static bool hasReductionIterator(LinalgOp &op) {
658  return isa<linalg::ReduceOp>(op) ||
659  (isa<linalg::GenericOp>(op) &&
660  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
661 }
662 
663 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
664 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
665 /// currently being vectorized. If `dest` has null rank, build an memref.store.
666 /// Return the produced value or null if no value is produced.
667 // Note: this is a true builder that notifies the OpBuilder listener.
668 // TODO: Consider moving as a static helper on the ReduceOp.
669 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
670  OpOperand *outputOperand,
671  VectorizationState &state) {
672  Location loc = value.getLoc();
673  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
674  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
675 
676  // Compute the vector type of the value to store. This type should be an
677  // identity or projection of the canonical vector type without any permutation
678  // applied, given that any permutation in a transfer write happens as part of
679  // the write itself.
681  opOperandMap.getContext(), opOperandMap.getNumInputs(),
682  [&](AffineDimExpr dimExpr) -> bool {
683  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
684  });
685  auto vectorType = state.getCanonicalVecType(
686  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
687 
688  Operation *write;
689  if (vectorType.getRank() > 0) {
690  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
691  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
692  rewriter.create<arith::ConstantIndexOp>(loc, 0));
693  value = broadcastIfNeeded(rewriter, value, vectorType);
694  assert(value.getType() == vectorType && "Incorrect type");
695  write = rewriter.create<vector::TransferWriteOp>(
696  loc, value, outputOperand->get(), indices, writeMap);
697  } else {
698  // 0-d case is still special: do not invert the reindexing writeMap.
699  if (!isa<VectorType>(value.getType()))
700  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
701  assert(value.getType() == vectorType && "Incorrect type");
702  write = rewriter.create<vector::TransferWriteOp>(
703  loc, value, outputOperand->get(), ValueRange{});
704  }
705 
706  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
707 
708  // If masked, set in-bounds to true. Masking guarantees that the access will
709  // be in-bounds.
710  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
711  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
712  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
713  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
714  }
715 
716  LDBG("vectorized op: " << *write << "\n");
717  if (!write->getResults().empty())
718  return write->getResult(0);
719  return Value();
720 }
721 
722 // Custom vectorization precondition function type. This is intented to be used
723 // with CustomVectorizationHook. Returns success if the corresponding custom
724 // hook can vectorize the op.
726  std::function<LogicalResult(Operation *, bool)>;
727 
728 // Custom vectorization function type. Produce a vector form of Operation*
729 // assuming all its vectorized operands are already in the IRMapping.
730 // Return nullptr if the Operation cannot be vectorized.
732  std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
733 
734 /// Helper function to vectorize the terminator of a `linalgOp`. New result
735 /// vector values are appended to `newResults`. Return
736 /// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
737 /// that it should not try to map produced operations and instead return the
738 /// results using the `newResults` vector making them available to the
739 /// vectorization algorithm for RAUW. This function is meant to be used as a
740 /// CustomVectorizationHook.
743  const IRMapping &bvm, VectorizationState &state,
744  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
745  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
746  if (!yieldOp)
748  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
749  // TODO: Scan for an opportunity for reuse.
750  // TODO: use a map.
751  Value vectorValue = bvm.lookup(output.value());
752  Value newResult =
753  buildVectorWrite(rewriter, vectorValue,
754  linalgOp.getDpsInitOperand(output.index()), state);
755  if (newResult)
756  newResults.push_back(newResult);
757  }
758 
760 }
761 
762 /// Helper function to vectorize the index operations of a `linalgOp`. Return
763 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
764 /// should map the produced operations. This function is meant to be used as a
765 /// CustomVectorizationHook.
767  VectorizationState &state,
768  Operation *op,
769  LinalgOp linalgOp) {
770  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
771  if (!indexOp)
773  auto loc = indexOp.getLoc();
774  // Compute the static loop sizes of the index op.
775  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
776  auto dim = indexOp.getDim();
777  // Compute a one-dimensional index vector for the index op dimension.
778  auto indexVectorType =
779  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
780  state.getScalableVecDims()[dim]);
781  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
782  // Return the one-dimensional index vector if it lives in the trailing
783  // dimension of the iteration space since the vectorization algorithm in this
784  // case can handle the broadcast.
785  if (dim == targetShape.size() - 1)
787  // Otherwise permute the targetShape to move the index dimension last,
788  // broadcast the one-dimensional index vector to the permuted shape, and
789  // finally transpose the broadcasted index vector to undo the permutation.
790  auto permPattern =
791  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
792  std::swap(permPattern[dim], permPattern.back());
793  auto permMap =
794  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
795 
796  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
797  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
798  indexSteps);
799  SmallVector<int64_t> transposition =
800  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
801  std::swap(transposition.back(), transposition[dim]);
802  auto transposeOp =
803  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
805 }
806 
807 /// Helper function to check if the tensor.extract can be vectorized by the
808 /// custom hook vectorizeTensorExtract.
809 static LogicalResult
810 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
811  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
812  if (!extractOp)
813  return failure();
814 
815  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
816  return failure();
817 
818  // Check the index type, but only for non 0-d tensors (for which we do need
819  // access indices).
820  if (not extractOp.getIndices().empty()) {
821  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
822  return failure();
823  }
824 
825  if (!llvm::all_of(extractOp->getResultTypes(),
826  VectorType::isValidElementType)) {
827  return failure();
828  }
829 
830  return success();
831 }
832 
833 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
834 /// generated from `tensor.extract`. The offset is calculated as follows
835 /// (example using scalar values):
836 ///
837 /// offset = extractOp.indices[0]
838 /// for (i = 1; i < numIndices; i++)
839 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
840 ///
841 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
842 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
844  VectorizationState &state,
845  tensor::ExtractOp extractOp,
846  const IRMapping &bvm) {
847  // The vector of indices for GatherOp should be shaped as the output vector.
848  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
849  auto loc = extractOp.getLoc();
850 
851  Value offset = broadcastIfNeeded(
852  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
853 
854  const size_t numIndices = extractOp.getIndices().size();
855  for (size_t i = 1; i < numIndices; i++) {
856  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
857 
858  auto dimSize = broadcastIfNeeded(
859  rewriter,
860  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
861  indexVecType);
862 
863  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
864 
865  auto extractOpIndex = broadcastIfNeeded(
866  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
867 
868  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
869  }
870 
871  return offset;
872 }
873 
875 
876 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
877 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
878 /// represents a contiguous load operation.
879 ///
880 /// Note that when calling this hook, it is assumed that the output vector is
881 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
882 /// labelled as a gather load before entering this method.
883 ///
884 /// Following on from the above, it is assumed that:
885 /// * for statically shaped loops, when no masks are used, only one dim is !=
886 /// 1 (that's what the shape of the output vector is based on).
887 /// * for dynamically shaped loops, there might be more non-unit dims
888 /// as the output vector type is user-specified.
889 ///
890 /// TODO: Statically shaped loops + vector masking
891 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
892  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
893  assert(
894  (linalgOp.hasDynamicShape() ||
895  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
896  "For statically shaped Linalg Ops, only one "
897  "non-unit loop dim is expected");
898  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
899 
900  size_t idx = loopRanges.size() - 1;
901  for (; idx != 0; idx--)
902  if (loopRanges[idx] != 1)
903  break;
904 
905  return idx;
906 }
907 
908 /// Checks whether `val` can be used for calculating a loop invariant index.
909 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
910  VectorType resType) {
911 
912  assert(((llvm::count_if(resType.getShape(),
913  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
914  "n-D vectors are not yet supported");
915 
916  // Blocks outside _this_ linalg.generic are effectively loop invariant.
917  // However, analysing block arguments for _this_ linalg.generic Op is a bit
918  // tricky. Just bail out in the latter case.
919  // TODO: We could try analysing the corresponding affine map here.
920  auto *block = linalgOp.getBlock();
921  if (isa<BlockArgument>(val))
922  return !llvm::is_contained(block->getArguments(), val);
923 
924  Operation *defOp = val.getDefiningOp();
925  assert(defOp && "This is neither a block argument nor an operation result");
926 
927  // IndexOp is loop invariant as long as its result remains constant across
928  // iterations. Note that for dynamic shapes, the corresponding dim will also
929  // be conservatively treated as != 1.
930  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
931  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
932  }
933 
934  auto *ancestor = block->findAncestorOpInBlock(*defOp);
935 
936  // Values define outside `linalgOp` are loop invariant.
937  if (!ancestor)
938  return true;
939 
940  // Values defined inside `linalgOp`, which are constant, are loop invariant.
941  if (isa<arith::ConstantOp>(ancestor))
942  return true;
943 
944  bool result = true;
945  for (auto op : ancestor->getOperands())
946  result &= isLoopInvariantIdx(linalgOp, op, resType);
947 
948  return result;
949 }
950 
951 /// Check whether `val` could be used for calculating the trailing index for a
952 /// contiguous load operation.
953 ///
954 /// There are currently 3 types of values that are allowed here:
955 /// 1. loop-invariant values,
956 /// 2. values that increment by 1 with every loop iteration,
957 /// 3. results of basic arithmetic operations (linear and continuous)
958 /// involving 1., 2. and 3.
959 /// This method returns True if indeed only such values are used in calculating
960 /// `val.`
961 ///
962 /// Additionally, the trailing index for a contiguous load operation should
963 /// increment by 1 with every loop iteration, i.e. be based on:
964 /// * `linalg.index <dim>` ,
965 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
966 /// `linalg.index <dim>` increments by 1 with every loop iteration).
967 /// `foundIndexOp` is updated to `true` when such Op is found.
968 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
969  bool &foundIndexOp, VectorType resType) {
970 
971  assert(((llvm::count_if(resType.getShape(),
972  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
973  "n-D vectors are not yet supported");
974 
975  // Blocks outside _this_ linalg.generic are effectively loop invariant.
976  // However, analysing block arguments for _this_ linalg.generic Op is a bit
977  // tricky. Just bail out in the latter case.
978  // TODO: We could try analysing the corresponding affine map here.
979  auto *block = linalgOp.getBlock();
980  if (isa<BlockArgument>(val))
981  return !llvm::is_contained(block->getArguments(), val);
982 
983  Operation *defOp = val.getDefiningOp();
984  assert(defOp && "This is neither a block argument nor an operation result");
985 
986  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
987  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
988 
989  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
990  return true;
991  }
992 
993  auto *ancestor = block->findAncestorOpInBlock(*defOp);
994 
995  if (!ancestor)
996  return false;
997 
998  // Conservatively reject Ops that could lead to indices with stride other
999  // than 1.
1000  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1001  return false;
1002 
1003  bool result = false;
1004  for (auto op : ancestor->getOperands())
1005  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1006 
1007  return result;
1008 }
1009 
1010 /// Infer the memory access pattern for the input ExtractOp
1011 ///
1012 /// Based on the ExtratOp result shape and the access indices, decides whether
1013 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1014 /// or a gather load. When analysing the ExtractOp indices (to identify
1015 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1016 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1017 /// Op).
1018 ///
1019 /// Note that it is always safe to use gather load operations for contiguous
1020 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1021 /// that `extractOp` is a gather load.
1023 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1024  LinalgOp &linalgOp, VectorType resType) {
1025 
1026  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1027 
1028  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1029  if (inputShape.getShape().empty())
1031 
1032  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1033  // otherwise.
1034  bool isOutput1DVector =
1035  (llvm::count_if(resType.getShape(),
1036  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1037  // 1. Assume that it's a gather load when reading non-1D vector.
1038  if (!isOutput1DVector)
1040 
1041  bool leadingIdxsLoopInvariant = true;
1042 
1043  // 2. Analyze the leading indices of `extractOp`.
1044  // Look at the way each index is calculated and decide whether it is suitable
1045  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1046  // gather load.
1047  auto indices = extractOp.getIndices();
1048  auto leadIndices = indices.drop_back(1);
1049 
1050  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1051  if (inputShape.getShape()[i] == 1)
1052  continue;
1053 
1054  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1055  }
1056 
1057  if (!leadingIdxsLoopInvariant) {
1058  LDBG("Found gather load: " << extractOp);
1060  }
1061 
1062  // 3. Analyze the trailing index for `extractOp`.
1063  // At this point we know that the leading indices are loop invariant. This
1064  // means that is potentially a scalar or a contiguous load. We can decide
1065  // based on the trailing idx.
1066  auto extractOpTrailingIdx = indices.back();
1067 
1068  // 3a. Scalar broadcast load
1069  // If the trailing index is loop invariant then this is a scalar load.
1070  if (leadingIdxsLoopInvariant &&
1071  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1072  LDBG("Found scalar broadcast load: " << extractOp);
1073 
1075  }
1076 
1077  // 3b. Contiguous loads
1078  // The trailing `extractOp` index should increment with every loop iteration.
1079  // This effectively means that it must be based on the trailing loop index.
1080  // This is what the following bool captures.
1081  bool foundIndexOp = false;
1082  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1083  foundIndexOp, resType);
1084  // TODO: Support generating contiguous loads for column vectors - that will
1085  // require adding a permutation map to tranfer_read Ops.
1086  bool isRowVector = resType.getShape().back() != 1;
1087  isContiguousLoad &= (foundIndexOp && isRowVector);
1088 
1089  if (isContiguousLoad) {
1090  LDBG("Found contigous load: " << extractOp);
1092  }
1093 
1094  // 4. Fallback case - gather load.
1095  LDBG("Found gather load: " << extractOp);
1097 }
1098 
1099 /// Helper function to vectorize the tensor.extract operations. Returns
1100 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1101 /// should map the produced operations. This function is meant to be used as a
1102 /// CustomVectorizationHook.
1105  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1106  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1107  if (!extractOp)
1109  auto loc = extractOp.getLoc();
1110 
1111  // Compute the static loop sizes of the extract op.
1112  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1113  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1114  loc,
1115  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1116  /*value=*/true));
1117  auto passThruConstantOp =
1118  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1119 
1120  // Base indices are currently set to 0. We will need to re-visit if more
1121  // generic scenarios are to be supported.
1122  SmallVector<Value> baseIndices(
1123  extractOp.getIndices().size(),
1124  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1125 
1126  VectorMemoryAccessKind memAccessKind =
1127  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1128 
1129  // 1. Handle gather access
1130  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1131  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1132 
1133  // Generate the gather load
1134  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1135  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1136  maskConstantOp, passThruConstantOp);
1137  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1138 
1139  LDBG("Vectorised as gather load: " << extractOp << "\n");
1141  }
1142 
1143  // 2. Handle:
1144  // a. scalar loads + broadcast,
1145  // b. contiguous loads.
1146  // Both cases use vector.transfer_read.
1147 
1148  // Collect indices for `vector.transfer_read`. At this point, the indices will
1149  // either be scalars or would have been broadcast to vectors matching the
1150  // result type. For indices that are vectors, there are two options:
1151  // * for non-trailing indices, all elements are identical (contiguous
1152  // loads are identified by looking for non-trailing indices that are
1153  // invariant with respect to the corresponding linalg.generic), or
1154  // * for trailing indices, the index vector will contain values with stride
1155  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1156  // needed.
1157  // This means that
1158  // * for scalar indices - just re-use it,
1159  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1160  // (0th) element and use that.
1161  SmallVector<Value> transferReadIdxs;
1162  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1163  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1164  if (idx.getType().isIndex()) {
1165  transferReadIdxs.push_back(idx);
1166  continue;
1167  }
1168 
1169  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1170  loc,
1171  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1172  resultType.getScalableDims().back()),
1173  idx);
1174  transferReadIdxs.push_back(
1175  rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1176  }
1177 
1178  // `tensor.extract_element` is always in-bounds, hence the following holds.
1179  auto dstRank = resultType.getRank();
1180  auto srcRank = extractOp.getTensor().getType().getRank();
1181  SmallVector<bool> inBounds(dstRank, true);
1182 
1183  // 2a. Handle scalar broadcast access.
1184  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1185  MLIRContext *ctx = rewriter.getContext();
1186  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1187  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1188 
1189  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1190  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1191  /*padding=*/std::nullopt, permutationMap, inBounds);
1192 
1193  // Mask this broadcasting xfer_read here rather than relying on the generic
1194  // path (the generic path assumes identity masking map, which wouldn't be
1195  // valid here).
1196  SmallVector<int64_t> readMaskShape = {1};
1197  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1198  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1199  loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1200  auto *maskedReadOp =
1201  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1202 
1203  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1205  maskedReadOp};
1206  }
1207 
1208  // 2b. Handle contiguous access.
1209  auto permutationMap = AffineMap::getMinorIdentityMap(
1210  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1211 
1212  int32_t rankDiff = dstRank - srcRank;
1213  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1214  // dims so that the ranks match. This is done by extending the map with 0s.
1215  // For example, for dstRank = 3, srcRank = 2, the following map created
1216  // above:
1217  // (d0, d1) --> (d0, d1)
1218  // is extended as:
1219  // (d0, d1) --> (0, d0, d1)
1220  while (rankDiff > 0) {
1221  permutationMap = permutationMap.insertResult(
1222  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1223  rankDiff--;
1224  }
1225 
1226  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1227  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1228  /*padding=*/std::nullopt, permutationMap, inBounds);
1229 
1230  LDBG("Vectorised as contiguous load: " << extractOp);
1232  transferReadOp};
1233 }
1234 
1235 /// Emit reduction operations if the shapes of the value to reduce is different
1236 /// that the result shape.
1237 // Note: this is a true builder that notifies the OpBuilder listener.
1238 // TODO: Consider moving as a static helper on the ReduceOp.
1239 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1240  Value reduceValue, Value initialValue,
1241  const IRMapping &bvm) {
1242  Value reduceVec = bvm.lookup(reduceValue);
1243  Value outputVec = bvm.lookup(initialValue);
1244  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1245  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1246  // Reduce only if needed as the value may already have been reduce for
1247  // contraction vectorization.
1248  if (!reduceType ||
1249  (outputType && reduceType.getShape() == outputType.getShape()))
1250  return nullptr;
1251  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1252  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1253 }
1254 
1255 /// Generic vectorization for a single operation `op`, given already vectorized
1256 /// operands carried by `bvm`. Vectorization occurs as follows:
1257 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1258 /// result on success.
1259 /// 2. Clone any constant in the current scope without vectorization: each
1260 /// consumer of the constant will later determine the shape to which the
1261 /// constant needs to be broadcast to.
1262 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1263 /// of the `customVectorizationHooks` to cover such cases.
1264 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1265 /// operand of maximal rank. Other operands have smaller rank and are
1266 /// broadcast accordingly. It is assumed this broadcast is always legal,
1267 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1268 ///
1269 /// This function assumes all operands of `op` have been vectorized and are in
1270 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1271 /// a topologically-sorted list of ops.
1272 /// This function does not update `bvm` but returns a VectorizationHookStatus
1273 /// that instructs the caller what `bvm` update needs to occur.
1276  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1277  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1278  LDBG("vectorize op " << *op << "\n");
1279 
1280  // 1. Try to apply any CustomVectorizationHook.
1281  if (!customVectorizationHooks.empty()) {
1282  for (auto &customFunc : customVectorizationHooks) {
1283  VectorizationHookResult result = customFunc(op, bvm);
1285  continue;
1286  return result;
1287  }
1288  }
1289 
1290  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1291  // Clone so that the constant is not confined to the linalgOp block .
1292  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1294  rewriter.clone(*op)};
1295 
1296  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1299 
1300  // 4 . Check if the operation is a reduction.
1301  SmallVector<std::pair<Value, Value>> reductionOperands;
1302  for (Value operand : op->getOperands()) {
1303  auto blockArg = dyn_cast<BlockArgument>(operand);
1304  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1305  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1306  continue;
1307  SmallVector<Operation *> reductionOps;
1308  Value reduceValue = matchReduction(
1309  linalgOp.getRegionOutputArgs(),
1310  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1311  if (!reduceValue)
1312  continue;
1313  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1314  }
1315  if (!reductionOperands.empty()) {
1316  assert(reductionOperands.size() == 1);
1317  Operation *reduceOp =
1318  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1319  reductionOperands[0].second, bvm);
1320  if (reduceOp)
1322  }
1323 
1324  // 5. Generic vectorization path for ElementwiseMappable ops.
1325  // a. Get the first max ranked shape.
1326  VectorType firstMaxRankedType;
1327  for (Value operand : op->getOperands()) {
1328  auto vecOperand = bvm.lookup(operand);
1329  assert(vecOperand && "Vector operand couldn't be found");
1330 
1331  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1332  if (vecType && (!firstMaxRankedType ||
1333  firstMaxRankedType.getRank() < vecType.getRank()))
1334  firstMaxRankedType = vecType;
1335  }
1336  // b. Broadcast each op if needed.
1337  SmallVector<Value> vecOperands;
1338  for (Value scalarOperand : op->getOperands()) {
1339  Value vecOperand = bvm.lookup(scalarOperand);
1340  assert(vecOperand && "Vector operand couldn't be found");
1341 
1342  if (firstMaxRankedType) {
1343  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1344  getElementTypeOrSelf(vecOperand.getType()),
1345  firstMaxRankedType.getScalableDims());
1346  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1347  } else {
1348  vecOperands.push_back(vecOperand);
1349  }
1350  }
1351  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1352  SmallVector<Type> resultTypes;
1353  for (Type resultType : op->getResultTypes()) {
1354  resultTypes.push_back(
1355  firstMaxRankedType
1356  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1357  firstMaxRankedType.getScalableDims())
1358  : resultType);
1359  }
1360  // d. Build and return the new op.
1361  return VectorizationHookResult{
1363  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1364  resultTypes, op->getAttrs())};
1365 }
1366 
1367 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1368 /// vector form. Generic vectorization proceeds as follows:
1369 /// 1. Verify the `linalgOp` has one non-empty region.
1370 /// 2. Values defined above the region are mapped to themselves and will be
1371 /// broadcasted on a per-need basis by their consumers.
1372 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1373 /// load).
1374 /// TODO: Reuse opportunities for RAR dependencies.
1375 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1376 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1377 /// iteration indices.
1378 /// 5. Iteratively call vectorizeOneOp on the region operations.
1379 ///
1380 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1381 /// performed to the maximal common vector size implied by the `linalgOp`
1382 /// iteration space. This eager broadcasting is introduced in the
1383 /// permutation_map of the vector.transfer_read operations. The eager
1384 /// broadcasting makes it trivial to determine where broadcast, transposes and
1385 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1386 /// the absence of good canonicalizations, the amount of work increases.
1387 /// This is not deemed a problem as we expect canonicalizations and foldings to
1388 /// aggressively clean up the useless work.
1389 static LogicalResult
1391  LinalgOp linalgOp,
1392  SmallVectorImpl<Value> &newResults) {
1393  LDBG("Vectorizing operation as linalg generic\n");
1394  Block *block = linalgOp.getBlock();
1395 
1396  // 2. Values defined above the region can only be broadcast for now. Make them
1397  // map to themselves.
1398  IRMapping bvm;
1399  SetVector<Value> valuesSet;
1400  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1401  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1402 
1403  if (linalgOp.getNumDpsInits() == 0)
1404  return failure();
1405 
1406  // 3. Turn all BBArgs into vector.transfer_read / load.
1407  Location loc = linalgOp.getLoc();
1408  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1409  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1410  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1411  if (linalgOp.isScalar(opOperand)) {
1412  bvm.map(bbarg, opOperand->get());
1413  continue;
1414  }
1415 
1416  // 3.a. Convert the indexing map for this input/output to a transfer read
1417  // permutation map and masking map.
1418  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1419 
1420  AffineMap readMap;
1421  VectorType readType;
1422  Type elemType = getElementTypeOrSelf(opOperand->get());
1423  if (linalgOp.isDpsInput(opOperand)) {
1424  // 3.a.i. For input reads we use the canonical vector shape.
1425  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1426  readType = state.getCanonicalVecType(elemType);
1427  } else {
1428  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1429  // reductions), the vector shape is computed by mapping the canonical
1430  // vector shape to the output domain and back to the canonical domain.
1431  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1432  readType =
1433  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1434  }
1435 
1436  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1437 
1438  Operation *read = rewriter.create<vector::TransferReadOp>(
1439  loc, readType, opOperand->get(), indices,
1440  /*padding=*/std::nullopt, readMap);
1441  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1442  Value readValue = read->getResult(0);
1443 
1444  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1445  // will be in-bounds.
1446  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1447  SmallVector<bool> inBounds(readType.getRank(), true);
1448  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1449  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1450  }
1451 
1452  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1453  // TODO: remove this.
1454  if (readType.getRank() == 0)
1455  readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1456  ArrayRef<int64_t>());
1457 
1458  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1459  << "\n");
1460  bvm.map(bbarg, readValue);
1461  bvm.map(opOperand->get(), readValue);
1462  }
1463 
1465  // 4a. Register CustomVectorizationHook for yieldOp.
1466  CustomVectorizationHook vectorizeYield =
1467  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1468  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1469  };
1470  hooks.push_back(vectorizeYield);
1471 
1472  // 4b. Register CustomVectorizationHook for indexOp.
1473  CustomVectorizationHook vectorizeIndex =
1474  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1475  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1476  };
1477  hooks.push_back(vectorizeIndex);
1478 
1479  // 4c. Register CustomVectorizationHook for extractOp.
1480  CustomVectorizationHook vectorizeExtract =
1481  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1482  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1483  };
1484  hooks.push_back(vectorizeExtract);
1485 
1486  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1487  for (Operation &op : block->getOperations()) {
1488  VectorizationHookResult result =
1489  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1490  if (result.status == VectorizationHookStatus::Failure) {
1491  LDBG("failed to vectorize: " << op << "\n");
1492  return failure();
1493  }
1494  if (result.status == VectorizationHookStatus::NewOp) {
1495  Operation *maybeMaskedOp =
1496  state.maskOperation(rewriter, result.newOp, linalgOp);
1497  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1498  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1499  }
1500  }
1501 
1502  return success();
1503 }
1504 
1505 /// Given a linalg::PackOp, return the `dest` shape before any packing
1506 /// permutations.
1507 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1508  ArrayRef<int64_t> destShape) {
1509  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1510 }
1511 
1512 /// Determines whether a mask for xfer_write is trivially "all true"
1513 ///
1514 /// Given all the inputs required to generate a mask (mask sizes and shapes),
1515 /// and an xfer_write operation (write indices and the destination tensor
1516 /// shape), determines whether the corresponding mask would be trivially
1517 /// foldable (i.e., trivially "all true").
1518 ///
1519 /// Use this method to avoid generating spurious masks and relaying on
1520 /// vectorization post-processing to remove them.
1521 ///
1522 /// Pre-conditions for a mask to be trivially foldable:
1523 /// * All involved shapes (mask + destination tensor) are static.
1524 /// * All write indices are constant.
1525 /// * All mask sizes are constant (including `arith.constant`).
1526 ///
1527 /// If the pre-conditions are met, the method checks for each destination
1528 /// dimension `d`:
1529 /// (1) destDimSize[rankDiff + d] <= maskShape[d]
1530 /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1531 ///
1532 /// rankDiff = rank(dest) - rank(mask).
1533 ///
1534 /// This method takes a conservative view: it may return false even if the mask
1535 /// is technically foldable.
1536 ///
1537 /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1538 /// of the dest tensor):
1539 /// %c0 = arith.constant 0 : index
1540 /// %mask = vector.create_mask 5, 1
1541 /// vector.mask %mask {
1542 /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1543 /// {in_bounds = [true, true]}
1544 /// : vector<5x1xi32>, tensor<5x1xi32>
1545 /// }
1546 ///
1547 /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1548 /// mask is required to avoid out-of-bounds write):
1549 /// %c0 = arith.constant 0 : index
1550 /// %mask = vector.create_mask 5, 1
1551 /// vector.mask %mask {
1552 /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1553 /// {in_bounds = [true, true]}
1554 /// : vector<8x1xi32>, tensor<5x1xi32>
1555 /// }
1556 ///
1557 /// TODO: Re-use in createReadOrMaskedRead
1559  SmallVector<Value> &writeIdxs,
1560  ArrayRef<int64_t> destShape,
1561  ArrayRef<int64_t> maskShape) {
1562  // Masking is unavoidable in the case of dynamic tensors.
1563  if (ShapedType::isDynamicShape(destShape))
1564  return false;
1565 
1566  // Collect all constant mask sizes.
1567  SmallVector<int64_t, 4> cstMaskSizes;
1568  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1569  if (auto intSize = getConstantIntValue(dimSize)) {
1570  cstMaskSizes.push_back(*intSize);
1571  }
1572  }
1573 
1574  // If any of the mask sizes is non-constant, bail out.
1575  if (cstMaskSizes.size() != maskShape.size())
1576  return false;
1577 
1578  // Collect all constant write indices.
1579  SmallVector<int64_t, 4> cstWriteIdxs;
1580  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1581  APSInt intVal;
1582  if (matchPattern(idx, m_ConstantInt(&intVal))) {
1583  cstWriteIdxs.push_back(intVal.getSExtValue());
1584  }
1585  }
1586 
1587  // If any of the write indices is non-constant, bail out.
1588  if (cstWriteIdxs.size() != destShape.size())
1589  return false;
1590 
1591  // Go over all destination dims and check (1) and (2). Take into account that:
1592  // * The number of mask sizes will match the rank of the vector to store.
1593  // This could be lower than the rank of the destination tensor.
1594  // * Mask sizes could be larger than the corresponding mask shape (hence
1595  // `clamp`).
1596  // TODO: The 2nd item should be rejected by the verifier.
1597  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1598  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1599  if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1600  /*(2)*/ destShape[rankDiff + i] <
1601  (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1602  cstWriteIdxs[i]))
1603  return false;
1604  }
1605 
1606  return true;
1607 }
1608 
1609 /// Creates an optionally masked TransferWriteOp
1610 ///
1611 /// Generates the following operation:
1612 /// %res = vector.transfer_write %vecToStore into %dest
1613 ///
1614 /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1615 ///
1616 /// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1617 /// %res = vector.mask %mask {
1618 /// vector.transfer_write %vecToStore into %dest
1619 /// }
1620 ///
1621 /// The mask shape is identical to `vecToStore` (with the element type ==
1622 /// i1), and the mask values are based on the shape of the `dest` tensor.
1623 ///
1624 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1625 /// is used instead of masking:
1626 ///
1627 /// %write = vector.transfer_write %vecToStore into %dest
1628 /// in_bounds_flags = (...)
1629 /// %res = vector.transfer_write %input into %dest
1630 /// {in_bounds = in_bounds_flags}
1631 ///
1632 /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1633 /// are set to 0.
1634 static Operation *
1636  Value dest, SmallVector<Value> writeIndices = {},
1637  bool useInBoundsInsteadOfMasking = false) {
1638 
1639  ShapedType destType = cast<ShapedType>(dest.getType());
1640  int64_t destRank = destType.getRank();
1641  auto destShape = destType.getShape();
1642 
1643  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1644  int64_t vecToStoreRank = vecToStoreType.getRank();
1645  auto vecToStoreShape = vecToStoreType.getShape();
1646 
1647  // Compute the in_bounds attribute
1648  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1649  if (useInBoundsInsteadOfMasking) {
1650  // Update the inBounds attribute.
1651  // FIXME: This computation is too weak - it ignores the write indices.
1652  for (unsigned i = 0; i < vecToStoreRank; i++)
1653  inBoundsVal[i] =
1654  (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1655  ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1656  }
1657 
1658  // If missing, initialize the write indices to 0.
1659  assert((writeIndices.empty() ||
1660  writeIndices.size() == static_cast<size_t>(destRank)) &&
1661  "Invalid number of write indices!");
1662  if (writeIndices.empty()) {
1663  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1664  writeIndices.assign(destRank, zero);
1665  }
1666 
1667  // Generate the xfer_write Op
1668  Operation *write =
1669  builder.create<vector::TransferWriteOp>(loc,
1670  /*vector=*/vecToStore,
1671  /*source=*/dest,
1672  /*indices=*/writeIndices,
1673  /*inBounds=*/inBoundsVal);
1674 
1675  // If masking is disabled, exit.
1676  if (useInBoundsInsteadOfMasking)
1677  return write;
1678 
1679  // Check if masking is needed. If not, exit.
1680  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1681  return write;
1682 
1683  // Compute the mask and mask the write Op.
1684  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1685 
1686  SmallVector<OpFoldResult> destSizes =
1687  tensor::getMixedSizes(builder, loc, dest);
1688  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1689  destSizes.end());
1690 
1691  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1692  vecToStoreShape))
1693  return write;
1694 
1695  Value maskForWrite =
1696  builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1697  return mlir::vector::maskOperation(builder, write, maskForWrite);
1698 }
1699 
1700 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1701 /// padding value and (3) input vector sizes into:
1702 ///
1703 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1704 ///
1705 /// As in the following example:
1706 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1707 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1708 ///
1709 /// This pack would be vectorized to:
1710 ///
1711 /// %load = vector.mask %mask {
1712 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1713 /// {in_bounds = [true, true, true]} :
1714 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1715 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1716 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1717 /// to vector<32x4x2x1x16xf32>
1718 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1719 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1720 /// %write = vector.transfer_write %transpose,
1721 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1722 /// {in_bounds = [true, true, true, true, true]}
1723 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1724 ///
1725 /// If the (3) input vector sizes are not provided, the vector sizes are
1726 /// determined by the result tensor shape and the `in_bounds`
1727 /// attribute is used instead of masking to mark out-of-bounds accesses.
1728 ///
1729 /// NOTE: The input vector sizes specify the dimensions corresponding to the
1730 /// outer dimensions of the output tensor. The remaining dimensions are
1731 /// computed based on, e.g., the static inner tiles.
1732 /// Supporting dynamic inner tiles will require the user to specify the
1733 /// missing vector sizes. This is left as a TODO.
1734 static LogicalResult
1735 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1736  ArrayRef<int64_t> inputVectorSizes,
1737  SmallVectorImpl<Value> &newResults) {
1738  // TODO: Introduce a parent class that will handle the insertion point update.
1739  OpBuilder::InsertionGuard g(rewriter);
1740  rewriter.setInsertionPoint(packOp);
1741 
1742  Location loc = packOp.getLoc();
1743  auto padValue = packOp.getPaddingValue();
1744  if (!padValue) {
1745  padValue = rewriter.create<arith::ConstantOp>(
1746  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1747  }
1748  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1749  LogicalResult status =
1750  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1751  .reifyResultShapes(rewriter, reifiedReturnShapes);
1752  (void)status; // prevent unused variable warning on non-assert builds.
1753  assert(succeeded(status) && "failed to reify result shapes");
1754 
1755  // If the input vector sizes are not provided, then the vector sizes are
1756  // determined by the result tensor shape. In case the vector sizes aren't
1757  // provided, we update the inBounds attribute instead of masking.
1758  bool useInBoundsInsteadOfMasking = false;
1759  if (inputVectorSizes.empty()) {
1760  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1761  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1762  useInBoundsInsteadOfMasking = true;
1763  }
1764 
1765  // Create masked TransferReadOp.
1766  SmallVector<int64_t> inputShape(inputVectorSizes);
1767  auto innerTiles = packOp.getStaticInnerTiles();
1768  auto innerDimsPos = packOp.getInnerDimsPos();
1769  auto outerDimsPerm = packOp.getOuterDimsPerm();
1770  if (!outerDimsPerm.empty())
1771  applyPermutationToVector(inputShape,
1773  for (auto [idx, size] : enumerate(innerTiles))
1774  inputShape[innerDimsPos[idx]] *= size;
1775  auto maskedRead = vector::createReadOrMaskedRead(
1776  rewriter, loc, packOp.getSource(), inputShape, padValue,
1777  useInBoundsInsteadOfMasking);
1778 
1779  // Create ShapeCastOp.
1780  SmallVector<int64_t> destShape(inputVectorSizes);
1781  destShape.append(innerTiles.begin(), innerTiles.end());
1782  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1783  packOp.getDestType().getElementType());
1784  auto shapeCastOp =
1785  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1786 
1787  // Create TransposeOp.
1788  auto destPermutation =
1790  auto transposeOp = rewriter.create<vector::TransposeOp>(
1791  loc, shapeCastOp.getResult(), destPermutation);
1792 
1793  // Create TransferWriteOp.
1794  Value dest = rewriter.create<tensor::EmptyOp>(
1795  loc, reifiedReturnShapes[0],
1796  transposeOp.getResult().getType().getElementType());
1797  Operation *write =
1798  createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
1799  newResults.push_back(write->getResult(0));
1800  return success();
1801 }
1802 
1803 /// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1804 /// Vector::TransferReadOp - Reads a vector from the source tensor
1805 /// vector::TransposeOp - Transpose the Source tensor
1806 /// ShapeCastOp - Reshape the data based on the target.
1807 /// vector::TransferWriteOp. - Write the result vector back to the destination
1808 /// tensor.
1809 /// If the vector sizes are not provided:
1810 /// * the vector sizes are determined by the input operand and attributes,
1811 /// * update the inBounds attribute instead of masking.
1812 static LogicalResult
1813 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1814  ArrayRef<int64_t> inputVectorSizes,
1815  SmallVectorImpl<Value> &newResults) {
1816 
1817  // TODO: Introduce a parent class that will handle the insertion point update.
1818  OpBuilder::InsertionGuard g(rewriter);
1819  rewriter.setInsertionPoint(unpackOp);
1820 
1821  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1822 
1823  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1824  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1825  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1826  bool useInBoundsInsteadOfMasking = false;
1827  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1828 
1829  auto destSize = unpackOp.getDestRank();
1830 
1831  if (!inputVectorSizes.empty())
1832  assert(inputVectorSizes.size() == destSize &&
1833  "Incorrect number of input vector sizes");
1834 
1835  // vectorSizes is the shape of the vector that will be used to do final
1836  // write on the destination tensor. It is set like this: Let's say the
1837  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1838  // Thus:
1839  // 1. vectorSizes = sourceShape.take_front(N)
1840  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1841  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1842  // innerTiles attribute value.
1843  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1844  if (vectorSizes.empty()) {
1845  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1846  if (!outerDimsPerm.empty())
1848  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1849  vectorSizes[pos] *= innerTiles[i];
1850 
1851  useInBoundsInsteadOfMasking = true;
1852  }
1853 
1854  // readVectorSizes is the size of tensor used to read and apply mask. It is
1855  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1856  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1857  // size M-N
1858  // Thus:
1859  // - initially: readVectorSizes = vectorInputSizes
1860  // - Divide all the readMaskShape locations pointed by innerDimPos
1861  // by the innerTileSize attribute value.
1862  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1863  // - Append the remaining shape from SS
1864  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1865  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1866  // 128] and outer_dims_perm is [1, 0] then read shape is:
1867  // ReadVectorSizes(initial): [512, 128]
1868  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1869  // = [16, 8]
1870  // After applying outer_dims_perm: [8, 16]
1871  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1872 
1873  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1874 
1875  for (auto [index, size] : enumerate(innerTiles)) {
1876  readVectorSizes[innerDimPos[index]] =
1877  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1878  }
1879  if (!outerDimsPerm.empty()) {
1880  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1881  }
1882  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1883  sourceShape.end());
1884 
1885  ReifiedRankedShapedTypeDims reifiedRetShapes;
1886  LogicalResult status =
1887  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1888  .reifyResultShapes(rewriter, reifiedRetShapes);
1889  if (status.failed()) {
1890  LDBG("Unable to reify result shapes of " << unpackOp);
1891  return failure();
1892  }
1893  Location loc = unpackOp->getLoc();
1894 
1895  auto padValue = rewriter.create<arith::ConstantOp>(
1896  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1897 
1898  // Read result, mask if necessary. If transferReadOp shape is not equal
1899  // to shape of source, then a mask is necessary.
1901  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1902  /*useInBoundsInsteadOfMasking=*/false);
1903 
1904  PackingMetadata packMetadata;
1905  SmallVector<int64_t> lastDimToInsertPosPerm =
1906  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1907  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1908  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1909  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1910  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1911  RankedTensorType stripMineTensorType =
1912  RankedTensorType::get(stripMineShape, stripMineElemType);
1913  // Transpose the appropriate rows to match output.
1914  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1915  loc, readResult, lastDimToInsertPosPerm);
1916 
1917  // Collapse the vector to the size required by result.
1918  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1919  stripMineTensorType, packMetadata.reassociations);
1920  mlir::VectorType vecCollapsedType =
1921  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1922  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1923  loc, vecCollapsedType, transposeOp->getResult(0));
1924 
1925  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1926  // otherwise the validator complains that the mask size is invalid.
1927  SmallVector<int64_t> writeVectorSizes(
1928  unpackOp.getDestType().hasStaticShape()
1929  ? vectorSizes
1930  : shapeCastOp.getResultVectorType().getShape());
1931  Value dest = rewriter.create<tensor::EmptyOp>(
1932  loc, reifiedRetShapes[0],
1933  shapeCastOp.getResult().getType().getElementType());
1935  rewriter, loc, shapeCastOp.getResult(), dest,
1936  /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1937  newResults.push_back(write->getResult(0));
1938  return success();
1939 }
1940 
1941 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1942 /// and (3) all-zero lowPad to
1943 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1944 static LogicalResult
1945 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1946  ArrayRef<int64_t> inputVectorSizes,
1947  SmallVectorImpl<Value> &newResults) {
1948  auto padValue = padOp.getConstantPaddingValue();
1949  Location loc = padOp.getLoc();
1950 
1951  // TODO: Introduce a parent class that will handle the insertion point update.
1952  OpBuilder::InsertionGuard g(rewriter);
1953  rewriter.setInsertionPoint(padOp);
1954 
1955  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1956  LogicalResult status =
1957  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1958  .reifyResultShapes(rewriter, reifiedReturnShapes);
1959  (void)status; // prevent unused variable warning on non-assert builds
1960  assert(succeeded(status) && "failed to reify result shapes");
1961  auto maskedRead = vector::createReadOrMaskedRead(
1962  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1963  /*useInBoundsInsteadOfMasking=*/false);
1964 
1965  // Create Xfer write Op
1966  Value dest = rewriter.create<tensor::EmptyOp>(
1967  loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1968  Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
1969  newResults.push_back(write->getResult(0));
1970  return success();
1971 }
1972 
1973 // TODO: probably need some extra checks for reduction followed by consumer
1974 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1975 static LogicalResult reductionPreconditions(LinalgOp op) {
1976  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1977  LDBG("reduction precondition failed: no reduction iterator\n");
1978  return failure();
1979  }
1980  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1981  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1982  if (indexingMap.isPermutation())
1983  continue;
1984 
1985  Operation *reduceOp = matchLinalgReduction(&opOperand);
1986  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1987  LDBG("reduction precondition failed: reduction detection failed\n");
1988  return failure();
1989  }
1990  }
1991  return success();
1992 }
1993 
1994 static LogicalResult
1996  bool flatten1DDepthwiseConv) {
1997  if (flatten1DDepthwiseConv) {
1998  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1999  "supported\n");
2000  return failure();
2001  }
2002 
2003  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2004  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
2005  return failure();
2006  }
2007 
2008  // Support dynamic shapes in 1D depthwise convolution, but only in the
2009  // _channel_ dimension.
2010  Value lhs = conv.getDpsInputOperand(0)->get();
2011  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2012  auto shapeWithoutCh = lhsShape.drop_back(1);
2013  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2014  LDBG("Dynamically-shaped op vectorization precondition failed: only "
2015  "channel dim can be dynamic\n");
2016  return failure();
2017  }
2018 
2019  return success();
2020 }
2021 
2022 static LogicalResult
2024  bool flatten1DDepthwiseConv) {
2025  if (isa<ConvolutionOpInterface>(op.getOperation()))
2026  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2027 
2028  if (hasReductionIterator(op))
2029  return reductionPreconditions(op);
2030 
2031  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2032  // linalg.copy ops and ops that implement ContractionOpInterface for now.
2033  if (!isElementwise(op) &&
2034  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2035  op.getOperation()))
2036  return failure();
2037 
2038  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
2039  return success();
2040 }
2041 
2042 /// Need to check if the inner-tiles are static/constant.
2043 static LogicalResult
2044 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2045  ArrayRef<int64_t> inputVectorSizes) {
2046 
2047  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
2048  return !getConstantIntValue(res).has_value();
2049  })) {
2050  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
2051  return failure();
2052  }
2053  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2054  bool satisfyEmptyCond = inputVectorSizes.empty() &&
2055  unpackOp.getDestType().hasStaticShape() &&
2056  unpackOp.getSourceType().hasStaticShape();
2057  if (!satisfyEmptyCond &&
2058  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
2059  return failure();
2060 
2061  return success();
2062 }
2063 
2064 static LogicalResult
2065 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2066  ArrayRef<int64_t> inputVectorSizes) {
2067 
2068  TypedValue<RankedTensorType> source = sliceOp.getSource();
2069  auto sourceType = source.getType();
2070  if (!VectorType::isValidElementType(sourceType.getElementType()))
2071  return failure();
2072 
2073  // Get the pad value.
2074  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2075  // scalar padding value. Note that:
2076  // * for in-bounds accesses,
2077  // the value is actually irrelevant. There are 2 cases in which xfer.read
2078  // accesses are known to be in-bounds:
2079  // 1. The source shape is static (output vector sizes would be based on
2080  // the source shape and hence all memory accesses would be in-bounds),
2081  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2082  // this case it is safe to assume that all memory accesses are in-bounds.
2083  //
2084  // When the value is not known and not needed, use 0. Otherwise, bail out.
2085  Value padValue = getStaticPadVal(sliceOp);
2086  bool isOutOfBoundsRead =
2087  !sourceType.hasStaticShape() && inputVectorSizes.empty();
2088 
2089  if (!padValue && isOutOfBoundsRead) {
2090  LDBG("Failed to get a pad value for out-of-bounds read access\n");
2091  return failure();
2092  }
2093  return success();
2094 }
2095 
2096 namespace {
2097 enum class ConvOperationKind { Conv, Pool };
2098 } // namespace
2099 
2101  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2102  isa<BlockArgument>(op->getOperand(0));
2103 }
2104 
2105 // Returns the ConvOperationKind of the op using reduceOp of the generic
2106 // payload. If it is neither a convolution nor a pooling, it returns
2107 // std::nullopt.
2108 //
2109 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2110 // + yield) and rhs is not used) then it is the body of a pooling
2111 // If conv, check for single `mul` predecessor. The `mul` operands must be
2112 // block arguments or extension of block arguments.
2113 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2114 // must be block arguments or extension of block arguments.
2115 static std::optional<ConvOperationKind>
2117  int numBlockArguments =
2118  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2119 
2120  switch (numBlockArguments) {
2121  case 1: {
2122  // Will be convolution if feeder is a MulOp.
2123  // A strength reduced version of MulOp for i1 type is AndOp which is also
2124  // supported. Otherwise, it can be pooling. This strength reduction logic
2125  // is in `buildBinaryFn` helper in the Linalg dialect.
2126  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2127  llvm::IsaPred<BlockArgument>);
2128  assert(feedValIt != reduceOp->operand_end() &&
2129  "Expected a non-block argument operand");
2130  Operation *feedOp = (*feedValIt).getDefiningOp();
2131  if (isCastOfBlockArgument(feedOp)) {
2132  return ConvOperationKind::Pool;
2133  }
2134 
2135  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2136  (isa<arith::AndIOp>(feedOp) &&
2137  feedOp->getResultTypes()[0].isInteger(1))) &&
2138  llvm::all_of(feedOp->getOperands(), [](Value v) {
2139  if (isa<BlockArgument>(v))
2140  return true;
2141  if (Operation *op = v.getDefiningOp())
2142  return isCastOfBlockArgument(op);
2143  return false;
2144  }))) {
2145  return std::nullopt;
2146  }
2147 
2148  return ConvOperationKind::Conv;
2149  }
2150  case 2:
2151  // Must be pooling
2152  return ConvOperationKind::Pool;
2153  default:
2154  return std::nullopt;
2155  }
2156 }
2157 
2158 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2159  switch (kind) {
2160  case vector::CombiningKind::ADD:
2161  case vector::CombiningKind::MAXNUMF:
2162  case vector::CombiningKind::MAXIMUMF:
2163  case vector::CombiningKind::MAXSI:
2164  case vector::CombiningKind::MAXUI:
2165  case vector::CombiningKind::MINNUMF:
2166  case vector::CombiningKind::MINIMUMF:
2167  case vector::CombiningKind::MINSI:
2169  return true;
2170  default:
2171  return false;
2172  }
2173 }
2174 
2175 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2176  auto getOperandType = [&](auto operand) {
2177  return dyn_cast<ShapedType>((operand->get()).getType());
2178  };
2179  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2180  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2181  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2182  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2183  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2184  // Note that this also ensures 2D and 3D convolutions are rejected.
2185  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2186  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2187  return failure();
2188 
2189  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2190  if (!reduceOp)
2191  return failure();
2192 
2193  auto maybeOper = getConvOperationKind(reduceOp);
2194  if (!maybeOper.has_value())
2195  return failure();
2196 
2197  auto maybeKind = getCombinerOpKind(reduceOp);
2198  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2199  // can get strength reduced to `OR` which is also supported. This strength
2200  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2201  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2202  *maybeKind != vector::CombiningKind::OR) &&
2203  (*maybeOper != ConvOperationKind::Pool ||
2204  !isSupportedPoolKind(*maybeKind)))) {
2205  return failure();
2206  }
2207 
2208  auto rhsRank = rhsShapedType.getRank();
2209  if (*maybeOper == ConvOperationKind::Pool) {
2210  if (rhsRank != 1)
2211  return failure();
2212  } else {
2213  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2214  return failure();
2215  }
2216 
2217  return success();
2218 }
2219 
2220 static LogicalResult vectorizeLinalgOpPrecondition(
2221  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2222  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2223  // tensor with dimension of 0 cannot be vectorized.
2224  if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2225  return llvm::is_contained(linalgOp.getShape(&operand), 0);
2226  }))
2227  return failure();
2228  // Check API contract for input vector sizes.
2229  if (!inputVectorSizes.empty() &&
2230  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2231  inputVectorSizes)))
2232  return failure();
2233 
2234  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2235  linalgOp, flatten1DDepthwiseConv))) {
2236  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
2237  return failure();
2238  }
2239 
2241 
2242  // Register CustomVectorizationPrecondition for extractOp.
2243  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2244 
2245  // All types in the body should be a supported element type for VectorType.
2246  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2247  // Check if any custom hook can vectorize the inner op.
2248  if (llvm::any_of(
2249  customPreconditions,
2250  [&](const CustomVectorizationPrecondition &customPrecondition) {
2251  return succeeded(
2252  customPrecondition(&innerOp, vectorizeNDExtract));
2253  })) {
2254  continue;
2255  }
2256  if (!llvm::all_of(innerOp.getOperandTypes(),
2257  VectorType::isValidElementType)) {
2258  return failure();
2259  }
2260  if (!llvm::all_of(innerOp.getResultTypes(),
2261  VectorType::isValidElementType)) {
2262  return failure();
2263  }
2264  }
2265  if (isElementwise(linalgOp))
2266  return success();
2267 
2268  // TODO: isaConvolutionOpInterface that can also infer from generic
2269  // features. But we will still need stride/dilation attributes that will be
2270  // annoying to reverse-engineer...
2271  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2272  return vectorizeConvOpPrecondition(linalgOp);
2273 
2274  // TODO: the common vector shape is equal to the static loop sizes only when
2275  // all indexing maps are projected permutations. For convs and stencils the
2276  // logic will need to evolve.
2277  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2278  LDBG("precondition failed: not projected permutations\n");
2279  return failure();
2280  }
2281  if (failed(reductionPreconditions(linalgOp))) {
2282  LDBG("precondition failed: reduction preconditions\n");
2283  return failure();
2284  }
2285  return success();
2286 }
2287 
2288 static LogicalResult
2289 vectorizePackOpPrecondition(linalg::PackOp packOp,
2290  ArrayRef<int64_t> inputVectorSizes) {
2291  auto padValue = packOp.getPaddingValue();
2292  Attribute cstAttr;
2293  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2294  LDBG("pad value is not constant: " << packOp << "\n");
2295  return failure();
2296  }
2297  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2298  bool satisfyEmptyCond = true;
2299  if (inputVectorSizes.empty()) {
2300  if (!packOp.getDestType().hasStaticShape() ||
2301  !packOp.getSourceType().hasStaticShape())
2302  satisfyEmptyCond = false;
2303  }
2304 
2305  if (!satisfyEmptyCond &&
2307  resultTensorShape.take_front(packOp.getSourceRank()),
2308  inputVectorSizes)))
2309  return failure();
2310 
2311  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2312  return !getConstantIntValue(v).has_value();
2313  })) {
2314  LDBG("inner_tiles must be constant: " << packOp << "\n");
2315  return failure();
2316  }
2317 
2318  return success();
2319 }
2320 
2321 static LogicalResult
2322 vectorizePadOpPrecondition(tensor::PadOp padOp,
2323  ArrayRef<int64_t> inputVectorSizes) {
2324  auto padValue = padOp.getConstantPaddingValue();
2325  if (!padValue) {
2326  LDBG("pad value is not constant: " << padOp << "\n");
2327  return failure();
2328  }
2329 
2330  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2331  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2332  inputVectorSizes)))
2333  return failure();
2334 
2335  // Padding with non-zero low pad values is not supported, unless the
2336  // corresponding result dim is 1 as this would require shifting the results to
2337  // the right for the low padded dims by the required amount of low padding.
2338  // However, we do support low padding if the dims being low padded have result
2339  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2340  // input size of that dimension will be dynamically zero (as the sum of the
2341  // low pad and input dim size has to be one) and hence we will create a zero
2342  // mask as the lowering logic just makes the mask one for the input dim size -
2343  // which is zero here. Hence we will load the pad value which is what we want
2344  // in this case. If the low pad is dynamically zero then the lowering is
2345  // correct as well as no shifts are necessary.
2346  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2347  Value padValue = en.value();
2348  unsigned pos = en.index();
2349  std::optional<int64_t> pad = getConstantIntValue(padValue);
2350  return (!pad.has_value() || pad.value() != 0) &&
2351  resultTensorShape[pos] != 1;
2352  })) {
2353  LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
2354  return failure();
2355  }
2356 
2357  return success();
2358 }
2359 
2360 /// Preconditions for scalable vectors. This is quite restrictive - it models
2361 /// the fact that in practice we would only make selected dimensions scalable.
2362 static LogicalResult
2364  ArrayRef<int64_t> inputVectorSizes,
2365  ArrayRef<bool> inputScalableVecDims) {
2366  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2367  "Number of input vector sizes and scalable dims doesn't match");
2368 
2369  size_t numOfScalableDims =
2370  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2371 
2372  if (numOfScalableDims == 0)
2373  return success();
2374 
2375  auto linalgOp = dyn_cast<LinalgOp>(op);
2376 
2377  // Cond 1: There's been no need for scalable vectorisation of
2378  // non-linalg Ops so far
2379  if (!linalgOp)
2380  return failure();
2381 
2382  // Cond 2: There's been no need for more than 2 scalable dims so far
2383  if (numOfScalableDims > 2)
2384  return failure();
2385 
2386  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2387  // it matches one of the supported cases:
2388  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2389  // (*).
2390  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2391  // parallel dims.
2392  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2393  // dim.
2394  // The 2nd restriction above means that only Matmul-like Ops are supported
2395  // when 2 dims are scalable, e.g. :
2396  // * iterators = [parallel, parallel, reduction]
2397  // * scalable flags = [true, true, false]
2398  //
2399  // (*) Non-unit dims get folded away in practice.
2400  // TODO: Relax these conditions as good motivating examples are identified.
2401 
2402  // Find the first scalable flag.
2403  bool seenNonUnitParallel = false;
2404  auto iterators = linalgOp.getIteratorTypesArray();
2405  SmallVector<bool> scalableFlags(inputScalableVecDims);
2406  int64_t idx = scalableFlags.size() - 1;
2407  while (!scalableFlags[idx]) {
2408  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2409  seenNonUnitParallel |=
2410  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2411 
2412  iterators.pop_back();
2413  scalableFlags.pop_back();
2414  --idx;
2415  }
2416 
2417  // Analyze the iterator corresponding to the first scalable dim.
2418  switch (iterators.back()) {
2419  case utils::IteratorType::reduction: {
2420  // Check 3. above is met.
2421  if (iterators.size() != inputVectorSizes.size()) {
2422  LDBG("Non-trailing reduction dim requested for scalable "
2423  "vectorization\n");
2424  return failure();
2425  }
2426  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2427  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2428  "is not supported\n");
2429  return failure();
2430  }
2431  break;
2432  }
2433  case utils::IteratorType::parallel: {
2434  // Check 1. and 2. above are met.
2435  if (seenNonUnitParallel) {
2436  LDBG("Inner parallel dim not requested for scalable "
2437  "vectorization\n");
2438  return failure();
2439  }
2440  break;
2441  }
2442  }
2443 
2444  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2445  // supported for which expect the folowing config:
2446  // * iterators = [parallel, parallel, reduction]
2447  // * scalable flags = [true, true, false]
2448  if (numOfScalableDims == 2) {
2449  // Disallow below case which breaks 3. above:
2450  // * iterators = [..., parallel, reduction]
2451  // * scalable flags = [..., true, true]
2452  if (iterators.back() == utils::IteratorType::reduction) {
2453  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2454  "vectorization\n");
2455  return failure();
2456  }
2457  scalableFlags.pop_back();
2458  iterators.pop_back();
2459 
2460  if (!scalableFlags.back() ||
2461  (iterators.back() != utils::IteratorType::parallel))
2462  return failure();
2463  }
2464 
2465  // Check to not let go the matmul with extended semantic, through this
2466  // transform.
2467  if (linalgOp.hasUserDefinedMaps())
2468  return failure();
2469 
2470  // Cond 4: Only the following ops are supported in the
2471  // presence of scalable vectors
2472  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2473  isa<linalg::MatmulTransposeAOp>(op) ||
2474  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2475  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2476 }
2477 
2479  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2480  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2481  bool flatten1DDepthwiseConv) {
2482 
2483  if (!hasVectorizationImpl(op))
2484  return failure();
2485 
2486  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2487  inputScalableVecDims)))
2488  return failure();
2489 
2491  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2492  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2493  vectorizeNDExtract,
2494  flatten1DDepthwiseConv);
2495  })
2496  .Case<tensor::PadOp>([&](auto padOp) {
2497  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2498  })
2499  .Case<linalg::PackOp>([&](auto packOp) {
2500  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2501  })
2502  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2503  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2504  })
2505  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2506  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2507  })
2508  .Default([](auto) { return failure(); });
2509 }
2510 
2511 /// Converts affine.apply Ops to arithmetic operations.
2512 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2513  OpBuilder::InsertionGuard g(rewriter);
2514  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2515 
2516  for (auto op : make_early_inc_range(toReplace)) {
2517  rewriter.setInsertionPoint(op);
2518  auto expanded = affine::expandAffineExpr(
2519  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2520  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2521  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2522  rewriter.replaceOp(op, expanded);
2523  }
2524 }
2525 
2527  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2528  tensor::InsertSliceOp>(op);
2529 }
2530 
2531 FailureOr<VectorizationResult>
2533  ArrayRef<int64_t> inputVectorSizes,
2534  ArrayRef<bool> inputScalableVecDims,
2535  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2536  LDBG("Attempting to vectorize:\n" << *op << "\n");
2537  LDBG("Input vector sizes: ");
2538  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2539  LLVM_DEBUG(llvm::dbgs() << "\n");
2540  LDBG("Input scalable vector dims: ");
2541  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2542  LLVM_DEBUG(llvm::dbgs() << "\n");
2543 
2544  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2545  vectorizeNDExtract,
2546  flatten1DDepthwiseConv))) {
2547  LDBG("Vectorization pre-conditions failed\n");
2548  return failure();
2549  }
2550 
2551  // Initialize vectorization state.
2552  VectorizationState state(rewriter);
2553  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2554  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2555  inputScalableVecDims))) {
2556  LDBG("Vectorization state couldn't be initialized\n");
2557  return failure();
2558  }
2559  }
2560 
2561  SmallVector<Value> results;
2562  auto vectorizeResult =
2564  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2565  // TODO: isaConvolutionOpInterface that can also infer from
2566  // generic features. Will require stride/dilation attributes
2567  // inference.
2568  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2569  FailureOr<Operation *> convOr = vectorizeConvolution(
2570  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2571  flatten1DDepthwiseConv);
2572  if (succeeded(convOr)) {
2573  llvm::append_range(results, (*convOr)->getResults());
2574  return success();
2575  }
2576 
2577  LDBG("Unsupported convolution can't be vectorized.\n");
2578  return failure();
2579  }
2580 
2581  LDBG("Vectorize generic by broadcasting to the canonical vector "
2582  "shape\n");
2583 
2584  // Pre-process before proceeding.
2585  convertAffineApply(rewriter, linalgOp);
2586 
2587  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2588  // to 'OpBuilder' when it is passed over to some methods like
2589  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2590  // erase an op within these methods, the actual rewriter won't be
2591  // notified and we will end up with read-after-free issues!
2592  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2593  })
2594  .Case<tensor::PadOp>([&](auto padOp) {
2595  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2596  results);
2597  })
2598  .Case<linalg::PackOp>([&](auto packOp) {
2599  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2600  results);
2601  })
2602  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2603  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2604  inputVectorSizes, results);
2605  })
2606  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2607  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2608  results);
2609  })
2610  .Default([](auto) { return failure(); });
2611 
2612  if (failed(vectorizeResult)) {
2613  LDBG("Vectorization failed\n");
2614  return failure();
2615  }
2616 
2617  return VectorizationResult{results};
2618 }
2619 
2621  memref::CopyOp copyOp) {
2622  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2623  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2624  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2625  return failure();
2626 
2627  auto srcElementType = getElementTypeOrSelf(srcType);
2628  auto dstElementType = getElementTypeOrSelf(dstType);
2629  if (!VectorType::isValidElementType(srcElementType) ||
2630  !VectorType::isValidElementType(dstElementType))
2631  return failure();
2632 
2633  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2634  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2635 
2636  Location loc = copyOp->getLoc();
2637  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2638  SmallVector<Value> indices(srcType.getRank(), zero);
2639 
2640  Value readValue = rewriter.create<vector::TransferReadOp>(
2641  loc, readType, copyOp.getSource(), indices,
2642  /*padding=*/std::nullopt,
2643  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2644  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2645  readValue =
2646  rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2647  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2648  }
2649  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2650  loc, readValue, copyOp.getTarget(), indices,
2651  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2652  rewriter.replaceOp(copyOp, writeValue->getResults());
2653  return success();
2654 }
2655 
2656 //----------------------------------------------------------------------------//
2657 // Misc. vectorization patterns.
2658 //----------------------------------------------------------------------------//
2659 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2660 /// given operation type OpTy.
2661 template <typename OpTy>
2662 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2664 
2665  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2666  PatternRewriter &rewriter) const final {
2667  bool changed = false;
2668  // Insert users in vector, because some users may be replaced/removed.
2669  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2670  if (auto op = dyn_cast<OpTy>(user))
2671  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2672  return success(changed);
2673  }
2674 
2675 protected:
2676  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2677  tensor::PadOp padOp, OpTy op) const = 0;
2678 };
2679 
2680 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2681 /// ```
2682 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2683 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2684 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2685 /// ```
2686 /// is rewritten to:
2687 /// ```
2688 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2689 /// {in_bounds = [true, true]}
2690 /// : tensor<?x?xf32>, vector<17x5xf32>
2691 /// ```
2692 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2693 /// sure that the original padding value %cst was never used.
2694 ///
2695 /// This rewrite is possible if:
2696 /// - `xferOp` has no out-of-bounds dims or mask.
2697 /// - Low padding is static 0.
2698 /// - Single, scalar padding value.
2700  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2702  vector::TransferReadOp>::VectorizePadOpUserPattern;
2703 
2704  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2705  vector::TransferReadOp xferOp) const override {
2706  // Low padding must be static 0.
2707  if (!padOp.hasZeroLowPad())
2708  return failure();
2709  // Pad value must be a constant.
2710  auto padValue = padOp.getConstantPaddingValue();
2711  if (!padValue)
2712  return failure();
2713  // Padding value of existing `xferOp` is unused.
2714  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2715  return failure();
2716 
2717  rewriter.modifyOpInPlace(xferOp, [&]() {
2718  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2719  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2720  rewriter.getBoolArrayAttr(inBounds));
2721  xferOp.getBaseMutable().assign(padOp.getSource());
2722  xferOp.getPaddingMutable().assign(padValue);
2723  });
2724 
2725  return success();
2726  }
2727 };
2728 
2729 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2730 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2731 /// value, where the same amount of padding is immediately removed again after
2732 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2733 /// tensor value and apply out-of-bounds masking. E.g.:
2734 /// ```
2735 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2736 /// : tensor<...> to tensor<?x?xf32>
2737 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2738 /// %2 = vector.transfer_write %vec, %1[...]
2739 /// : vector<17x5xf32>, tensor<17x5xf32>
2740 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2741 /// : tensor<17x5xf32> to tensor<?x?xf32>
2742 /// ```
2743 /// is rewritten to:
2744 /// ```
2745 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2746 /// : tensor<...> to tensor<?x?xf32>
2747 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2748 /// tensor<?x?xf32>
2749 /// ```
2750 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2751 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2752 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2753 /// from %r's old dimensions.
2754 ///
2755 /// This rewrite is possible if:
2756 /// - Low padding is static 0.
2757 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2758 /// ExtractSliceOp trims the same amount of padding that was added
2759 /// beforehand.
2760 /// - Single, scalar padding value.
2762  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2764  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2765 
2766  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2767  vector::TransferWriteOp xferOp) const override {
2768  // TODO: support 0-d corner case.
2769  if (xferOp.getTransferRank() == 0)
2770  return failure();
2771 
2772  // Low padding must be static 0.
2773  if (!padOp.hasZeroLowPad())
2774  return failure();
2775  // Pad value must be a constant.
2776  auto padValue = padOp.getConstantPaddingValue();
2777  if (!padValue)
2778  return failure();
2779  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2780  if (!xferOp->hasOneUse())
2781  return failure();
2782  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2783  if (!trimPadding)
2784  return failure();
2785  // Only static zero offsets supported when trimming padding.
2786  if (!trimPadding.hasZeroOffset())
2787  return failure();
2788  // trimPadding must remove the amount of padding that was added earlier.
2789  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2790  return failure();
2791 
2792  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2793  rewriter.setInsertionPoint(xferOp);
2794 
2795  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2796  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2797  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2798  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2799  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2800  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2801 
2802  return success();
2803  }
2804 
2805  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2806  /// i.e., same dimensions.
2807  ///
2808  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2809  /// dimensions, this function tries to infer the (static) tensor size by
2810  /// looking at the defining op and utilizing op-specific knowledge.
2811  ///
2812  /// This is a conservative analysis. In case equal tensor sizes cannot be
2813  /// proven statically, this analysis returns `false` even though the tensor
2814  /// sizes may turn out to be equal at runtime.
2815  bool hasSameTensorSize(Value beforePadding,
2816  tensor::ExtractSliceOp afterTrimming) const {
2817  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2818  // result and CastOp operand.
2819  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2820  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2821  return true;
2822 
2823  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2824  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2825  // Only RankedTensorType supported.
2826  if (!t1 || !t2)
2827  return false;
2828  // Rank of both values must be the same.
2829  if (t1.getRank() != t2.getRank())
2830  return false;
2831 
2832  // All static dimensions must be the same. Mixed cases (e.g., dimension
2833  // static in `t1` but dynamic in `t2`) are not supported.
2834  for (unsigned i = 0; i < t1.getRank(); ++i) {
2835  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2836  return false;
2837  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2838  return false;
2839  }
2840 
2841  // Nothing more to check if all dimensions are static.
2842  if (t1.getNumDynamicDims() == 0)
2843  return true;
2844 
2845  // All dynamic sizes must be the same. The only supported case at the
2846  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2847  // thereof).
2848 
2849  // Apart from CastOp, only ExtractSliceOp is supported.
2850  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2851  if (!beforeSlice)
2852  return false;
2853 
2854  assert(static_cast<size_t>(t1.getRank()) ==
2855  beforeSlice.getMixedSizes().size());
2856  assert(static_cast<size_t>(t2.getRank()) ==
2857  afterTrimming.getMixedSizes().size());
2858 
2859  for (unsigned i = 0; i < t1.getRank(); ++i) {
2860  // Skip static dimensions.
2861  if (!t1.isDynamicDim(i))
2862  continue;
2863  auto size1 = beforeSlice.getMixedSizes()[i];
2864  auto size2 = afterTrimming.getMixedSizes()[i];
2865 
2866  // Case 1: Same value or same constant int.
2867  if (isEqualConstantIntOrValue(size1, size2))
2868  continue;
2869 
2870  // Other cases: Take a deeper look at defining ops of values.
2871  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2872  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2873  if (!v1 || !v2)
2874  return false;
2875 
2876  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2877  // CSE is run.)
2878  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2879  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2880  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2881  minOp1.getOperands() == minOp2.getOperands())
2882  continue;
2883 
2884  // Add additional cases as needed.
2885  }
2886 
2887  // All tests passed.
2888  return true;
2889  }
2890 };
2891 
2892 /// Returns the effective Pad value for the input op, provided it's a scalar.
2893 ///
2894 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2895 /// this Op performs padding, retrieve the padding value provided that it's
2896 /// a scalar and static/fixed for all the padded values. Returns an empty value
2897 /// otherwise.
2898 ///
2899 /// TODO: This is used twice (when checking vectorization pre-conditions and
2900 /// when vectorizing). Cache results instead of re-running.
2902  if (!op)
2903  return {};
2904 
2905  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2906  // being broadcast, provided that it's a scalar.
2907  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2908  auto source = bcast.getSource();
2909  if (llvm::dyn_cast<VectorType>(source.getType()))
2910  return {};
2911 
2912  return source;
2913  }
2914 
2915  // 2. linalg.fill - use the scalar input value that used to fill the output
2916  // tensor.
2917  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2918  return fill.getInputs()[0];
2919  }
2920 
2921  // 3. tensor.generateOp - can't guarantee the value is fixed without
2922  // analysing, bail out.
2923  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2924  return {};
2925  }
2926 
2927  // 4. vector.transfer_write - inspect the input vector that's written from. If
2928  // if contains a single value that has been broadcast (e.g. via
2929  // vector.broadcast), extract it, fail otherwise.
2930  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2931  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2932 
2933  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2934  // than the input tensor, then, provided it's constant, we'll extract the
2935  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2936  // TODO: Clarify the semantics when the input tensor is larger than the
2937  // destination.
2938  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2939  return getStaticPadVal(slice.getDest().getDefiningOp());
2940 
2941  return {};
2942 }
2943 
2944 static LogicalResult
2945 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2946  ArrayRef<int64_t> inputVectorSizes,
2947  SmallVectorImpl<Value> &newResults) {
2948  // TODO: Introduce a parent class that will handle the insertion point update.
2949  OpBuilder::InsertionGuard g(rewriter);
2950  rewriter.setInsertionPoint(sliceOp);
2951 
2952  TypedValue<RankedTensorType> source = sliceOp.getSource();
2953  auto sourceType = source.getType();
2954  auto resultType = sliceOp.getResultType();
2955 
2956  Value padValue = getStaticPadVal(sliceOp);
2957 
2958  if (!padValue) {
2959  auto elemType = sourceType.getElementType();
2960  padValue = rewriter.create<arith::ConstantOp>(
2961  sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2962  }
2963 
2964  // 2. Get the vector shape
2965  SmallVector<int64_t> vecShape;
2966  size_t rankDiff = resultType.getRank() - sourceType.getRank();
2967  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2968  if (!inputVectorSizes.empty()) {
2969  vecShape.push_back(inputVectorSizes[i]);
2970  } else if (!sourceType.isDynamicDim(i)) {
2971  vecShape.push_back(sourceType.getDimSize(i));
2972  } else if (!resultType.isDynamicDim(i)) {
2973  // Source shape is not statically known, but result shape is.
2974  // Vectorize with size of result shape. This may be larger than the
2975  // source size.
2976  // FIXME: Using rankDiff implies that the source tensor is inserted at
2977  // the end of the destination tensor. However, that's not required.
2978  vecShape.push_back(resultType.getDimSize(rankDiff + i));
2979  } else {
2980  // Neither source nor result dim of padOp is static. Cannot vectorize
2981  // the copy.
2982  return failure();
2983  }
2984  }
2985  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2986 
2987  // 3. Generate TransferReadOp + TransferWriteOp
2988  auto loc = sliceOp.getLoc();
2989 
2990  // Create read
2991  SmallVector<Value> readIndices(
2992  vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
2994  rewriter, loc, source, vecType.getShape(), padValue,
2995  /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
2996 
2997  // Create write
2998  auto writeIndices =
2999  getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3000  Operation *write =
3001  createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3002  writeIndices, inputVectorSizes.empty());
3003 
3004  // 4. Finalize
3005  newResults.push_back(write->getResult(0));
3006 
3007  return success();
3008 }
3009 
3010 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3011 /// ```
3012 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3013 /// %r = tensor.insert_slice %0
3014 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3015 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3016 /// ```
3017 /// is rewritten to:
3018 /// ```
3019 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
3020 /// : tensor<?x?xf32>, vector<17x5xf32>
3021 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3022 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3023 /// ```
3024 ///
3025 /// This rewrite is possible if:
3026 /// - Low padding is static 0.
3027 /// - `padOp` result shape is static.
3028 /// - The entire padded tensor is inserted.
3029 /// (Implies that sizes of `insertOp` are all static.)
3030 /// - Only unit strides in `insertOp`.
3031 /// - Single, scalar padding value.
3032 /// - `padOp` result not used as destination.
3034  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3036  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3037 
3038  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3039  tensor::InsertSliceOp insertOp) const override {
3040  // Low padding must be static 0.
3041  if (!padOp.hasZeroLowPad())
3042  return failure();
3043  // Only unit stride supported.
3044  if (!insertOp.hasUnitStride())
3045  return failure();
3046  // Pad value must be a constant.
3047  auto padValue = padOp.getConstantPaddingValue();
3048  if (!padValue)
3049  return failure();
3050  // Dynamic shapes not supported.
3051  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3052  return failure();
3053  // Pad result not used as destination.
3054  if (insertOp.getDest() == padOp.getResult())
3055  return failure();
3056 
3057  auto vecType = VectorType::get(padOp.getType().getShape(),
3058  padOp.getType().getElementType());
3059  unsigned vecRank = vecType.getRank();
3060  unsigned tensorRank = insertOp.getType().getRank();
3061 
3062  // Check if sizes match: Insert the entire tensor into most minor dims.
3063  // (No permutations allowed.)
3064  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3065  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3066  if (!llvm::all_of(
3067  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3068  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3069  }))
3070  return failure();
3071 
3072  // Insert the TransferReadOp and TransferWriteOp at the position of the
3073  // InsertSliceOp.
3074  rewriter.setInsertionPoint(insertOp);
3075 
3076  // Generate TransferReadOp: Read entire source tensor and add high
3077  // padding.
3078  SmallVector<Value> readIndices(
3079  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
3080  auto read = rewriter.create<vector::TransferReadOp>(
3081  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3082 
3083  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3084  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3085  // source must fit into the destination at the specified offsets.
3086  auto writeIndices = getValueOrCreateConstantIndexOp(
3087  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3088  SmallVector<bool> inBounds(vecRank, true);
3089  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3090  insertOp, read, insertOp.getDest(), writeIndices,
3091  ArrayRef<bool>{inBounds});
3092 
3093  return success();
3094  }
3095 };
3096 
3098  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3102  patterns.getContext(), baseBenefit.getBenefit() + 1);
3103 }
3104 
3105 //----------------------------------------------------------------------------//
3106 // Forwarding patterns
3107 //----------------------------------------------------------------------------//
3108 
3109 /// Check whether there is any interleaved use of any `values` between
3110 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3111 /// is in a different block.
3112 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3113  ValueRange values) {
3114  if (firstOp->getBlock() != secondOp->getBlock() ||
3115  !firstOp->isBeforeInBlock(secondOp)) {
3116  LDBG("interleavedUses precondition failed, firstOp: "
3117  << *firstOp << ", second op: " << *secondOp << "\n");
3118  return true;
3119  }
3120  for (auto v : values) {
3121  for (auto &u : v.getUses()) {
3122  Operation *owner = u.getOwner();
3123  if (owner == firstOp || owner == secondOp)
3124  continue;
3125  // TODO: this is too conservative, use dominance info in the future.
3126  if (owner->getBlock() == firstOp->getBlock() &&
3127  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3128  continue;
3129  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
3130  << ", second op: " << *secondOp << "\n");
3131  return true;
3132  }
3133  }
3134  return false;
3135 }
3136 
3137 /// Return the unique subview use of `v` if it is indeed unique, null
3138 /// otherwise.
3139 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3140  memref::SubViewOp subViewOp;
3141  for (auto &u : v.getUses()) {
3142  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3143  if (subViewOp)
3144  return memref::SubViewOp();
3145  subViewOp = newSubViewOp;
3146  }
3147  }
3148  return subViewOp;
3149 }
3150 
3151 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3152 /// when available.
3154  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3155 
3156  // TODO: support mask.
3157  if (xferOp.getMask())
3158  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3159 
3160  // Transfer into `view`.
3161  Value viewOrAlloc = xferOp.getBase();
3162  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3163  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3164  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3165 
3166  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3167  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3168  if (!subViewOp)
3169  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3170  Value subView = subViewOp.getResult();
3171 
3172  // Find the copy into `subView` without interleaved uses.
3173  memref::CopyOp copyOp;
3174  for (auto &u : subView.getUses()) {
3175  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3176  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3177  if (newCopyOp.getTarget() != subView)
3178  continue;
3179  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3180  continue;
3181  copyOp = newCopyOp;
3182  break;
3183  }
3184  }
3185  if (!copyOp)
3186  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3187 
3188  // Find the fill into `viewOrAlloc` without interleaved uses before the
3189  // copy.
3190  FillOp maybeFillOp;
3191  for (auto &u : viewOrAlloc.getUses()) {
3192  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3193  assert(isa<MemRefType>(newFillOp.output().getType()));
3194  if (newFillOp.output() != viewOrAlloc)
3195  continue;
3196  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3197  continue;
3198  maybeFillOp = newFillOp;
3199  break;
3200  }
3201  }
3202  // Ensure padding matches.
3203  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3204  return rewriter.notifyMatchFailure(xferOp,
3205  "padding value does not match fill");
3206 
3207  // `in` is the subview that memref.copy reads. Replace it.
3208  Value in = copyOp.getSource();
3209 
3210  // memref.copy + linalg.fill can be used to create a padded local buffer.
3211  // The `masked` attribute is only valid on this padded buffer.
3212  // When forwarding to vector.transfer_read, the attribute must be reset
3213  // conservatively.
3214  auto vectorType = xferOp.getVectorType();
3215  Value res = rewriter.create<vector::TransferReadOp>(
3216  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3217  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3218  rewriter.getBoolArrayAttr(
3219  SmallVector<bool>(vectorType.getRank(), false)));
3220 
3221  if (maybeFillOp)
3222  rewriter.eraseOp(maybeFillOp);
3223  rewriter.eraseOp(copyOp);
3224  rewriter.replaceOp(xferOp, res);
3225 
3226  return success();
3227 }
3228 
3229 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3230 /// when available.
3232  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3233  // TODO: support mask.
3234  if (xferOp.getMask())
3235  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3236 
3237  // Transfer into `viewOrAlloc`.
3238  Value viewOrAlloc = xferOp.getBase();
3239  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3240  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3241  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3242 
3243  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3244  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3245  if (!subViewOp)
3246  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3247  Value subView = subViewOp.getResult();
3248 
3249  // Find the copy from `subView` without interleaved uses.
3250  memref::CopyOp copyOp;
3251  for (auto &u : subViewOp.getResult().getUses()) {
3252  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3253  if (newCopyOp.getSource() != subView)
3254  continue;
3255  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3256  continue;
3257  copyOp = newCopyOp;
3258  break;
3259  }
3260  }
3261  if (!copyOp)
3262  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3263 
3264  // `out` is the subview copied into that we replace.
3265  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3266  Value out = copyOp.getTarget();
3267 
3268  // Forward vector.transfer into copy.
3269  // memref.copy + linalg.fill can be used to create a padded local buffer.
3270  // The `masked` attribute is only valid on this padded buffer.
3271  // When forwarding to vector.transfer_write, the attribute must be reset
3272  // conservatively.
3273  auto vector = xferOp.getVector();
3274  rewriter.create<vector::TransferWriteOp>(
3275  xferOp.getLoc(), vector, out, xferOp.getIndices(),
3276  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3278  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3279 
3280  rewriter.eraseOp(copyOp);
3281  rewriter.eraseOp(xferOp);
3282 
3283  return success();
3284 }
3285 
3286 //===----------------------------------------------------------------------===//
3287 // Convolution vectorization patterns
3288 //===----------------------------------------------------------------------===//
3289 
3290 template <int N>
3291 static void bindShapeDims(ShapedType shapedType) {}
3292 
3293 template <int N, typename IntTy, typename... IntTy2>
3294 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3295  val = shapedType.getShape()[N];
3296  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3297 }
3298 
3299 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3300 template <typename... IntTy>
3301 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3302  bindShapeDims<0>(shapedType, vals...);
3303 }
3304 
3305 namespace {
3306 /// Generate a vector implementation for either:
3307 /// ```
3308 /// Op def: ( w, kw )
3309 /// Iters: ({Par(), Red()})
3310 /// Layout: {{w + kw}, {kw}, {w}}
3311 /// ```
3312 /// kw is unrolled.
3313 ///
3314 /// or
3315 ///
3316 /// ```
3317 /// Op def: ( n, w, c, kw, f )
3318 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3319 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3320 /// ```
3321 /// kw is unrolled, w is unrolled iff dilationW > 1.
3322 ///
3323 /// or
3324 ///
3325 /// ```
3326 /// Op def: ( n, c, w, f, kw )
3327 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3328 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3329 /// ```
3330 /// kw is unrolled, w is unrolled iff dilationW > 1.
3331 ///
3332 /// or
3333 ///
3334 /// ```
3335 /// Op def: ( n, w, c, kw )
3336 /// Iters: ({Par(), Par(), Par(), Red()})
3337 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3338 /// ```
3339 /// kw is unrolled, w is unrolled iff dilationW > 1.
3340 struct Conv1DGenerator
3341  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3342  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3343  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3344 
3345  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3346  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3347  resShaped = linalgOp.getDpsInitOperand(0)->get();
3348  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3349  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3350  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3351 
3352  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3353  redOp = reduceOp->getName().getIdentifier();
3354 
3355  setConvOperationKind(reduceOp);
3356 
3357  auto maybeKind = getCombinerOpKind(reduceOp);
3358  reductionKind = maybeKind.value();
3359 
3360  // The ConvolutionOpInterface gives us guarantees of existence for
3361  // strides/dilations. However, we do not need to rely on those, we can
3362  // simply use them if present, otherwise use the default and let the generic
3363  // conv. matcher in the ConvGenerator succeed or fail.
3364  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3365  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3366  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3367  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3368  }
3369 
3370  /// Generate a vector implementation for:
3371  /// ```
3372  /// Op def: ( w, kw )
3373  /// Iters: ({Par(), Red()})
3374  /// Layout: {{w + kw}, {kw}, {w}}
3375  /// ```
3376  /// kw is always unrolled.
3377  ///
3378  /// or
3379  ///
3380  /// ```
3381  /// Op def: ( n, w, c, kw, f )
3382  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3383  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3384  /// ```
3385  /// kw is always unrolled.
3386  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3387  /// > 1.
3388  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3389  int64_t nSize, wSize, cSize, kwSize, fSize;
3390  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3391  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3392  switch (conv1DOpOrder) {
3393  case Conv1DOpOrder::W:
3394  // Initialize unused dimensions
3395  nSize = fSize = cSize = 0;
3396  // out{W}
3397  bindShapeDims(resShapedType, wSize);
3398  // kernel{kw}
3399  bindShapeDims(rhsShapedType, kwSize);
3400  lhsShape = {// iw = ow + kw - 1
3401  // (i.e. 16 convolved with 3 -> 14)
3402  (wSize + kwSize - 1)};
3403  rhsShape = {kwSize};
3404  resShape = {wSize};
3405  break;
3406  case Conv1DOpOrder::Nwc:
3407  // out{n, w, f}
3408  bindShapeDims(resShapedType, nSize, wSize, fSize);
3409  switch (oper) {
3410  case ConvOperationKind::Conv:
3411  // kernel{kw, c, f}
3412  bindShapeDims(rhsShapedType, kwSize, cSize);
3413  break;
3414  case ConvOperationKind::Pool:
3415  // kernel{kw}
3416  bindShapeDims(rhsShapedType, kwSize);
3417  cSize = fSize;
3418  break;
3419  }
3420  lhsShape = {nSize,
3421  // iw = ow * sw + kw * dw - 1
3422  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3423  // Perform the proper inclusive -> exclusive -> inclusive.
3424  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3425  1,
3426  cSize};
3427  switch (oper) {
3428  case ConvOperationKind::Conv:
3429  rhsShape = {kwSize, cSize, fSize};
3430  break;
3431  case ConvOperationKind::Pool:
3432  rhsShape = {kwSize};
3433  break;
3434  }
3435  resShape = {nSize, wSize, fSize};
3436  break;
3437  case Conv1DOpOrder::Ncw:
3438  // out{n, f, w}
3439  bindShapeDims(resShapedType, nSize, fSize, wSize);
3440  switch (oper) {
3441  case ConvOperationKind::Conv:
3442  // kernel{f, c, kw}
3443  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3444  break;
3445  case ConvOperationKind::Pool:
3446  // kernel{kw}
3447  bindShapeDims(rhsShapedType, kwSize);
3448  cSize = fSize;
3449  break;
3450  }
3451  lhsShape = {nSize, cSize,
3452  // iw = ow * sw + kw * dw - 1
3453  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3454  // Perform the proper inclusive -> exclusive -> inclusive.
3455  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3456  1};
3457  switch (oper) {
3458  case ConvOperationKind::Conv:
3459  rhsShape = {fSize, cSize, kwSize};
3460  break;
3461  case ConvOperationKind::Pool:
3462  rhsShape = {kwSize};
3463  break;
3464  }
3465  resShape = {nSize, fSize, wSize};
3466  break;
3467  }
3468 
3469  vector::TransferWriteOp write;
3470  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3471 
3472  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3473  // When strideW == 1, we can batch the contiguous loads and avoid
3474  // unrolling
3475  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3476 
3477  Type lhsEltType = lhsShapedType.getElementType();
3478  Type rhsEltType = rhsShapedType.getElementType();
3479  Type resEltType = resShapedType.getElementType();
3480  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3481  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3482  auto resType = VectorType::get(resShape, resEltType);
3483  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3484  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3485  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3486  SmallVector<Value> resPadding(resShape.size(), zero);
3487 
3488  // Read the whole lhs, rhs and res in one shot (with zero padding).
3489  Value lhs = rewriter.create<vector::TransferReadOp>(
3490  loc, lhsType, lhsShaped, lhsPadding,
3491  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3492  // This is needed only for Conv.
3493  Value rhs = nullptr;
3494  if (oper == ConvOperationKind::Conv)
3495  rhs = rewriter.create<vector::TransferReadOp>(
3496  loc, rhsType, rhsShaped, rhsPadding,
3497  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3498  Value res = rewriter.create<vector::TransferReadOp>(
3499  loc, resType, resShaped, resPadding,
3500  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3501 
3502  // The base vectorization case for channeled convolution is input:
3503  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3504  // vectorization case, we do pre transpose on input, weight, and output.
3505  switch (conv1DOpOrder) {
3506  case Conv1DOpOrder::W:
3507  case Conv1DOpOrder::Nwc:
3508  // Base case, so no transposes necessary.
3509  break;
3510  case Conv1DOpOrder::Ncw: {
3511  // To match base vectorization case, we pre-transpose current case.
3512  // ncw -> nwc
3513  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3514  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3515  // fcw -> wcf
3516  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3517 
3518  // This is needed only for Conv.
3519  if (oper == ConvOperationKind::Conv)
3520  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3521  // nfw -> nwf
3522  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3523  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3524  break;
3525  }
3526  }
3527 
3528  //===------------------------------------------------------------------===//
3529  // Begin vector-only rewrite part
3530  //===------------------------------------------------------------------===//
3531  // Unroll along kw and read slices of lhs and rhs.
3532  SmallVector<Value> lhsVals, rhsVals, resVals;
3533  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3534  kwSize, strideW, dilationW, wSizeStep,
3535  isSingleChanneled);
3536  // Do not do for pooling.
3537  if (oper == ConvOperationKind::Conv)
3538  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3539  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3540  wSizeStep, isSingleChanneled);
3541 
3542  auto linearIndex = [&](int64_t kw, int64_t w) {
3543  return kw * (wSize / wSizeStep) + w;
3544  };
3545 
3546  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3547  // or perform outerproduct for non-channeled convolution or perform simple
3548  // arith operation for pooling
3549  for (int64_t kw = 0; kw < kwSize; ++kw) {
3550  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3551  switch (oper) {
3552  case ConvOperationKind::Conv:
3553  if (isSingleChanneled) {
3554  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3555  lhsVals[linearIndex(kw, w)],
3556  rhsVals[kw], resVals[w]);
3557  } else {
3558  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3559  lhsVals[linearIndex(kw, w)],
3560  rhsVals[kw], resVals[w]);
3561  }
3562  break;
3563  case ConvOperationKind::Pool:
3564  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3565  resVals[w]);
3566  break;
3567  }
3568  }
3569  }
3570 
3571  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3572  isSingleChanneled);
3573  //===------------------------------------------------------------------===//
3574  // End vector-only rewrite part
3575  //===------------------------------------------------------------------===//
3576 
3577  // The base vectorization case for channeled convolution is output:
3578  // {n,w,f} To reuse the result from base pattern vectorization case, we
3579  // post transpose the base case result.
3580  switch (conv1DOpOrder) {
3581  case Conv1DOpOrder::W:
3582  case Conv1DOpOrder::Nwc:
3583  // Base case, so no transposes necessary.
3584  break;
3585  case Conv1DOpOrder::Ncw: {
3586  // nwf -> nfw
3587  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3588  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3589  break;
3590  }
3591  }
3592 
3593  return rewriter
3594  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3595  .getOperation();
3596  }
3597 
3598  // Take a value and widen to have the same element type as `ty`.
3599  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3600  const Type srcElementType = getElementTypeOrSelf(val.getType());
3601  const Type dstElementType = getElementTypeOrSelf(ty);
3602  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3603  if (srcElementType == dstElementType)
3604  return val;
3605 
3606  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3607  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3608  const Type dstType =
3609  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3610 
3611  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3612  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3613  }
3614 
3615  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3616  srcWidth < dstWidth)
3617  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3618 
3619  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3620  srcWidth < dstWidth)
3621  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3622 
3623  assert(false && "unhandled promotion case");
3624  return nullptr;
3625  }
3626 
3627  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3628  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3629  Value lhs, Value rhs, Value res) {
3630  vector::IteratorType par = vector::IteratorType::parallel;
3631  vector::IteratorType red = vector::IteratorType::reduction;
3632  AffineExpr n, w, f, c;
3633  bindDims(ctx, n, w, f, c);
3634  lhs = promote(rewriter, loc, lhs, res.getType());
3635  rhs = promote(rewriter, loc, rhs, res.getType());
3636  auto contrationOp = rewriter.create<vector::ContractionOp>(
3637  loc, lhs, rhs, res,
3638  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3639  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3640  contrationOp.setKind(reductionKind);
3641  return contrationOp;
3642  }
3643 
3644  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3645  // convolution.
3646  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3647  Value lhs, Value rhs, Value res) {
3648  return rewriter.create<vector::OuterProductOp>(
3649  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3650  }
3651 
3652  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3653  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3654  Value res) {
3655  if (isPoolExt)
3656  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3657  return rewriter
3658  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3659  ->getResult(0);
3660  }
3661 
3662  /// Generate a vector implementation for:
3663  /// ```
3664  /// Op def: ( n, w, c, kw)
3665  /// Iters: ({Par(), Par(), Par(), Red()})
3666  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3667  /// ```
3668  /// kw is always unrolled.
3669  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3670  /// > 1.
3671  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3672  bool channelDimScalableFlag,
3673  bool flatten) {
3674  bool scalableChDim = false;
3675  bool useMasking = false;
3676  int64_t nSize, wSize, cSize, kwSize;
3677  // kernel{kw, c}
3678  bindShapeDims(rhsShapedType, kwSize, cSize);
3679  if (ShapedType::isDynamic(cSize)) {
3680  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3681  cSize = channelDimVecSize;
3682  // Scalable vectors are only used when both conditions are met:
3683  // 1. channel dim is dynamic
3684  // 2. channelDimScalableFlag is set
3685  scalableChDim = channelDimScalableFlag;
3686  useMasking = true;
3687  }
3688 
3689  assert(!(useMasking && flatten) &&
3690  "Unsupported flattened conv with dynamic shapes");
3691 
3692  // out{n, w, c}
3693  bindShapeDims(resShapedType, nSize, wSize);
3694 
3695  vector::TransferWriteOp write;
3696  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3697 
3698  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3699  // When strideW == 1, we can batch the contiguous loads and avoid
3700  // unrolling
3701  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3702 
3703  Type lhsEltType = lhsShapedType.getElementType();
3704  Type rhsEltType = rhsShapedType.getElementType();
3705  Type resEltType = resShapedType.getElementType();
3706  VectorType lhsType = VectorType::get(
3707  {nSize,
3708  // iw = ow * sw + kw * dw - 1
3709  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3710  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3711  cSize},
3712  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3713  VectorType rhsType =
3714  VectorType::get({kwSize, cSize}, rhsEltType,
3715  /*scalableDims=*/{false, scalableChDim});
3716  VectorType resType =
3717  VectorType::get({nSize, wSize, cSize}, resEltType,
3718  /*scalableDims=*/{false, false, scalableChDim});
3719 
3720  // Masks the input xfer Op along the channel dim, iff the corresponding
3721  // scalable flag is set.
3722  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3723  ArrayRef<bool> scalableDims,
3724  Operation *opToMask) {
3725  if (!useMasking)
3726  return opToMask;
3727  auto maskType =
3728  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3729 
3730  SmallVector<bool> inBounds(maskShape.size(), true);
3731  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3732  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3733  rewriter.getBoolArrayAttr(inBounds));
3734 
3736  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3737 
3738  Value maskOp =
3739  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3740 
3741  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3742  };
3743 
3744  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3745  // 0].
3746  Value lhs = rewriter.create<vector::TransferReadOp>(
3747  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3748  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3749  auto maybeMaskedLhs = maybeMaskXferOp(
3750  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3751 
3752  // Read rhs slice of size {kw, c} @ [0, 0].
3753  Value rhs = rewriter.create<vector::TransferReadOp>(
3754  loc, rhsType, rhsShaped, ValueRange{zero, zero},
3755  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3756  auto maybeMaskedRhs = maybeMaskXferOp(
3757  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3758 
3759  // Read res slice of size {n, w, c} @ [0, 0, 0].
3760  Value res = rewriter.create<vector::TransferReadOp>(
3761  loc, resType, resShaped, ValueRange{zero, zero, zero},
3762  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3763  auto maybeMaskedRes = maybeMaskXferOp(
3764  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3765 
3766  //===------------------------------------------------------------------===//
3767  // Begin vector-only rewrite part
3768  //===------------------------------------------------------------------===//
3769  // Unroll along kw and read slices of lhs and rhs.
3770  SmallVector<Value> lhsVals, rhsVals, resVals;
3771  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3772  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3773 
3774  // Extract lhs slice of size {n, wSizeStep, c}
3775  // @ [0, sw * w + dw * kw, 0].
3776  for (int64_t kw = 0; kw < kwSize; ++kw) {
3777  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3778  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3779  loc, maybeMaskedLhs->getResult(0),
3780  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3781  inOutSliceSizes, inOutStrides));
3782  }
3783  }
3784  // Extract rhs slice of size {c} @ [kw].
3785  for (int64_t kw = 0; kw < kwSize; ++kw) {
3786  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3787  loc, maybeMaskedRhs->getResult(0),
3788  /*offsets=*/ArrayRef<int64_t>{kw}));
3789  }
3790  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3791  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3792  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3793  loc, maybeMaskedRes->getResult(0),
3794  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3795  inOutStrides));
3796  }
3797 
3798  auto linearIndex = [&](int64_t kw, int64_t w) {
3799  return kw * (wSize / wSizeStep) + w;
3800  };
3801 
3802  // Note - the scalable flags are ignored as flattening combined with
3803  // scalable vectorization is not supported.
3804  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3805  auto lhsTypeAfterFlattening =
3806  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3807  auto resTypeAfterFlattening =
3808  VectorType::get(inOutFlattenSliceSizes, resEltType);
3809 
3810  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3811  for (int64_t kw = 0; kw < kwSize; ++kw) {
3812  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3813  Value lhsVal = lhsVals[linearIndex(kw, w)];
3814  Value resVal = resVals[w];
3815  if (flatten) {
3816  // Flatten the input and output vectors (collapse the channel
3817  // dimension)
3818  lhsVal = rewriter.create<vector::ShapeCastOp>(
3819  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3820  resVal = rewriter.create<vector::ShapeCastOp>(
3821  loc, resTypeAfterFlattening, resVals[w]);
3822  }
3823  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3824  rhsVals[kw], resVal, flatten);
3825  if (flatten) {
3826  // Un-flatten the output vector (restore the channel dimension)
3827  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3828  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3829  }
3830  }
3831  }
3832 
3833  // Its possible we failed to create the Fma.
3834  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3835  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3836  for (auto &collection :
3837  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3838  for (Value v : collection)
3839  rewriter.eraseOp(v.getDefiningOp());
3840  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3841  }
3842 
3843  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3844  // This does not depend on kw.
3845  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3846  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3847  loc, resVals[w], maybeMaskedRes->getResult(0),
3848  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3849  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3850  }
3851  //===------------------------------------------------------------------===//
3852  // End vector-only rewrite part
3853  //===------------------------------------------------------------------===//
3854 
3855  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3856  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3857  loc, maybeMaskedRes->getResult(0), resShaped,
3858  ValueRange{zero, zero, zero});
3859  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3860  resOut);
3861  }
3862 
3863  /// Lower:
3864  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3865  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3866  /// to MulAcc.
3867  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3868  Value lhs, Value rhs, Value res,
3869  bool flatten) {
3870  auto rhsTy = cast<ShapedType>(rhs.getType());
3871  auto resTy = cast<ShapedType>(res.getType());
3872 
3873  // TODO(suderman): Change this to use a vector.ima intrinsic.
3874  lhs = promote(rewriter, loc, lhs, resTy);
3875 
3876  if (flatten) {
3877  // NOTE: This following logic won't work for scalable vectors. For this
3878  // reason, "flattening" is not supported when shapes are dynamic (this
3879  // should be captured by one of the pre-conditions).
3880 
3881  // There are two options for handling the filter:
3882  // * shape_cast(broadcast(filter))
3883  // * broadcast(shuffle(filter))
3884  // Opt for the option without shape_cast to simplify the codegen.
3885  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3886  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3887 
3888  SmallVector<int64_t, 16> indices;
3889  for (int i = 0; i < resSize / rhsSize; ++i) {
3890  for (int j = 0; j < rhsSize; ++j)
3891  indices.push_back(j);
3892  }
3893 
3894  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3895  }
3896  // Broadcast the filter to match the output vector
3897  rhs = rewriter.create<vector::BroadcastOp>(
3898  loc, resTy.clone(rhsTy.getElementType()), rhs);
3899 
3900  rhs = promote(rewriter, loc, rhs, resTy);
3901 
3902  if (!lhs || !rhs)
3903  return nullptr;
3904 
3905  if (isa<FloatType>(resTy.getElementType()))
3906  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3907 
3908  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3909  return rewriter.create<arith::AddIOp>(loc, mul, res);
3910  }
3911 
3912  /// Entry point for non-channeled convolution:
3913  /// {{w + kw}, {kw}, {w}}
3914  FailureOr<Operation *> generateNonChanneledConv() {
3915  AffineExpr w, kw;
3916  bindDims(ctx, w, kw);
3917  if (!iters({Par(), Red()}))
3918  return rewriter.notifyMatchFailure(op,
3919  "failed to match conv::W 1-par 1-red");
3920 
3921  // No transposition needed.
3922  if (layout({/*lhsIndex*/ {w + kw},
3923  /*rhsIndex*/ {kw},
3924  /*resIndex*/ {w}}))
3925  return conv(Conv1DOpOrder::W);
3926 
3927  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3928  }
3929 
3930  /// Entry point that transposes into the common form:
3931  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3932  FailureOr<Operation *> generateNwcConv() {
3933  AffineExpr n, w, f, kw, c;
3934  bindDims(ctx, n, w, f, kw, c);
3935  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3936  return rewriter.notifyMatchFailure(
3937  op, "failed to match conv::Nwc 3-par 2-red");
3938 
3939  // No transposition needed.
3940  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3941  /*rhsIndex*/ {kw, c, f},
3942  /*resIndex*/ {n, w, f}}))
3943  return conv(Conv1DOpOrder::Nwc);
3944 
3945  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3946  }
3947 
3948  /// Entry point that transposes into the common form:
3949  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3950  FailureOr<Operation *> generateNcwConv() {
3951  AffineExpr n, w, f, kw, c;
3952  bindDims(ctx, n, f, w, c, kw);
3953  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3954  return rewriter.notifyMatchFailure(
3955  op, "failed to match conv::Ncw 3-par 2-red");
3956 
3957  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3958  /*rhsIndex*/ {f, c, kw},
3959  /*resIndex*/ {n, f, w}}))
3960  return conv(Conv1DOpOrder::Ncw);
3961 
3962  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3963  }
3964 
3965  /// Entry point that transposes into the common form:
3966  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3967  FailureOr<Operation *> generateNwcPooling() {
3968  AffineExpr n, w, c, kw;
3969  bindDims(ctx, n, w, c, kw);
3970  if (!iters({Par(), Par(), Par(), Red()}))
3971  return rewriter.notifyMatchFailure(op,
3972  "failed to match pooling 3-par 1-red");
3973 
3974  // No transposition needed.
3975  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3976  /*rhsIndex*/ {kw},
3977  /*resIndex*/ {n, w, c}}))
3978  return conv(Conv1DOpOrder::Nwc);
3979 
3980  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3981  }
3982 
3983  /// Entry point that transposes into the common form:
3984  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3985  FailureOr<Operation *> generateNcwPooling() {
3986  AffineExpr n, w, c, kw;
3987  bindDims(ctx, n, c, w, kw);
3988  if (!iters({Par(), Par(), Par(), Red()}))
3989  return rewriter.notifyMatchFailure(op,
3990  "failed to match pooling 3-par 1-red");
3991 
3992  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3993  /*rhsIndex*/ {kw},
3994  /*resIndex*/ {n, c, w}}))
3995  return conv(Conv1DOpOrder::Ncw);
3996 
3997  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3998  }
3999 
4000  /// Entry point that transposes into the common form:
4001  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4002  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4003  bool vecChDimScalableFlag = false,
4004  bool flatten = false) {
4005  AffineExpr n, w, c, kw;
4006  bindDims(ctx, n, w, c, kw);
4007  if (!iters({Par(), Par(), Par(), Red()}))
4008  return rewriter.notifyMatchFailure(
4009  op, "failed to match depthwise::Nwc conv 3-par 1-red");
4010 
4011  // No transposition needed.
4012  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4013  /*rhsIndex*/ {kw, c},
4014  /*resIndex*/ {n, w, c}}))
4015  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4016 
4017  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4018  }
4019 
4020 private:
4021  ConvOperationKind oper = ConvOperationKind::Conv;
4022  StringAttr redOp;
4023  StringAttr poolExtOp;
4024  bool isPoolExt = false;
4025  int strideW, dilationW;
4026  Value lhsShaped, rhsShaped, resShaped;
4027  ShapedType lhsShapedType, rhsShapedType, resShapedType;
4028  vector::CombiningKind reductionKind;
4029 
4030  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4031  void setConvOperationKind(Operation *reduceOp) {
4032  int numBlockArguments =
4033  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4034  if (numBlockArguments == 1) {
4035  // Will be convolution if feeder is a MulOp.
4036  // A strength reduced version of MulOp for i1 type is AndOp which is also
4037  // supported. Otherwise, it can be pooling. This strength reduction logic
4038  // is in `buildBinaryFn` helper in the Linalg dialect.
4039  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4040  llvm::IsaPred<BlockArgument>);
4041  Operation *feedOp = (*feedValIt).getDefiningOp();
4042  if (isCastOfBlockArgument(feedOp)) {
4043  oper = ConvOperationKind::Pool;
4044  isPoolExt = true;
4045  poolExtOp = feedOp->getName().getIdentifier();
4046  return;
4047  }
4048  oper = ConvOperationKind::Conv;
4049  return;
4050  }
4051  // numBlockArugments == 2 and this is a pooling op.
4052  oper = ConvOperationKind::Pool;
4053  isPoolExt = false;
4054  }
4055 };
4056 } // namespace
4057 
4058 /// Helper function to vectorize a LinalgOp with convolution semantics.
4059 // TODO: extend the generic vectorization to support windows and drop this.
4060 static FailureOr<Operation *> vectorizeConvolution(
4061  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4062  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4063  Conv1DGenerator conv1dGen(rewriter, op);
4064  auto res = conv1dGen.generateNonChanneledConv();
4065  if (succeeded(res))
4066  return res;
4067  res = conv1dGen.generateNwcConv();
4068  if (succeeded(res))
4069  return res;
4070  res = conv1dGen.generateNcwConv();
4071  if (succeeded(res))
4072  return res;
4073  res = conv1dGen.generateNwcPooling();
4074  if (succeeded(res))
4075  return res;
4076  res = conv1dGen.generateNcwPooling();
4077  if (succeeded(res))
4078  return res;
4079 
4080  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4081  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4082  // masked/scalable) is the channel dim (i.e. the trailing dim).
4083  uint64_t vecChDimSize = ShapedType::kDynamic;
4084  bool vecChDimScalableFlag = false;
4085  if (!inputVecSizes.empty()) {
4086  // Only use the input vector size corresponding to the channel dim. Other
4087  // vector dims will be inferred from the Ops.
4088  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4089  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4090  "Not a 1D depthwise conv!");
4091  size_t chDimIdx =
4093  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4094  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4095 
4096  vecChDimSize = inputVecSizes[chDimIdx];
4097  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4098  }
4099  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4100  flatten1DDepthwiseConv);
4101 }
4102 
4105 
4106  LogicalResult matchAndRewrite(LinalgOp op,
4107  PatternRewriter &rewriter) const override {
4108  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4109  if (failed(resultOrFail))
4110  return failure();
4111  Operation *newOp = *resultOrFail;
4112  if (newOp->getNumResults() == 0) {
4113  rewriter.eraseOp(op.getOperation());
4114  return success();
4115  }
4116  assert(newOp->getNumResults() == 1 && "expected single result");
4117  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4118  return success();
4119  }
4120 };
4121 
4124  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4125 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4710
union mlir::linalg::@1221::ArityGroupAndKind::Kind kind
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4709
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4708
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult reductionPreconditions(LinalgOp op)
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:138
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:600
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:665
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:97
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:218
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:301
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:200
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:651
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2799
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:332
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Definition: Transforms.h:868
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.