MLIR  21.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Vectorize tensor::InsertSliceOp with:
63 /// * vector::TransferReadOp + vector::TransferWriteOp
64 /// The vector sizes are either:
65 /// * user-provided in `inputVectorSizes`, or
66 /// * inferred from the static dims in the input and output tensors.
67 /// Bails out if:
68 /// * vector sizes are not user-provided, and
69 /// * at least one dim is dynamic (in both the input and output tensors).
70 ///
71 /// Before:
72 /// !t_in_type = tensor<1x2x3xf32>
73 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
74 /// !v_type = vector<1x2x3xf32>
75 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
76 /// into !t_out_type
77 /// After:
78 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
79 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
80 static LogicalResult
81 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
82  ArrayRef<int64_t> inputVectorSizes,
83  SmallVectorImpl<Value> &newResults);
84 
85 /// Returns the effective Pad value for the input op, provided it's a scalar.
86 ///
87 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
88 /// this Op performs padding, retrieve the padding value provided that it's
89 /// a scalar and static/fixed for all the padded values. Returns an empty value
90 /// otherwise.
91 static Value getStaticPadVal(Operation *op);
92 
93 /// Return the unique instance of OpType in `block` if it is indeed unique.
94 /// Return null if none or more than 1 instances exist.
95 template <typename OpType>
96 static OpType getSingleOpOfType(Block &block) {
97  OpType res;
98  block.walk([&](OpType op) {
99  if (res) {
100  res = nullptr;
101  return WalkResult::interrupt();
102  }
103  res = op;
104  return WalkResult::advance();
105  });
106  return res;
107 }
108 
109 /// Helper function to extract the input slices after filter is unrolled along
110 /// kw.
111 static SmallVector<Value>
113  int64_t nSize, int64_t wSize, int64_t cSize,
114  int64_t kwSize, int strideW, int dilationW,
115  int64_t wSizeStep, bool isSingleChanneled) {
116  SmallVector<Value> result;
117  if (isSingleChanneled) {
118  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
119  // convolution.
120  SmallVector<int64_t> sizes = {wSizeStep};
121  SmallVector<int64_t> strides = {1};
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  for (int64_t w = 0; w < wSize; w += wSizeStep) {
124  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
125  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
126  }
127  }
128  } else {
129  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
130  // for channeled convolution.
131  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
132  SmallVector<int64_t> strides = {1, 1, 1};
133  for (int64_t kw = 0; kw < kwSize; ++kw) {
134  for (int64_t w = 0; w < wSize; w += wSizeStep) {
135  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
136  loc, input,
137  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
138  sizes, strides));
139  }
140  }
141  }
142  return result;
143 }
144 
145 /// Helper function to extract the filter slices after filter is unrolled along
146 /// kw.
148  Location loc, Value filter,
149  int64_t kwSize) {
150  SmallVector<Value> result;
151  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
152  // non-chanelled convolution] @ [kw].
153  for (int64_t kw = 0; kw < kwSize; ++kw) {
154  result.push_back(rewriter.create<vector::ExtractOp>(
155  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
156  }
157  return result;
158 }
159 
160 /// Helper function to extract the result slices after filter is unrolled along
161 /// kw.
162 static SmallVector<Value>
164  int64_t nSize, int64_t wSize, int64_t fSize,
165  int64_t wSizeStep, bool isSingleChanneled) {
166  SmallVector<Value> result;
167  if (isSingleChanneled) {
168  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
169  SmallVector<int64_t> sizes = {wSizeStep};
170  SmallVector<int64_t> strides = {1};
171  for (int64_t w = 0; w < wSize; w += wSizeStep) {
172  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
173  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
174  }
175  } else {
176  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
177  // convolution.
178  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
179  SmallVector<int64_t> strides = {1, 1, 1};
180  for (int64_t w = 0; w < wSize; w += wSizeStep) {
181  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
182  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
183  }
184  }
185  return result;
186 }
187 
188 /// Helper function to insert the computed result slices.
190  Value res, int64_t wSize, int64_t wSizeStep,
191  SmallVectorImpl<Value> &resVals,
192  bool isSingleChanneled) {
193 
194  if (isSingleChanneled) {
195  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196  // This does not depend on kw.
197  SmallVector<int64_t> strides = {1};
198  for (int64_t w = 0; w < wSize; w += wSizeStep) {
199  res = rewriter.create<vector::InsertStridedSliceOp>(
200  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
201  }
202  } else {
203  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
204  // convolution. This does not depend on kw.
205  SmallVector<int64_t> strides = {1, 1, 1};
206  for (int64_t w = 0; w < wSize; w += wSizeStep) {
207  res = rewriter.create<vector::InsertStridedSliceOp>(
208  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
209  strides);
210  }
211  }
212  return res;
213 }
214 
215 /// Contains the vectorization state and related methods used across the
216 /// vectorization process of a given operation.
218  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
219 
220  /// Initializes the vectorization state, including the computation of the
221  /// canonical vector shape for vectorization.
222  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
223  ArrayRef<int64_t> inputVectorSizes,
224  ArrayRef<bool> inputScalableVecDims);
225 
226  /// Returns the canonical vector shape used to vectorize the iteration space.
227  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
228 
229  /// Returns the vector dimensions that are scalable in the canonical vector
230  /// shape.
231  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
232 
233  /// Returns a vector type of the provided `elementType` with the canonical
234  /// vector shape and the corresponding fixed/scalable dimensions bit. If
235  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
236  /// accordingly.
238  Type elementType,
239  std::optional<AffineMap> dimPermutation = std::nullopt) const {
241  SmallVector<bool> scalableDims;
242  if (dimPermutation.has_value()) {
243  vectorShape =
244  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
245  scalableDims =
246  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
247  } else {
248  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
249  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
250  }
251 
252  return VectorType::get(vectorShape, elementType, scalableDims);
253  }
254 
255  /// Masks an operation with the canonical vector mask if the operation needs
256  /// masking. Returns the masked operation or the original operation if masking
257  /// is not needed. If provided, the canonical mask for this operation is
258  /// permuted using `maybeIndexingMap`.
259  Operation *
260  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
261  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
262 
263 private:
264  /// Initializes the iteration space static sizes using the Linalg op
265  /// information. This may become more complicated in the future.
266  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
267  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
268  }
269 
270  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
271  /// all the static and dynamic dimensions of the iteration space to be
272  /// vectorized and store them in `iterSpaceValueSizes`.
273  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
274  LinalgOp linalgOp);
275 
276  /// Create or retrieve an existing mask value to mask `opToMask` in the
277  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
278  /// permuted using that permutation map. If a new mask is created, it will be
279  /// cached for future users.
280  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
281  LinalgOp linalgOp,
282  std::optional<AffineMap> maybeMaskingMap);
283 
284  /// Check whether this permutation map can be used for masking. At the
285  /// moment we only make sure that there are no broadcast dimensions, but this
286  /// might change if indexing maps evolve.
287  bool isValidMaskingMap(AffineMap maskingMap) {
288  return maskingMap.getBroadcastDims().size() == 0;
289  }
290 
291  /// Turn the input indexing map into a valid masking map.
292  ///
293  /// The input indexing map may contain "zero" results, e.g.:
294  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
295  /// Applying such maps to canonical vector shapes like this one:
296  /// (1, 16, 16, 4)
297  /// would yield an invalid vector shape like this:
298  /// (16, 16, 1, 0)
299  /// Instead, drop the broadcasting dims that make no sense for masking perm.
300  /// maps:
301  /// (d0, d1, d2, d3) -> (d2, d1, d0)
302  /// This way, the corresponding vector/mask type will be:
303  /// vector<16x16x1xty>
304  /// rather than this invalid Vector type:
305  /// vector<16x16x1x0xty>
306  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
307  return indexingMap.dropZeroResults();
308  }
309 
310  // Holds the compile-time static sizes of the iteration space to vectorize.
311  // Dynamic dimensions are represented using ShapedType::kDynamic.
312  SmallVector<int64_t> iterSpaceStaticSizes;
313 
314  /// Holds the value sizes of the iteration space to vectorize. Static
315  /// dimensions are represented by 'arith.constant' and dynamic
316  /// dimensions by 'tensor/memref.dim'.
317  SmallVector<Value> iterSpaceValueSizes;
318 
319  /// Holds the canonical vector shape used to vectorize the iteration space.
320  SmallVector<int64_t> canonicalVecShape;
321 
322  /// Holds the vector dimensions that are scalable in the canonical vector
323  /// shape.
324  SmallVector<bool> scalableVecDims;
325 
326  /// Holds the active masks for permutations of the canonical vector iteration
327  /// space.
328  DenseMap<AffineMap, Value> activeMaskCache;
329 
330  /// Global vectorization guard for the incoming rewriter. It's initialized
331  /// when the vectorization state is initialized.
333 };
334 
335 LogicalResult
336 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
337  LinalgOp linalgOp) {
338  // TODO: Support 0-d vectors.
339  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
340  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
341  // Create constant index op for static dimensions.
342  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
343  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
344  continue;
345  }
346 
347  // Find an operand defined on this dimension of the iteration space to
348  // extract the runtime dimension size.
349  Value operand;
350  unsigned operandDimPos;
351  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
352  operandDimPos)))
353  return failure();
354 
355  Value dynamicDim = linalgOp.hasPureTensorSemantics()
356  ? (Value)rewriter.create<tensor::DimOp>(
357  linalgOp.getLoc(), operand, operandDimPos)
358  : (Value)rewriter.create<memref::DimOp>(
359  linalgOp.getLoc(), operand, operandDimPos);
360  iterSpaceValueSizes.push_back(dynamicDim);
361  }
362 
363  return success();
364 }
365 
366 /// Initializes the vectorization state, including the computation of the
367 /// canonical vector shape for vectorization.
368 // TODO: Move this to the constructor when we can remove the failure cases.
369 LogicalResult
370 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
371  ArrayRef<int64_t> inputVectorSizes,
372  ArrayRef<bool> inputScalableVecDims) {
373  // Initialize the insertion point.
374  rewriter.setInsertionPoint(linalgOp);
375 
376  if (!inputVectorSizes.empty()) {
377  // Get the canonical vector shape from the input vector sizes provided. This
378  // path should be taken to vectorize code with dynamic shapes and when using
379  // vector sizes greater than the iteration space sizes.
380  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
381  scalableVecDims.append(inputScalableVecDims.begin(),
382  inputScalableVecDims.end());
383  } else {
384  // Compute the canonical vector shape from the operation shape. If there are
385  // dynamic shapes, the operation won't be vectorized. We assume all the
386  // vector dimensions are fixed.
387  canonicalVecShape = linalgOp.getStaticLoopRanges();
388  scalableVecDims.append(linalgOp.getNumLoops(), false);
389  }
390 
391  LDBG("Canonical vector shape: ");
392  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
393  LLVM_DEBUG(llvm::dbgs() << "\n");
394  LDBG("Scalable vector dims: ");
395  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
396  LLVM_DEBUG(llvm::dbgs() << "\n");
397 
398  if (ShapedType::isDynamicShape(canonicalVecShape))
399  return failure();
400 
401  // Initialize iteration space static sizes.
402  initIterSpaceStaticSizes(linalgOp);
403 
404  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
405  // all the static and dynamic dimensions of the iteration space, needed to
406  // compute a mask during vectorization.
407  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
408  return failure();
409 
410  return success();
411 }
412 
413 /// Create or retrieve an existing mask value to mask `opToMask` in the
414 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
415 /// using that permutation map. If a new mask is created, it will be cached for
416 /// future users.
417 Value VectorizationState::getOrCreateMaskFor(
418  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
419  std::optional<AffineMap> maybeMaskingMap) {
420 
421  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
422  "Ill-formed masking map.");
423 
424  // No mask is needed if the operation is not maskable.
425  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
426  if (!maskableOp)
427  return Value();
428 
429  assert(!maskableOp.isMasked() &&
430  "Masking an operation that is already masked");
431 
432  // If no masking map was provided, use an identity map with the loop dims.
433  assert((!maybeMaskingMap || *maybeMaskingMap) &&
434  "Unexpected null mask permutation map");
435  AffineMap maskingMap =
436  maybeMaskingMap ? *maybeMaskingMap
438  linalgOp.getNumLoops(), rewriter.getContext());
439 
440  LDBG("Masking map: " << maskingMap << "\n");
441 
442  // Return the active mask for the masking map of this operation if it was
443  // already created.
444  auto activeMaskIt = activeMaskCache.find(maskingMap);
445  if (activeMaskIt != activeMaskCache.end()) {
446  Value mask = activeMaskIt->second;
447  LDBG("Reusing mask: " << mask << "\n");
448  return mask;
449  }
450 
451  // Compute permuted projection of the iteration space to be masked and the
452  // corresponding mask shape. If the resulting iteration space dimensions are
453  // static and identical to the mask shape, masking is not needed for this
454  // operation.
455  // TODO: Improve this check. Only projected permutation indexing maps are
456  // supported.
457  SmallVector<int64_t> permutedStaticSizes =
458  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
459  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
460  auto maskShape = maskType.getShape();
461 
462  LDBG("Mask shape: ");
463  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
464  LLVM_DEBUG(llvm::dbgs() << "\n");
465 
466  if (permutedStaticSizes == maskShape) {
467  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
468  activeMaskCache[maskingMap] = Value();
469  return Value();
470  }
471 
472  // Permute the iteration space value sizes to compute the mask upper bounds.
473  SmallVector<Value> upperBounds =
474  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
475  assert(!maskShape.empty() && !upperBounds.empty() &&
476  "Masked 0-d vectors are not supported yet");
477 
478  // Create the mask based on the dimension values.
479  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
480  maskType, upperBounds);
481  LDBG("Creating new mask: " << mask << "\n");
482  activeMaskCache[maskingMap] = mask;
483  return mask;
484 }
485 
486 Operation *
488  LinalgOp linalgOp,
489  std::optional<AffineMap> maybeIndexingMap) {
490  LDBG("Trying to mask: " << *opToMask << "\n");
491 
492  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
493  if (maybeIndexingMap)
494  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
495 
496  // Create or retrieve mask for this operation.
497  Value mask =
498  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
499 
500  if (!mask) {
501  LDBG("No mask required\n");
502  return opToMask;
503  }
504 
505  // Wrap the operation with a new `vector.mask` and update D-U chain.
506  assert(opToMask && "Expected a valid operation to mask");
507  auto maskOp = cast<vector::MaskOp>(
508  mlir::vector::maskOperation(rewriter, opToMask, mask));
509  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
510 
511  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
512  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
513  maskOpTerminator);
514 
515  LDBG("Masked operation: " << *maskOp << "\n");
516  return maskOp;
517 }
518 
519 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
520 /// projectedPermutation, compress the unused dimensions to serve as a
521 /// permutation_map for a vector transfer operation.
522 /// For example, given a linalg op such as:
523 ///
524 /// ```
525 /// %0 = linalg.generic {
526 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
527 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
528 /// }
529 /// ins(%0 : tensor<2x3x4xf32>)
530 /// outs(%1 : tensor<5x6xf32>)
531 /// ```
532 ///
533 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
534 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
535 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
537  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
538  "expected projected permutation");
539  auto res = compressUnusedDims(map);
540  assert(res.getNumDims() ==
541  (res.getNumResults() - res.getNumOfZeroResults()) &&
542  "expected reindexed map with same number of dims and results");
543  return res;
544 }
545 
546 /// Helper enum to represent conv1d input traversal order.
547 enum class Conv1DOpOrder {
548  W, // Corresponds to non-channeled 1D convolution operation.
549  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
550  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
551 };
552 
553 /// Helper data structure to represent the result of vectorization.
554 /// In certain specific cases, like terminators, we do not want to propagate/
556  /// Op failed to vectorize.
557  Failure = 0,
558  /// Op vectorized and custom function took care of replacement logic
560  /// Op vectorized into a new Op whose results will replace original Op's
561  /// results.
562  NewOp
563  // TODO: support values if Op vectorized to Many-Ops whose results we need to
564  // aggregate for replacement.
565 };
567  /// Return status from vectorizing the current op.
569  /// New vectorized operation to replace the current op.
570  /// Replacement behavior is specified by `status`.
572 };
573 
574 std::optional<vector::CombiningKind>
576  using ::mlir::vector::CombiningKind;
577 
578  if (!combinerOp)
579  return std::nullopt;
581  .Case<arith::AddIOp, arith::AddFOp>(
582  [&](auto op) { return CombiningKind::ADD; })
583  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
584  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
585  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
586  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
587  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
588  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
589  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
590  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
591  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
592  .Case<arith::MulIOp, arith::MulFOp>(
593  [&](auto op) { return CombiningKind::MUL; })
594  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
595  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
596  .Default([&](auto op) { return std::nullopt; });
597 }
598 
599 /// Check whether `outputOperand` is a reduction with a single combiner
600 /// operation. Return the combiner operation of the reduction. Return
601 /// nullptr otherwise. Multiple reduction operations would impose an
602 /// ordering between reduction dimensions and is currently unsupported in
603 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
604 /// max(min(X))
605 // TODO: use in LinalgOp verification, there is a circular dependency atm.
606 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
607  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
608  unsigned outputPos =
609  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
610  // Only single combiner operations are supported for now.
611  SmallVector<Operation *, 4> combinerOps;
612  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
613  combinerOps.size() != 1)
614  return nullptr;
615 
616  // Return the combiner operation.
617  return combinerOps[0];
618 }
619 
620 /// Broadcast `value` to a vector of `shape` if possible. Return value
621 /// otherwise.
622 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
623  auto dstVecType = dyn_cast<VectorType>(dstType);
624  // If no shape to broadcast to, just return `value`.
625  if (dstVecType.getRank() == 0)
626  return value;
627  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
629  return value;
630  Location loc = b.getInsertionPoint()->getLoc();
631  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
632 }
633 
634 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
635 /// assumes that `reductionOp` has two operands and one of them is the reduction
636 /// initial value.buildMultiDimReduce
637 // Note: this is a true builder that notifies the OpBuilder listener.
638 // TODO: Consider moving as a static helper on the ReduceOp.
640  Value valueToReduce, Value acc,
641  ArrayRef<bool> dimsToMask) {
642  auto maybeKind = getCombinerOpKind(reduceOp);
643  assert(maybeKind && "Failed precondition: could not get reduction kind");
644  return b.create<vector::MultiDimReductionOp>(
645  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
646 }
647 
648 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
649  return llvm::to_vector(
650  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
651 }
652 
653 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
654 /// reduction iterator.
655 static bool hasReductionIterator(LinalgOp &op) {
656  return isa<linalg::ReduceOp>(op) ||
657  (isa<linalg::GenericOp>(op) &&
658  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
659 }
660 
661 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
662 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
663 /// currently being vectorized. If `dest` has null rank, build an memref.store.
664 /// Return the produced value or null if no value is produced.
665 // Note: this is a true builder that notifies the OpBuilder listener.
666 // TODO: Consider moving as a static helper on the ReduceOp.
667 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
668  OpOperand *outputOperand,
669  VectorizationState &state) {
670  Location loc = value.getLoc();
671  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
672  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
673 
674  // Compute the vector type of the value to store. This type should be an
675  // identity or projection of the canonical vector type without any permutation
676  // applied, given that any permutation in a transfer write happens as part of
677  // the write itself.
679  opOperandMap.getContext(), opOperandMap.getNumInputs(),
680  [&](AffineDimExpr dimExpr) -> bool {
681  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
682  });
683  auto vectorType = state.getCanonicalVecType(
684  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
685 
686  Operation *write;
687  if (vectorType.getRank() > 0) {
688  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
689  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
690  rewriter.create<arith::ConstantIndexOp>(loc, 0));
691  value = broadcastIfNeeded(rewriter, value, vectorType);
692  assert(value.getType() == vectorType && "Incorrect type");
693  write = rewriter.create<vector::TransferWriteOp>(
694  loc, value, outputOperand->get(), indices, writeMap);
695  } else {
696  // 0-d case is still special: do not invert the reindexing writeMap.
697  if (!isa<VectorType>(value.getType()))
698  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
699  assert(value.getType() == vectorType && "Incorrect type");
700  write = rewriter.create<vector::TransferWriteOp>(
701  loc, value, outputOperand->get(), ValueRange{});
702  }
703 
704  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
705 
706  // If masked, set in-bounds to true. Masking guarantees that the access will
707  // be in-bounds.
708  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
709  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
710  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
711  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
712  }
713 
714  LDBG("vectorized op: " << *write << "\n");
715  if (!write->getResults().empty())
716  return write->getResult(0);
717  return Value();
718 }
719 
720 // Custom vectorization precondition function type. This is intented to be used
721 // with CustomVectorizationHook. Returns success if the corresponding custom
722 // hook can vectorize the op.
724  std::function<LogicalResult(Operation *, bool)>;
725 
726 // Custom vectorization function type. Produce a vector form of Operation*
727 // assuming all its vectorized operands are already in the IRMapping.
728 // Return nullptr if the Operation cannot be vectorized.
730  std::function<VectorizationResult(Operation *, const IRMapping &)>;
731 
732 /// Helper function to vectorize the terminator of a `linalgOp`. New result
733 /// vector values are appended to `newResults`. Return
734 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
735 /// should not try to map produced operations and instead return the results
736 /// using the `newResults` vector making them available to the vectorization
737 /// algorithm for RAUW. This function is meant to be used as a
738 /// CustomVectorizationHook.
739 static VectorizationResult
741  const IRMapping &bvm, VectorizationState &state,
742  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
743  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
744  if (!yieldOp)
746  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
747  // TODO: Scan for an opportunity for reuse.
748  // TODO: use a map.
749  Value vectorValue = bvm.lookup(output.value());
750  Value newResult =
751  buildVectorWrite(rewriter, vectorValue,
752  linalgOp.getDpsInitOperand(output.index()), state);
753  if (newResult)
754  newResults.push_back(newResult);
755  }
756 
758 }
759 
760 /// Helper function to vectorize the index operations of a `linalgOp`. Return
761 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
762 /// should map the produced operations. This function is meant to be used as a
763 /// CustomVectorizationHook.
765  VectorizationState &state,
766  Operation *op,
767  LinalgOp linalgOp) {
768  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
769  if (!indexOp)
771  auto loc = indexOp.getLoc();
772  // Compute the static loop sizes of the index op.
773  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
774  auto dim = indexOp.getDim();
775  // Compute a one-dimensional index vector for the index op dimension.
776  auto indexVectorType =
777  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
778  state.getScalableVecDims()[dim]);
779  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
780  // Return the one-dimensional index vector if it lives in the trailing
781  // dimension of the iteration space since the vectorization algorithm in this
782  // case can handle the broadcast.
783  if (dim == targetShape.size() - 1)
785  // Otherwise permute the targetShape to move the index dimension last,
786  // broadcast the one-dimensional index vector to the permuted shape, and
787  // finally transpose the broadcasted index vector to undo the permutation.
788  auto permPattern =
789  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
790  std::swap(permPattern[dim], permPattern.back());
791  auto permMap =
792  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
793 
794  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
795  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
796  indexSteps);
797  SmallVector<int64_t> transposition =
798  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
799  std::swap(transposition.back(), transposition[dim]);
800  auto transposeOp =
801  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
802  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
803 }
804 
805 /// Helper function to check if the tensor.extract can be vectorized by the
806 /// custom hook vectorizeTensorExtract.
807 static LogicalResult
808 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
809  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
810  if (!extractOp)
811  return failure();
812 
813  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
814  return failure();
815 
816  // Check the index type, but only for non 0-d tensors (for which we do need
817  // access indices).
818  if (not extractOp.getIndices().empty()) {
819  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
820  return failure();
821  }
822 
823  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
824  return !VectorType::isValidElementType(type);
825  })) {
826  return failure();
827  }
828 
829  return success();
830 }
831 
832 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
833 /// generated from `tensor.extract`. The offset is calculated as follows
834 /// (example using scalar values):
835 ///
836 /// offset = extractOp.indices[0]
837 /// for (i = 1; i < numIndices; i++)
838 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
839 ///
840 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
841 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
843  VectorizationState &state,
844  tensor::ExtractOp extractOp,
845  const IRMapping &bvm) {
846  // The vector of indices for GatherOp should be shaped as the output vector.
847  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
848  auto loc = extractOp.getLoc();
849 
850  Value offset = broadcastIfNeeded(
851  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
852 
853  const size_t numIndices = extractOp.getIndices().size();
854  for (size_t i = 1; i < numIndices; i++) {
855  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
856 
857  auto dimSize = broadcastIfNeeded(
858  rewriter,
859  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
860  indexVecType);
861 
862  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
863 
864  auto extractOpIndex = broadcastIfNeeded(
865  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
866 
867  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
868  }
869 
870  return offset;
871 }
872 
874 
875 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
876 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
877 /// represents a contiguous load operation.
878 ///
879 /// Note that when calling this hook, it is assumed that the output vector is
880 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
881 /// labelled as a gather load before entering this method.
882 ///
883 /// Following on from the above, it is assumed that:
884 /// * for statically shaped loops, when no masks are used, only one dim is !=
885 /// 1 (that's what the shape of the output vector is based on).
886 /// * for dynamically shaped loops, there might be more non-unit dims
887 /// as the output vector type is user-specified.
888 ///
889 /// TODO: Statically shaped loops + vector masking
890 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
891  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
892  assert(
893  (linalgOp.hasDynamicShape() ||
894  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
895  "For statically shaped Linalg Ops, only one "
896  "non-unit loop dim is expected");
897  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
898 
899  size_t idx = loopRanges.size() - 1;
900  for (; idx != 0; idx--)
901  if (loopRanges[idx] != 1)
902  break;
903 
904  return idx;
905 }
906 
907 /// Checks whether `val` can be used for calculating a loop invariant index.
908 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
909  VectorType resType) {
910 
911  assert(((llvm::count_if(resType.getShape(),
912  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
913  "n-D vectors are not yet supported");
914 
915  // Blocks outside _this_ linalg.generic are effectively loop invariant.
916  // However, analysing block arguments for _this_ linalg.generic Op is a bit
917  // tricky. Just bail out in the latter case.
918  // TODO: We could try analysing the corresponding affine map here.
919  auto *block = linalgOp.getBlock();
920  if (isa<BlockArgument>(val))
921  return llvm::all_of(block->getArguments(),
922  [&val](Value v) { return (v != val); });
923 
924  Operation *defOp = val.getDefiningOp();
925  assert(defOp && "This is neither a block argument nor an operation result");
926 
927  // IndexOp is loop invariant as long as its result remains constant across
928  // iterations. Note that for dynamic shapes, the corresponding dim will also
929  // be conservatively treated as != 1.
930  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
931  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
932  }
933 
934  auto *ancestor = block->findAncestorOpInBlock(*defOp);
935 
936  // Values define outside `linalgOp` are loop invariant.
937  if (!ancestor)
938  return true;
939 
940  // Values defined inside `linalgOp`, which are constant, are loop invariant.
941  if (isa<arith::ConstantOp>(ancestor))
942  return true;
943 
944  bool result = true;
945  for (auto op : ancestor->getOperands())
946  result &= isLoopInvariantIdx(linalgOp, op, resType);
947 
948  return result;
949 }
950 
951 /// Check whether `val` could be used for calculating the trailing index for a
952 /// contiguous load operation.
953 ///
954 /// There are currently 3 types of values that are allowed here:
955 /// 1. loop-invariant values,
956 /// 2. values that increment by 1 with every loop iteration,
957 /// 3. results of basic arithmetic operations (linear and continuous)
958 /// involving 1., 2. and 3.
959 /// This method returns True if indeed only such values are used in calculating
960 /// `val.`
961 ///
962 /// Additionally, the trailing index for a contiguous load operation should
963 /// increment by 1 with every loop iteration, i.e. be based on:
964 /// * `linalg.index <dim>` ,
965 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
966 /// `linalg.index <dim>` increments by 1 with every loop iteration).
967 /// `foundIndexOp` is updated to `true` when such Op is found.
968 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
969  bool &foundIndexOp, VectorType resType) {
970 
971  assert(((llvm::count_if(resType.getShape(),
972  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
973  "n-D vectors are not yet supported");
974 
975  // Blocks outside _this_ linalg.generic are effectively loop invariant.
976  // However, analysing block arguments for _this_ linalg.generic Op is a bit
977  // tricky. Just bail out in the latter case.
978  // TODO: We could try analysing the corresponding affine map here.
979  auto *block = linalgOp.getBlock();
980  if (isa<BlockArgument>(val))
981  return llvm::all_of(block->getArguments(),
982  [&val](Value v) { return (v != val); });
983 
984  Operation *defOp = val.getDefiningOp();
985  assert(defOp && "This is neither a block argument nor an operation result");
986 
987  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
988  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
989 
990  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
991  return true;
992  }
993 
994  auto *ancestor = block->findAncestorOpInBlock(*defOp);
995 
996  if (!ancestor)
997  return false;
998 
999  // Conservatively reject Ops that could lead to indices with stride other
1000  // than 1.
1001  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1002  return false;
1003 
1004  bool result = false;
1005  for (auto op : ancestor->getOperands())
1006  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1007 
1008  return result;
1009 }
1010 
1011 /// Infer the memory access pattern for the input ExtractOp
1012 ///
1013 /// Based on the ExtratOp result shape and the access indices, decides whether
1014 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1015 /// or a gather load. When analysing the ExtractOp indices (to identify
1016 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1017 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1018 /// Op).
1019 ///
1020 /// Note that it is always safe to use gather load operations for contiguous
1021 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1022 /// that `extractOp` is a gather load.
1024 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1025  LinalgOp &linalgOp, VectorType resType) {
1026 
1027  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1028 
1029  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1030  if (inputShape.getShape().empty())
1032 
1033  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1034  // otherwise.
1035  bool isOutput1DVector =
1036  (llvm::count_if(resType.getShape(),
1037  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1038  // 1. Assume that it's a gather load when reading non-1D vector.
1039  if (!isOutput1DVector)
1041 
1042  bool leadingIdxsLoopInvariant = true;
1043 
1044  // 2. Analyze the leading indices of `extractOp`.
1045  // Look at the way each index is calculated and decide whether it is suitable
1046  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1047  // gather load.
1048  auto indices = extractOp.getIndices();
1049  auto leadIndices = indices.drop_back(1);
1050 
1051  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1052  if (inputShape.getShape()[i] == 1)
1053  continue;
1054 
1055  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1056  }
1057 
1058  if (!leadingIdxsLoopInvariant) {
1059  LDBG("Found gather load: " << extractOp);
1061  }
1062 
1063  // 3. Analyze the trailing index for `extractOp`.
1064  // At this point we know that the leading indices are loop invariant. This
1065  // means that is potentially a scalar or a contiguous load. We can decide
1066  // based on the trailing idx.
1067  auto extractOpTrailingIdx = indices.back();
1068 
1069  // 3a. Scalar broadcast load
1070  // If the trailing index is loop invariant then this is a scalar load.
1071  if (leadingIdxsLoopInvariant &&
1072  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1073  LDBG("Found scalar broadcast load: " << extractOp);
1074 
1076  }
1077 
1078  // 3b. Contiguous loads
1079  // The trailing `extractOp` index should increment with every loop iteration.
1080  // This effectively means that it must be based on the trailing loop index.
1081  // This is what the following bool captures.
1082  bool foundIndexOp = false;
1083  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1084  foundIndexOp, resType);
1085  // TODO: Support generating contiguous loads for column vectors - that will
1086  // require adding a permutation map to tranfer_read Ops.
1087  bool isRowVector = resType.getShape().back() != 1;
1088  isContiguousLoad &= (foundIndexOp && isRowVector);
1089 
1090  if (isContiguousLoad) {
1091  LDBG("Found contigous load: " << extractOp);
1093  }
1094 
1095  // 4. Fallback case - gather load.
1096  LDBG("Found gather load: " << extractOp);
1098 }
1099 
1100 /// Helper function to vectorize the tensor.extract operations. Returns
1101 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1102 /// should map the produced operations. This function is meant to be used as a
1103 /// CustomVectorizationHook.
1104 static VectorizationResult
1106  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1107  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1108  if (!extractOp)
1110  auto loc = extractOp.getLoc();
1111 
1112  // Compute the static loop sizes of the extract op.
1113  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1114  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1115  loc,
1116  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1117  /*value=*/true));
1118  auto passThruConstantOp =
1119  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1120 
1121  // Base indices are currently set to 0. We will need to re-visit if more
1122  // generic scenarios are to be supported.
1123  SmallVector<Value> baseIndices(
1124  extractOp.getIndices().size(),
1125  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1126 
1127  VectorMemoryAccessKind memAccessKind =
1128  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1129 
1130  // 1. Handle gather access
1131  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1132  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1133 
1134  // Generate the gather load
1135  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1136  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1137  maskConstantOp, passThruConstantOp);
1138  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1139 
1140  LDBG("Vectorised as gather load: " << extractOp << "\n");
1142  }
1143 
1144  // 2. Handle:
1145  // a. scalar loads + broadcast,
1146  // b. contiguous loads.
1147  // Both cases use vector.transfer_read.
1148 
1149  // Collect indices for `vector.transfer_read`. At this point, the indices will
1150  // either be scalars or would have been broadcast to vectors matching the
1151  // result type. For indices that are vectors, there are two options:
1152  // * for non-trailing indices, all elements are identical (contiguous
1153  // loads are identified by looking for non-trailing indices that are
1154  // invariant with respect to the corresponding linalg.generic), or
1155  // * for trailing indices, the index vector will contain values with stride
1156  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1157  // needed.
1158  // This means that
1159  // * for scalar indices - just re-use it,
1160  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1161  // (0th) element and use that.
1162  SmallVector<Value> transferReadIdxs;
1163  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1164  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1165  if (idx.getType().isIndex()) {
1166  transferReadIdxs.push_back(idx);
1167  continue;
1168  }
1169 
1170  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1171  loc,
1172  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1173  resultType.getScalableDims().back()),
1174  idx);
1175  transferReadIdxs.push_back(
1176  rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1177  }
1178 
1179  // `tensor.extract_element` is always in-bounds, hence the following holds.
1180  auto dstRank = resultType.getRank();
1181  auto srcRank = extractOp.getTensor().getType().getRank();
1182  SmallVector<bool> inBounds(dstRank, true);
1183 
1184  // 2a. Handle scalar broadcast access.
1185  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1186  MLIRContext *ctx = rewriter.getContext();
1187  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1188  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1189 
1190  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1191  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1192  permutationMap, inBounds);
1193 
1194  // Mask this broadcasting xfer_read here rather than relying on the generic
1195  // path (the generic path assumes identity masking map, which wouldn't be
1196  // valid here).
1197  SmallVector<int64_t> readMaskShape = {1};
1198  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1199  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1200  loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1201  auto *maskedReadOp =
1202  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1203 
1204  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1205  return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1206  }
1207 
1208  // 2b. Handle contiguous access.
1209  auto permutationMap = AffineMap::getMinorIdentityMap(
1210  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1211 
1212  int32_t rankDiff = dstRank - srcRank;
1213  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1214  // dims so that the ranks match. This is done by extending the map with 0s.
1215  // For example, for dstRank = 3, srcRank = 2, the following map created
1216  // above:
1217  // (d0, d1) --> (d0, d1)
1218  // is extended as:
1219  // (d0, d1) --> (0, d0, d1)
1220  while (rankDiff > 0) {
1221  permutationMap = permutationMap.insertResult(
1222  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1223  rankDiff--;
1224  }
1225 
1226  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1227  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1228  inBounds);
1229 
1230  LDBG("Vectorised as contiguous load: " << extractOp);
1231  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1232 }
1233 
1234 /// Emit reduction operations if the shapes of the value to reduce is different
1235 /// that the result shape.
1236 // Note: this is a true builder that notifies the OpBuilder listener.
1237 // TODO: Consider moving as a static helper on the ReduceOp.
1238 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1239  Value reduceValue, Value initialValue,
1240  const IRMapping &bvm) {
1241  Value reduceVec = bvm.lookup(reduceValue);
1242  Value outputVec = bvm.lookup(initialValue);
1243  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1244  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1245  // Reduce only if needed as the value may already have been reduce for
1246  // contraction vectorization.
1247  if (!reduceType ||
1248  (outputType && reduceType.getShape() == outputType.getShape()))
1249  return nullptr;
1250  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1251  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1252 }
1253 
1254 /// Generic vectorization for a single operation `op`, given already vectorized
1255 /// operands carried by `bvm`. Vectorization occurs as follows:
1256 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1257 /// result on success.
1258 /// 2. Clone any constant in the current scope without vectorization: each
1259 /// consumer of the constant will later determine the shape to which the
1260 /// constant needs to be broadcast to.
1261 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1262 /// of the `customVectorizationHooks` to cover such cases.
1263 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1264 /// operand of maximal rank. Other operands have smaller rank and are
1265 /// broadcast accordingly. It is assumed this broadcast is always legal,
1266 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1267 ///
1268 /// This function assumes all operands of `op` have been vectorized and are in
1269 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1270 /// a topologically-sorted list of ops.
1271 /// This function does not update `bvm` but returns a VectorizationStatus that
1272 /// instructs the caller what `bvm` update needs to occur.
1273 static VectorizationResult
1275  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1276  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1277  LDBG("vectorize op " << *op << "\n");
1278 
1279  // 1. Try to apply any CustomVectorizationHook.
1280  if (!customVectorizationHooks.empty()) {
1281  for (auto &customFunc : customVectorizationHooks) {
1282  VectorizationResult result = customFunc(op, bvm);
1283  if (result.status == VectorizationStatus::Failure)
1284  continue;
1285  return result;
1286  }
1287  }
1288 
1289  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1290  // Clone so that the constant is not confined to the linalgOp block .
1291  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1292  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1293 
1294  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1297 
1298  // 4 . Check if the operation is a reduction.
1299  SmallVector<std::pair<Value, Value>> reductionOperands;
1300  for (Value operand : op->getOperands()) {
1301  auto blockArg = dyn_cast<BlockArgument>(operand);
1302  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1303  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1304  continue;
1305  SmallVector<Operation *> reductionOps;
1306  Value reduceValue = matchReduction(
1307  linalgOp.getRegionOutputArgs(),
1308  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1309  if (!reduceValue)
1310  continue;
1311  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1312  }
1313  if (!reductionOperands.empty()) {
1314  assert(reductionOperands.size() == 1);
1315  Operation *reduceOp =
1316  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1317  reductionOperands[0].second, bvm);
1318  if (reduceOp)
1320  }
1321 
1322  // 5. Generic vectorization path for ElementwiseMappable ops.
1323  // a. Get the first max ranked shape.
1324  VectorType firstMaxRankedType;
1325  for (Value operand : op->getOperands()) {
1326  auto vecOperand = bvm.lookup(operand);
1327  assert(vecOperand && "Vector operand couldn't be found");
1328 
1329  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1330  if (vecType && (!firstMaxRankedType ||
1331  firstMaxRankedType.getRank() < vecType.getRank()))
1332  firstMaxRankedType = vecType;
1333  }
1334  // b. Broadcast each op if needed.
1335  SmallVector<Value> vecOperands;
1336  for (Value scalarOperand : op->getOperands()) {
1337  Value vecOperand = bvm.lookup(scalarOperand);
1338  assert(vecOperand && "Vector operand couldn't be found");
1339 
1340  if (firstMaxRankedType) {
1341  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1342  getElementTypeOrSelf(vecOperand.getType()),
1343  firstMaxRankedType.getScalableDims());
1344  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1345  } else {
1346  vecOperands.push_back(vecOperand);
1347  }
1348  }
1349  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1350  SmallVector<Type> resultTypes;
1351  for (Type resultType : op->getResultTypes()) {
1352  resultTypes.push_back(
1353  firstMaxRankedType
1354  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1355  firstMaxRankedType.getScalableDims())
1356  : resultType);
1357  }
1358  // d. Build and return the new op.
1359  return VectorizationResult{
1361  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1362  resultTypes, op->getAttrs())};
1363 }
1364 
1365 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1366 /// vector form. Generic vectorization proceeds as follows:
1367 /// 1. Verify the `linalgOp` has one non-empty region.
1368 /// 2. Values defined above the region are mapped to themselves and will be
1369 /// broadcasted on a per-need basis by their consumers.
1370 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1371 /// load).
1372 /// TODO: Reuse opportunities for RAR dependencies.
1373 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1374 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1375 /// iteration indices.
1376 /// 5. Iteratively call vectorizeOneOp on the region operations.
1377 ///
1378 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1379 /// performed to the maximal common vector size implied by the `linalgOp`
1380 /// iteration space. This eager broadcasting is introduced in the
1381 /// permutation_map of the vector.transfer_read operations. The eager
1382 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1383 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1384 /// the absence of good canonicalizations, the amount of work increases.
1385 /// This is not deemed a problem as we expect canonicalizations and foldings to
1386 /// aggressively clean up the useless work.
1387 static LogicalResult
1389  LinalgOp linalgOp,
1390  SmallVectorImpl<Value> &newResults) {
1391  LDBG("Vectorizing operation as linalg generic\n");
1392  Block *block = linalgOp.getBlock();
1393 
1394  // 2. Values defined above the region can only be broadcast for now. Make them
1395  // map to themselves.
1396  IRMapping bvm;
1397  SetVector<Value> valuesSet;
1398  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1399  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1400 
1401  if (linalgOp.getNumDpsInits() == 0)
1402  return failure();
1403 
1404  // 3. Turn all BBArgs into vector.transfer_read / load.
1405  Location loc = linalgOp.getLoc();
1406  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1407  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1408  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1409  if (linalgOp.isScalar(opOperand)) {
1410  bvm.map(bbarg, opOperand->get());
1411  continue;
1412  }
1413 
1414  // 3.a. Convert the indexing map for this input/output to a transfer read
1415  // permutation map and masking map.
1416  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1417 
1418  AffineMap readMap;
1419  VectorType readType;
1420  Type elemType = getElementTypeOrSelf(opOperand->get());
1421  if (linalgOp.isDpsInput(opOperand)) {
1422  // 3.a.i. For input reads we use the canonical vector shape.
1423  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1424  readType = state.getCanonicalVecType(elemType);
1425  } else {
1426  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1427  // reductions), the vector shape is computed by mapping the canonical
1428  // vector shape to the output domain and back to the canonical domain.
1429  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1430  readType =
1431  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1432  }
1433 
1434  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1435 
1436  Operation *read = rewriter.create<vector::TransferReadOp>(
1437  loc, readType, opOperand->get(), indices, readMap);
1438  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1439  Value readValue = read->getResult(0);
1440 
1441  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1442  // will be in-bounds.
1443  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1444  SmallVector<bool> inBounds(readType.getRank(), true);
1445  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1446  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1447  }
1448 
1449  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1450  // TODO: remove this.
1451  if (readType.getRank() == 0)
1452  readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1453  ArrayRef<int64_t>());
1454 
1455  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1456  << "\n");
1457  bvm.map(bbarg, readValue);
1458  bvm.map(opOperand->get(), readValue);
1459  }
1460 
1462  // 4a. Register CustomVectorizationHook for yieldOp.
1463  CustomVectorizationHook vectorizeYield =
1464  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1465  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1466  };
1467  hooks.push_back(vectorizeYield);
1468 
1469  // 4b. Register CustomVectorizationHook for indexOp.
1470  CustomVectorizationHook vectorizeIndex =
1471  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1472  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1473  };
1474  hooks.push_back(vectorizeIndex);
1475 
1476  // 4c. Register CustomVectorizationHook for extractOp.
1477  CustomVectorizationHook vectorizeExtract =
1478  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1479  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1480  };
1481  hooks.push_back(vectorizeExtract);
1482 
1483  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1484  for (Operation &op : block->getOperations()) {
1485  VectorizationResult result =
1486  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1487  if (result.status == VectorizationStatus::Failure) {
1488  LDBG("failed to vectorize: " << op << "\n");
1489  return failure();
1490  }
1491  if (result.status == VectorizationStatus::NewOp) {
1492  Operation *maybeMaskedOp =
1493  state.maskOperation(rewriter, result.newOp, linalgOp);
1494  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1495  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1496  }
1497  }
1498 
1499  return success();
1500 }
1501 
1502 /// Given a linalg::PackOp, return the `dest` shape before any packing
1503 /// permutations.
1504 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1505  ArrayRef<int64_t> destShape) {
1506  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1507 }
1508 
1509 /// Creates an optionally masked TransferWriteOp
1510 ///
1511 /// Generates the following operation:
1512 /// %res = vector.transfer_write %vectorToStore into %dest
1513 ///
1514 /// If the leading N dimensions of the destination tensor do not match
1515 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1516 /// masking is applied to ensure correctness:
1517 ///
1518 /// %mask = vector.create_mask(%destShape)
1519 /// %res = vector.mask %mask {
1520 /// vector.transfer_write %vectorToStore into %dest
1521 /// }
1522 ///
1523 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1524 /// is used instead of masking:
1525 ///
1526 /// %write = vector.transfer_write %vectorToStore into %dest
1527 /// in_bounds_flags = (...)
1528 /// %res = vector.transfer_write %input into %dest
1529 /// {in_bounds = in_bounds_flags}
1530 ///
1531 /// NOTE: All write offsets are set to 0.
1532 /// TODO: Allow specyfying write offsets.
1533 /// NOTE: When N < rank(input), the missing vector sizes are effectively
1534 /// extracted from the trailing sizes of `destSizes`. This means those sizes
1535 /// must be static.
1536 /// TODO: Support cases where an arbitrary dim is dynamic - this will require
1537 /// specifying all the vector sizes.
1538 static Operation *
1539 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1540  Value dest,
1541  ArrayRef<int64_t> inputVecSizesForLeadingDims,
1542  bool useInBoundsInsteadOfMasking = false) {
1543 
1544  ShapedType destType = cast<ShapedType>(dest.getType());
1545  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
1546  static_cast<int64_t>(destType.getRank()) &&
1547  "Rank mismatch!");
1548  (void)destType;
1549 
1550  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1551  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1552 
1553  // Compute the in_bounds attribute
1554  SmallVector<bool> inBoundsVal(rank, true);
1555  if (useInBoundsInsteadOfMasking) {
1556  // In this case, assume that all the required vector sizes have been
1557  // provided.
1558  assert(inputVecSizesForLeadingDims.size() ==
1559  static_cast<size_t>(destType.getRank()) &&
1560  "Insufficient number of input vector sizes!");
1561  // Update the inBounds attribute.
1562  for (unsigned i = 0; i < rank; i++)
1563  inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1564  !ShapedType::isDynamic(destShape[i]);
1565  }
1566 
1567  // Generate the xfer_write Op
1568  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1569  Operation *write = builder.create<vector::TransferWriteOp>(
1570  loc,
1571  /*vector=*/vectorToStore,
1572  /*source=*/dest,
1573  /*indices=*/SmallVector<Value>(rank, zero),
1574  /*inBounds=*/inBoundsVal);
1575  assert(llvm::none_of(
1576  destShape.drop_front(inputVecSizesForLeadingDims.size()),
1577  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1578  "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1579 
1580  // If masking is disabled, exit.
1581  if (useInBoundsInsteadOfMasking)
1582  return write;
1583 
1584  // Check if masking is needed.
1585  bool needMaskForWrite =
1586  !llvm::equal(inputVecSizesForLeadingDims,
1587  destShape.take_front(inputVecSizesForLeadingDims.size()));
1588 
1589  // If masking is needed, generate the mask and mask the operation.
1590  if (needMaskForWrite) {
1591  SmallVector<int64_t> writeMaskShape;
1592  writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
1593  inputVecSizesForLeadingDims.end());
1594  writeMaskShape.append(destShape.begin() +
1595  inputVecSizesForLeadingDims.size(),
1596  destShape.end());
1597  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1598  Value maskForWrite = builder.create<vector::CreateMaskOp>(
1599  loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
1600  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1601  }
1602 
1603  return write;
1604 }
1605 
1606 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1607 /// padding value and (3) input vector sizes into:
1608 ///
1609 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1610 ///
1611 /// As in the following example:
1612 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1613 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1614 ///
1615 /// This pack would be vectorized to:
1616 ///
1617 /// %load = vector.mask %mask {
1618 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1619 /// {in_bounds = [true, true, true]} :
1620 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1621 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1622 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1623 /// to vector<32x4x2x1x16xf32>
1624 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1625 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1626 /// %write = vector.transfer_write %transpose,
1627 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1628 /// {in_bounds = [true, true, true, true, true]}
1629 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1630 ///
1631 /// If the (3) input vector sizes are not provided, the vector sizes are
1632 /// determined by the result tensor shape and the `in_bounds`
1633 /// attribute is used instead of masking to mark out-of-bounds accesses.
1634 ///
1635 /// NOTE: The input vector sizes specify the dimensions corresponding to the
1636 /// outer dimensions of the output tensor. The remaining dimensions are
1637 /// computed based on, e.g., the static inner tiles.
1638 /// Supporting dynamic inner tiles will require the user to specify the
1639 /// missing vector sizes. This is left as a TODO.
1640 static LogicalResult
1641 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1642  ArrayRef<int64_t> inputVectorSizes,
1643  SmallVectorImpl<Value> &newResults) {
1644  // TODO: Introduce a parent class that will handle the insertion point update.
1645  OpBuilder::InsertionGuard g(rewriter);
1646  rewriter.setInsertionPoint(packOp);
1647 
1648  Location loc = packOp.getLoc();
1649  auto padValue = packOp.getPaddingValue();
1650  if (!padValue) {
1651  padValue = rewriter.create<arith::ConstantOp>(
1652  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1653  }
1654  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1655  LogicalResult status =
1656  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1657  .reifyResultShapes(rewriter, reifiedReturnShapes);
1658  (void)status; // prevent unused variable warning on non-assert builds.
1659  assert(succeeded(status) && "failed to reify result shapes");
1660 
1661  // If the input vector sizes are not provided, then the vector sizes are
1662  // determined by the result tensor shape. In case the vector sizes aren't
1663  // provided, we update the inBounds attribute instead of masking.
1664  bool useInBoundsInsteadOfMasking = false;
1665  if (inputVectorSizes.empty()) {
1666  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1667  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1668  useInBoundsInsteadOfMasking = true;
1669  }
1670 
1671  // Create masked TransferReadOp.
1672  SmallVector<int64_t> inputShape(inputVectorSizes);
1673  auto innerTiles = packOp.getStaticInnerTiles();
1674  auto innerDimsPos = packOp.getInnerDimsPos();
1675  auto outerDimsPerm = packOp.getOuterDimsPerm();
1676  if (!outerDimsPerm.empty())
1677  applyPermutationToVector(inputShape,
1679  for (auto [idx, size] : enumerate(innerTiles))
1680  inputShape[innerDimsPos[idx]] *= size;
1681  auto maskedRead = vector::createReadOrMaskedRead(
1682  rewriter, loc, packOp.getSource(), inputShape, padValue,
1683  useInBoundsInsteadOfMasking);
1684 
1685  // Create ShapeCastOp.
1686  SmallVector<int64_t> destShape(inputVectorSizes);
1687  destShape.append(innerTiles.begin(), innerTiles.end());
1688  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1689  packOp.getDestType().getElementType());
1690  auto shapeCastOp =
1691  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1692 
1693  // Create TransposeOp.
1694  auto destPermutation =
1696  auto transposeOp = rewriter.create<vector::TransposeOp>(
1697  loc, shapeCastOp.getResult(), destPermutation);
1698 
1699  // Create TransferWriteOp.
1700  Value dest = rewriter.create<tensor::EmptyOp>(
1701  loc, reifiedReturnShapes[0],
1702  transposeOp.getResult().getType().getElementType());
1703  Operation *write =
1704  createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
1705  /*inputVecSizesForLeadingDims=*/inputVectorSizes,
1706  /*useInBoundsInsteadOfMasking=*/false);
1707  newResults.push_back(write->getResult(0));
1708  return success();
1709 }
1710 
1711 /// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1712 /// Vector::TransferReadOp - Reads a vector from the source tensor
1713 /// vector::TransposeOp - Transpose the Source tensor
1714 /// ShapeCastOp - Reshape the data based on the target.
1715 /// vector::TransferWriteOp. - Write the result vector back to the destination
1716 /// tensor.
1717 /// If the vector sizes are not provided:
1718 /// * the vector sizes are determined by the input operand and attributes,
1719 /// * update the inBounds attribute instead of masking.
1720 static LogicalResult
1721 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1722  ArrayRef<int64_t> inputVectorSizes,
1723  SmallVectorImpl<Value> &newResults) {
1724 
1725  // TODO: Introduce a parent class that will handle the insertion point update.
1726  OpBuilder::InsertionGuard g(rewriter);
1727  rewriter.setInsertionPoint(unpackOp);
1728 
1729  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1730 
1731  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1732  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1733  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1734  bool useInBoundsInsteadOfMasking = false;
1735  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1736 
1737  auto destSize = unpackOp.getDestRank();
1738 
1739  if (!inputVectorSizes.empty())
1740  assert(inputVectorSizes.size() == destSize &&
1741  "Incorrect number of input vector sizes");
1742 
1743  // vectorSizes is the shape of the vector that will be used to do final
1744  // write on the destination tensor. It is set like this: Let's say the
1745  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1746  // Thus:
1747  // 1. vectorSizes = sourceShape.take_front(N)
1748  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1749  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1750  // innerTiles attribute value.
1751  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1752  if (vectorSizes.empty()) {
1753  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1754  if (!outerDimsPerm.empty())
1756  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1757  vectorSizes[pos] *= innerTiles[i];
1758 
1759  useInBoundsInsteadOfMasking = true;
1760  }
1761 
1762  // readVectorSizes is the size of tensor used to read and apply mask. It is
1763  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1764  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1765  // size M-N
1766  // Thus:
1767  // - initially: readVectorSizes = vectorInputSizes
1768  // - Divide all the readMaskShape locations pointed by innerDimPos
1769  // by the innerTileSize attribute value.
1770  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1771  // - Append the remaining shape from SS
1772  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1773  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1774  // 128] and outer_dims_perm is [1, 0] then read shape is:
1775  // ReadVectorSizes(initial): [512, 128]
1776  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1777  // = [16, 8]
1778  // After applying outer_dims_perm: [8, 16]
1779  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1780 
1781  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1782 
1783  for (auto [index, size] : enumerate(innerTiles)) {
1784  readVectorSizes[innerDimPos[index]] =
1785  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1786  }
1787  if (!outerDimsPerm.empty()) {
1788  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1789  }
1790  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1791  sourceShape.end());
1792 
1793  ReifiedRankedShapedTypeDims reifiedRetShapes;
1794  LogicalResult status =
1795  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1796  .reifyResultShapes(rewriter, reifiedRetShapes);
1797  if (status.failed()) {
1798  LDBG("Unable to reify result shapes of " << unpackOp);
1799  return failure();
1800  }
1801  Location loc = unpackOp->getLoc();
1802 
1803  auto padValue = rewriter.create<arith::ConstantOp>(
1804  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1805 
1806  // Read result, mask if necessary. If transferReadOp shape is not equal
1807  // to shape of source, then a mask is necessary.
1809  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1810  /*useInBoundsInsteadOfMasking=*/false);
1811 
1812  PackingMetadata packMetadata;
1813  SmallVector<int64_t> lastDimToInsertPosPerm =
1814  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1815  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1816  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1817  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1818  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1819  RankedTensorType stripMineTensorType =
1820  RankedTensorType::get(stripMineShape, stripMineElemType);
1821  // Transpose the appropriate rows to match output.
1822  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1823  loc, readResult, lastDimToInsertPosPerm);
1824 
1825  // Collapse the vector to the size required by result.
1826  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1827  stripMineTensorType, packMetadata.reassociations);
1828  mlir::VectorType vecCollapsedType =
1829  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1830  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1831  loc, vecCollapsedType, transposeOp->getResult(0));
1832 
1833  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1834  // otherwise the validator complains that the mask size is invalid.
1835  SmallVector<int64_t> writeVectorSizes(
1836  unpackOp.getDestType().hasStaticShape()
1837  ? vectorSizes
1838  : shapeCastOp.getResultVectorType().getShape());
1839  Value dest = rewriter.create<tensor::EmptyOp>(
1840  loc, reifiedRetShapes[0],
1841  shapeCastOp.getResult().getType().getElementType());
1842  Operation *write =
1843  createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
1844  /*inputVecSizesForLeadingDims=*/writeVectorSizes,
1845  useInBoundsInsteadOfMasking);
1846  newResults.push_back(write->getResult(0));
1847  return success();
1848 }
1849 
1850 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1851 /// and (3) all-zero lowPad to
1852 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1853 static LogicalResult
1854 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1855  ArrayRef<int64_t> inputVectorSizes,
1856  SmallVectorImpl<Value> &newResults) {
1857  auto padValue = padOp.getConstantPaddingValue();
1858  Location loc = padOp.getLoc();
1859 
1860  // TODO: Introduce a parent class that will handle the insertion point update.
1861  OpBuilder::InsertionGuard g(rewriter);
1862  rewriter.setInsertionPoint(padOp);
1863 
1864  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1865  LogicalResult status =
1866  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1867  .reifyResultShapes(rewriter, reifiedReturnShapes);
1868  (void)status; // prevent unused variable warning on non-assert builds
1869  assert(succeeded(status) && "failed to reify result shapes");
1870  auto maskedRead = vector::createReadOrMaskedRead(
1871  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1872  /*useInBoundsInsteadOfMasking=*/false);
1873 
1874  // Create Xfer write Op
1875  Value dest = rewriter.create<tensor::EmptyOp>(
1876  loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1877  Operation *write =
1878  createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
1879  /*inputVecSizesForLeadingDims=*/inputVectorSizes,
1880  /*useInBoundsInsteadOfMasking=*/false);
1881  newResults.push_back(write->getResult(0));
1882  return success();
1883 }
1884 
1885 // TODO: probably need some extra checks for reduction followed by consumer
1886 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1887 static LogicalResult reductionPreconditions(LinalgOp op) {
1888  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1889  LDBG("reduction precondition failed: no reduction iterator\n");
1890  return failure();
1891  }
1892  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1893  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1894  if (indexingMap.isPermutation())
1895  continue;
1896 
1897  Operation *reduceOp = matchLinalgReduction(&opOperand);
1898  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1899  LDBG("reduction precondition failed: reduction detection failed\n");
1900  return failure();
1901  }
1902  }
1903  return success();
1904 }
1905 
1906 static LogicalResult
1908  bool flatten1DDepthwiseConv) {
1909  if (flatten1DDepthwiseConv) {
1910  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1911  "supported\n");
1912  return failure();
1913  }
1914 
1915  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1916  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1917  return failure();
1918  }
1919 
1920  // Support dynamic shapes in 1D depthwise convolution, but only in the
1921  // _channel_ dimension.
1922  Value lhs = conv.getDpsInputOperand(0)->get();
1923  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1924  auto shapeWithoutCh = lhsShape.drop_back(1);
1925  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1926  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1927  "channel dim can be dynamic\n");
1928  return failure();
1929  }
1930 
1931  return success();
1932 }
1933 
1934 static LogicalResult
1936  bool flatten1DDepthwiseConv) {
1937  if (isa<ConvolutionOpInterface>(op.getOperation()))
1938  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1939 
1940  if (hasReductionIterator(op))
1941  return reductionPreconditions(op);
1942 
1943  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1944  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1945  if (!isElementwise(op) &&
1946  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1947  op.getOperation()))
1948  return failure();
1949 
1950  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1951  return success();
1952 }
1953 
1954 /// Need to check if the inner-tiles are static/constant.
1955 static LogicalResult
1956 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
1957  ArrayRef<int64_t> inputVectorSizes) {
1958 
1959  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1960  return !getConstantIntValue(res).has_value();
1961  })) {
1962  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1963  return failure();
1964  }
1965  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1966  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1967  unpackOp.getDestType().hasStaticShape() &&
1968  unpackOp.getSourceType().hasStaticShape();
1969  if (!satisfyEmptyCond &&
1970  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1971  return failure();
1972 
1973  return success();
1974 }
1975 
1976 static LogicalResult
1977 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1978  ArrayRef<int64_t> inputVectorSizes) {
1979 
1980  TypedValue<RankedTensorType> source = sliceOp.getSource();
1981  auto sourceType = source.getType();
1982  if (!VectorType::isValidElementType(sourceType.getElementType()))
1983  return failure();
1984 
1985  // Get the pad value.
1986  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
1987  // scalar padding value. Note that:
1988  // * for in-bounds accesses,
1989  // the value is actually irrelevant. There are 2 cases in which xfer.read
1990  // accesses are known to be in-bounds:
1991  // 1. The source shape is static (output vector sizes would be based on
1992  // the source shape and hence all memory accesses would be in-bounds),
1993  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
1994  // this case it is safe to assume that all memory accesses are in-bounds.
1995  //
1996  // When the value is not known and not needed, use 0. Otherwise, bail out.
1997  Value padValue = getStaticPadVal(sliceOp);
1998  bool isOutOfBoundsRead =
1999  !sourceType.hasStaticShape() && inputVectorSizes.empty();
2000 
2001  if (!padValue && isOutOfBoundsRead) {
2002  LDBG("Failed to get a pad value for out-of-bounds read access\n");
2003  return failure();
2004  }
2005  return success();
2006 }
2007 
2008 namespace {
2009 enum class ConvOperationKind { Conv, Pool };
2010 } // namespace
2011 
2013  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2014  isa<BlockArgument>(op->getOperand(0));
2015 }
2016 
2017 // Returns the ConvOperationKind of the op using reduceOp of the generic
2018 // payload. If it is neither a convolution nor a pooling, it returns
2019 // std::nullopt.
2020 //
2021 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2022 // + yield) and rhs is not used) then it is the body of a pooling
2023 // If conv, check for single `mul` predecessor. The `mul` operands must be
2024 // block arguments or extension of block arguments.
2025 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2026 // must be block arguments or extension of block arguments.
2027 static std::optional<ConvOperationKind>
2029  int numBlockArguments =
2030  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2031 
2032  switch (numBlockArguments) {
2033  case 1: {
2034  // Will be convolution if feeder is a MulOp.
2035  // A strength reduced version of MulOp for i1 type is AndOp which is also
2036  // supported. Otherwise, it can be pooling. This strength reduction logic
2037  // is in `buildBinaryFn` helper in the Linalg dialect.
2038  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2039  llvm::IsaPred<BlockArgument>);
2040  assert(feedValIt != reduceOp->operand_end() &&
2041  "Expected a non-block argument operand");
2042  Operation *feedOp = (*feedValIt).getDefiningOp();
2043  if (isCastOfBlockArgument(feedOp)) {
2044  return ConvOperationKind::Pool;
2045  }
2046 
2047  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2048  (isa<arith::AndIOp>(feedOp) &&
2049  feedOp->getResultTypes()[0].isInteger(1))) &&
2050  llvm::all_of(feedOp->getOperands(), [](Value v) {
2051  if (isa<BlockArgument>(v))
2052  return true;
2053  if (Operation *op = v.getDefiningOp())
2054  return isCastOfBlockArgument(op);
2055  return false;
2056  }))) {
2057  return std::nullopt;
2058  }
2059 
2060  return ConvOperationKind::Conv;
2061  }
2062  case 2:
2063  // Must be pooling
2064  return ConvOperationKind::Pool;
2065  default:
2066  return std::nullopt;
2067  }
2068 }
2069 
2070 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2071  switch (kind) {
2072  case vector::CombiningKind::ADD:
2073  case vector::CombiningKind::MAXNUMF:
2074  case vector::CombiningKind::MAXIMUMF:
2075  case vector::CombiningKind::MAXSI:
2076  case vector::CombiningKind::MAXUI:
2077  case vector::CombiningKind::MINNUMF:
2078  case vector::CombiningKind::MINIMUMF:
2079  case vector::CombiningKind::MINSI:
2081  return true;
2082  default:
2083  return false;
2084  }
2085 }
2086 
2087 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2088  auto getOperandType = [&](auto operand) {
2089  return dyn_cast<ShapedType>((operand->get()).getType());
2090  };
2091  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2092  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2093  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2094  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2095  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2096  // Note that this also ensures 2D and 3D convolutions are rejected.
2097  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2098  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2099  return failure();
2100 
2101  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2102  if (!reduceOp)
2103  return failure();
2104 
2105  auto maybeOper = getConvOperationKind(reduceOp);
2106  if (!maybeOper.has_value())
2107  return failure();
2108 
2109  auto maybeKind = getCombinerOpKind(reduceOp);
2110  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2111  // can get strength reduced to `OR` which is also supported. This strength
2112  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2113  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2114  *maybeKind != vector::CombiningKind::OR) &&
2115  (*maybeOper != ConvOperationKind::Pool ||
2116  !isSupportedPoolKind(*maybeKind)))) {
2117  return failure();
2118  }
2119 
2120  auto rhsRank = rhsShapedType.getRank();
2121  if (*maybeOper == ConvOperationKind::Pool) {
2122  if (rhsRank != 1)
2123  return failure();
2124  } else {
2125  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2126  return failure();
2127  }
2128 
2129  return success();
2130 }
2131 
2132 static LogicalResult vectorizeLinalgOpPrecondition(
2133  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2134  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2135  // tensor with dimension of 0 cannot be vectorized.
2136  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
2137  return failure();
2138  // Check API contract for input vector sizes.
2139  if (!inputVectorSizes.empty() &&
2140  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2141  inputVectorSizes)))
2142  return failure();
2143 
2144  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2145  linalgOp, flatten1DDepthwiseConv))) {
2146  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
2147  return failure();
2148  }
2149 
2151 
2152  // Register CustomVectorizationPrecondition for extractOp.
2153  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2154 
2155  // All types in the body should be a supported element type for VectorType.
2156  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2157  // Check if any custom hook can vectorize the inner op.
2158  if (llvm::any_of(
2159  customPreconditions,
2160  [&](const CustomVectorizationPrecondition &customPrecondition) {
2161  return succeeded(
2162  customPrecondition(&innerOp, vectorizeNDExtract));
2163  })) {
2164  continue;
2165  }
2166  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
2167  return !VectorType::isValidElementType(type);
2168  })) {
2169  return failure();
2170  }
2171  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
2172  return !VectorType::isValidElementType(type);
2173  })) {
2174  return failure();
2175  }
2176  }
2177  if (isElementwise(linalgOp))
2178  return success();
2179 
2180  // TODO: isaConvolutionOpInterface that can also infer from generic
2181  // features. But we will still need stride/dilation attributes that will be
2182  // annoying to reverse-engineer...
2183  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2184  return vectorizeConvOpPrecondition(linalgOp);
2185 
2186  // TODO: the common vector shape is equal to the static loop sizes only when
2187  // all indexing maps are projected permutations. For convs and stencils the
2188  // logic will need to evolve.
2189  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2190  LDBG("precondition failed: not projected permutations\n");
2191  return failure();
2192  }
2193  if (failed(reductionPreconditions(linalgOp))) {
2194  LDBG("precondition failed: reduction preconditions\n");
2195  return failure();
2196  }
2197  return success();
2198 }
2199 
2200 static LogicalResult
2201 vectorizePackOpPrecondition(linalg::PackOp packOp,
2202  ArrayRef<int64_t> inputVectorSizes) {
2203  auto padValue = packOp.getPaddingValue();
2204  Attribute cstAttr;
2205  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2206  LDBG("pad value is not constant: " << packOp << "\n");
2207  return failure();
2208  }
2209  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2210  bool satisfyEmptyCond = true;
2211  if (inputVectorSizes.empty()) {
2212  if (!packOp.getDestType().hasStaticShape() ||
2213  !packOp.getSourceType().hasStaticShape())
2214  satisfyEmptyCond = false;
2215  }
2216 
2217  if (!satisfyEmptyCond &&
2219  resultTensorShape.take_front(packOp.getSourceRank()),
2220  inputVectorSizes)))
2221  return failure();
2222 
2223  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2224  return !getConstantIntValue(v).has_value();
2225  })) {
2226  LDBG("inner_tiles must be constant: " << packOp << "\n");
2227  return failure();
2228  }
2229 
2230  return success();
2231 }
2232 
2233 static LogicalResult
2234 vectorizePadOpPrecondition(tensor::PadOp padOp,
2235  ArrayRef<int64_t> inputVectorSizes) {
2236  auto padValue = padOp.getConstantPaddingValue();
2237  if (!padValue) {
2238  LDBG("pad value is not constant: " << padOp << "\n");
2239  return failure();
2240  }
2241 
2242  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2243  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2244  inputVectorSizes)))
2245  return failure();
2246 
2247  // Padding with non-zero low pad values is not supported, unless the
2248  // corresponding result dim is 1 as this would require shifting the results to
2249  // the right for the low padded dims by the required amount of low padding.
2250  // However, we do support low padding if the dims being low padded have result
2251  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2252  // input size of that dimension will be dynamically zero (as the sum of the
2253  // low pad and input dim size has to be one) and hence we will create a zero
2254  // mask as the lowering logic just makes the mask one for the input dim size -
2255  // which is zero here. Hence we will load the pad value which is what we want
2256  // in this case. If the low pad is dynamically zero then the lowering is
2257  // correct as well as no shifts are necessary.
2258  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2259  Value padValue = en.value();
2260  unsigned pos = en.index();
2261  std::optional<int64_t> pad = getConstantIntValue(padValue);
2262  return (!pad.has_value() || pad.value() != 0) &&
2263  resultTensorShape[pos] != 1;
2264  })) {
2265  LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
2266  return failure();
2267  }
2268 
2269  return success();
2270 }
2271 
2272 /// Preconditions for scalable vectors. This is quite restrictive - it models
2273 /// the fact that in practice we would only make selected dimensions scalable.
2274 static LogicalResult
2276  ArrayRef<int64_t> inputVectorSizes,
2277  ArrayRef<bool> inputScalableVecDims) {
2278  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2279  "Number of input vector sizes and scalable dims doesn't match");
2280 
2281  size_t numOfScalableDims =
2282  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2283 
2284  if (numOfScalableDims == 0)
2285  return success();
2286 
2287  auto linalgOp = dyn_cast<LinalgOp>(op);
2288 
2289  // Cond 1: There's been no need for scalable vectorisation of
2290  // non-linalg Ops so far
2291  if (!linalgOp)
2292  return failure();
2293 
2294  // Cond 2: There's been no need for more than 2 scalable dims so far
2295  if (numOfScalableDims > 2)
2296  return failure();
2297 
2298  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2299  // it matches one of the supported cases:
2300  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2301  // (*).
2302  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2303  // parallel dims.
2304  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2305  // dim.
2306  // The 2nd restriction above means that only Matmul-like Ops are supported
2307  // when 2 dims are scalable, e.g. :
2308  // * iterators = [parallel, parallel, reduction]
2309  // * scalable flags = [true, true, false]
2310  //
2311  // (*) Non-unit dims get folded away in practice.
2312  // TODO: Relax these conditions as good motivating examples are identified.
2313 
2314  // Find the first scalable flag.
2315  bool seenNonUnitParallel = false;
2316  auto iterators = linalgOp.getIteratorTypesArray();
2317  SmallVector<bool> scalableFlags(inputScalableVecDims);
2318  int64_t idx = scalableFlags.size() - 1;
2319  while (!scalableFlags[idx]) {
2320  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2321  seenNonUnitParallel |=
2322  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2323 
2324  iterators.pop_back();
2325  scalableFlags.pop_back();
2326  --idx;
2327  }
2328 
2329  // Analyze the iterator corresponding to the first scalable dim.
2330  switch (iterators.back()) {
2331  case utils::IteratorType::reduction: {
2332  // Check 3. above is met.
2333  if (iterators.size() != inputVectorSizes.size()) {
2334  LDBG("Non-trailing reduction dim requested for scalable "
2335  "vectorization\n");
2336  return failure();
2337  }
2338  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2339  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2340  "is not supported\n");
2341  return failure();
2342  }
2343  break;
2344  }
2345  case utils::IteratorType::parallel: {
2346  // Check 1. and 2. above are met.
2347  if (seenNonUnitParallel) {
2348  LDBG("Inner parallel dim not requested for scalable "
2349  "vectorization\n");
2350  return failure();
2351  }
2352  break;
2353  }
2354  }
2355 
2356  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2357  // supported for which expect the folowing config:
2358  // * iterators = [parallel, parallel, reduction]
2359  // * scalable flags = [true, true, false]
2360  if (numOfScalableDims == 2) {
2361  // Disallow below case which breaks 3. above:
2362  // * iterators = [..., parallel, reduction]
2363  // * scalable flags = [..., true, true]
2364  if (iterators.back() == utils::IteratorType::reduction) {
2365  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2366  "vectorization\n");
2367  return failure();
2368  }
2369  scalableFlags.pop_back();
2370  iterators.pop_back();
2371 
2372  if (!scalableFlags.back() ||
2373  (iterators.back() != utils::IteratorType::parallel))
2374  return failure();
2375  }
2376 
2377  // Check to not let go the matmul with extended semantic, through this
2378  // transform.
2379  if (linalgOp.hasUserDefinedMaps())
2380  return failure();
2381 
2382  // Cond 4: Only the following ops are supported in the
2383  // presence of scalable vectors
2384  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2385  isa<linalg::MatmulTransposeAOp>(op) ||
2386  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2387  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2388 }
2389 
2391  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2392  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2393  bool flatten1DDepthwiseConv) {
2394 
2395  if (!hasVectorizationImpl(op))
2396  return failure();
2397 
2398  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2399  inputScalableVecDims)))
2400  return failure();
2401 
2403  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2404  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2405  vectorizeNDExtract,
2406  flatten1DDepthwiseConv);
2407  })
2408  .Case<tensor::PadOp>([&](auto padOp) {
2409  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2410  })
2411  .Case<linalg::PackOp>([&](auto packOp) {
2412  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2413  })
2414  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2415  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2416  })
2417  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2418  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2419  })
2420  .Default([](auto) { return failure(); });
2421 }
2422 
2423 /// Converts affine.apply Ops to arithmetic operations.
2424 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2425  OpBuilder::InsertionGuard g(rewriter);
2426  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2427 
2428  for (auto op : make_early_inc_range(toReplace)) {
2429  rewriter.setInsertionPoint(op);
2430  auto expanded = affine::expandAffineExpr(
2431  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2432  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2433  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2434  rewriter.replaceOp(op, expanded);
2435  }
2436 }
2437 
2439  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2440  tensor::InsertSliceOp>(op);
2441 }
2442 
2443 /// Emit a suitable vector form for an operation. If provided,
2444 /// `inputVectorSizes` are used to vectorize this operation.
2445 /// `inputVectorSizes` must match the rank of the iteration space of the
2446 /// operation and the input vector sizes must be greater than or equal to
2447 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2448 /// also allows the vectorization of operations with dynamic shapes.
2449 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2450  ArrayRef<int64_t> inputVectorSizes,
2451  ArrayRef<bool> inputScalableVecDims,
2452  bool vectorizeNDExtract,
2453  bool flatten1DDepthwiseConv) {
2454  LDBG("Attempting to vectorize:\n" << *op << "\n");
2455  LDBG("Input vector sizes: ");
2456  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2457  LLVM_DEBUG(llvm::dbgs() << "\n");
2458  LDBG("Input scalable vector dims: ");
2459  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2460  LLVM_DEBUG(llvm::dbgs() << "\n");
2461 
2462  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2463  vectorizeNDExtract,
2464  flatten1DDepthwiseConv))) {
2465  LDBG("Vectorization pre-conditions failed\n");
2466  return failure();
2467  }
2468 
2469  // Initialize vectorization state.
2470  VectorizationState state(rewriter);
2471  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2472  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2473  inputScalableVecDims))) {
2474  LDBG("Vectorization state couldn't be initialized\n");
2475  return failure();
2476  }
2477  }
2478 
2479  SmallVector<Value> results;
2480  auto vectorizeResult =
2482  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2483  // TODO: isaConvolutionOpInterface that can also infer from
2484  // generic features. Will require stride/dilation attributes
2485  // inference.
2486  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2487  FailureOr<Operation *> convOr = vectorizeConvolution(
2488  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2489  flatten1DDepthwiseConv);
2490  if (succeeded(convOr)) {
2491  llvm::append_range(results, (*convOr)->getResults());
2492  return success();
2493  }
2494 
2495  LDBG("Unsupported convolution can't be vectorized.\n");
2496  return failure();
2497  }
2498 
2499  LDBG("Vectorize generic by broadcasting to the canonical vector "
2500  "shape\n");
2501 
2502  // Pre-process before proceeding.
2503  convertAffineApply(rewriter, linalgOp);
2504 
2505  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2506  // to 'OpBuilder' when it is passed over to some methods like
2507  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2508  // erase an op within these methods, the actual rewriter won't be
2509  // notified and we will end up with read-after-free issues!
2510  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2511  })
2512  .Case<tensor::PadOp>([&](auto padOp) {
2513  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2514  results);
2515  })
2516  .Case<linalg::PackOp>([&](auto packOp) {
2517  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2518  results);
2519  })
2520  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2521  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2522  inputVectorSizes, results);
2523  })
2524  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2525  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2526  results);
2527  })
2528  .Default([](auto) { return failure(); });
2529 
2530  if (failed(vectorizeResult)) {
2531  LDBG("Vectorization failed\n");
2532  return failure();
2533  }
2534 
2535  if (!results.empty())
2536  rewriter.replaceOp(op, results);
2537  else
2538  rewriter.eraseOp(op);
2539 
2540  return success();
2541 }
2542 
2544  memref::CopyOp copyOp) {
2545  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2546  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2547  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2548  return failure();
2549 
2550  auto srcElementType = getElementTypeOrSelf(srcType);
2551  auto dstElementType = getElementTypeOrSelf(dstType);
2552  if (!VectorType::isValidElementType(srcElementType) ||
2553  !VectorType::isValidElementType(dstElementType))
2554  return failure();
2555 
2556  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2557  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2558 
2559  Location loc = copyOp->getLoc();
2560  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2561  SmallVector<Value> indices(srcType.getRank(), zero);
2562 
2563  Value readValue = rewriter.create<vector::TransferReadOp>(
2564  loc, readType, copyOp.getSource(), indices,
2565  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2566  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2567  readValue =
2568  rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2569  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2570  }
2571  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2572  loc, readValue, copyOp.getTarget(), indices,
2573  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2574  rewriter.replaceOp(copyOp, writeValue->getResults());
2575  return success();
2576 }
2577 
2578 //----------------------------------------------------------------------------//
2579 // Misc. vectorization patterns.
2580 //----------------------------------------------------------------------------//
2581 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2582 /// given operation type OpTy.
2583 template <typename OpTy>
2584 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2586 
2587  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2588  PatternRewriter &rewriter) const final {
2589  bool changed = false;
2590  // Insert users in vector, because some users may be replaced/removed.
2591  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2592  if (auto op = dyn_cast<OpTy>(user))
2593  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2594  return success(changed);
2595  }
2596 
2597 protected:
2598  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2599  tensor::PadOp padOp, OpTy op) const = 0;
2600 };
2601 
2602 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2603 /// ```
2604 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2605 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2606 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2607 /// ```
2608 /// is rewritten to:
2609 /// ```
2610 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2611 /// {in_bounds = [true, true]}
2612 /// : tensor<?x?xf32>, vector<17x5xf32>
2613 /// ```
2614 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2615 /// sure that the original padding value %cst was never used.
2616 ///
2617 /// This rewrite is possible if:
2618 /// - `xferOp` has no out-of-bounds dims or mask.
2619 /// - Low padding is static 0.
2620 /// - Single, scalar padding value.
2622  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2624  vector::TransferReadOp>::VectorizePadOpUserPattern;
2625 
2626  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2627  vector::TransferReadOp xferOp) const override {
2628  // Low padding must be static 0.
2629  if (!padOp.hasZeroLowPad())
2630  return failure();
2631  // Pad value must be a constant.
2632  auto padValue = padOp.getConstantPaddingValue();
2633  if (!padValue)
2634  return failure();
2635  // Padding value of existing `xferOp` is unused.
2636  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2637  return failure();
2638 
2639  rewriter.modifyOpInPlace(xferOp, [&]() {
2640  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2641  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2642  rewriter.getBoolArrayAttr(inBounds));
2643  xferOp.getBaseMutable().assign(padOp.getSource());
2644  xferOp.getPaddingMutable().assign(padValue);
2645  });
2646 
2647  return success();
2648  }
2649 };
2650 
2651 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2652 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2653 /// value, where the same amount of padding is immediately removed again after
2654 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2655 /// tensor value and apply out-of-bounds masking. E.g.:
2656 /// ```
2657 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2658 /// : tensor<...> to tensor<?x?xf32>
2659 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2660 /// %2 = vector.transfer_write %vec, %1[...]
2661 /// : vector<17x5xf32>, tensor<17x5xf32>
2662 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2663 /// : tensor<17x5xf32> to tensor<?x?xf32>
2664 /// ```
2665 /// is rewritten to:
2666 /// ```
2667 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2668 /// : tensor<...> to tensor<?x?xf32>
2669 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2670 /// tensor<?x?xf32>
2671 /// ```
2672 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2673 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2674 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2675 /// from %r's old dimensions.
2676 ///
2677 /// This rewrite is possible if:
2678 /// - Low padding is static 0.
2679 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2680 /// ExtractSliceOp trims the same amount of padding that was added
2681 /// beforehand.
2682 /// - Single, scalar padding value.
2684  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2686  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2687 
2688  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2689  vector::TransferWriteOp xferOp) const override {
2690  // TODO: support 0-d corner case.
2691  if (xferOp.getTransferRank() == 0)
2692  return failure();
2693 
2694  // Low padding must be static 0.
2695  if (!padOp.hasZeroLowPad())
2696  return failure();
2697  // Pad value must be a constant.
2698  auto padValue = padOp.getConstantPaddingValue();
2699  if (!padValue)
2700  return failure();
2701  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2702  if (!xferOp->hasOneUse())
2703  return failure();
2704  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2705  if (!trimPadding)
2706  return failure();
2707  // Only static zero offsets supported when trimming padding.
2708  if (!trimPadding.hasZeroOffset())
2709  return failure();
2710  // trimPadding must remove the amount of padding that was added earlier.
2711  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2712  return failure();
2713 
2714  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2715  rewriter.setInsertionPoint(xferOp);
2716 
2717  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2718  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2719  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2720  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2721  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2722  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2723 
2724  return success();
2725  }
2726 
2727  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2728  /// i.e., same dimensions.
2729  ///
2730  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2731  /// dimensions, this function tries to infer the (static) tensor size by
2732  /// looking at the defining op and utilizing op-specific knowledge.
2733  ///
2734  /// This is a conservative analysis. In case equal tensor sizes cannot be
2735  /// proven statically, this analysis returns `false` even though the tensor
2736  /// sizes may turn out to be equal at runtime.
2737  bool hasSameTensorSize(Value beforePadding,
2738  tensor::ExtractSliceOp afterTrimming) const {
2739  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2740  // result and CastOp operand.
2741  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2742  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2743  return true;
2744 
2745  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2746  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2747  // Only RankedTensorType supported.
2748  if (!t1 || !t2)
2749  return false;
2750  // Rank of both values must be the same.
2751  if (t1.getRank() != t2.getRank())
2752  return false;
2753 
2754  // All static dimensions must be the same. Mixed cases (e.g., dimension
2755  // static in `t1` but dynamic in `t2`) are not supported.
2756  for (unsigned i = 0; i < t1.getRank(); ++i) {
2757  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2758  return false;
2759  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2760  return false;
2761  }
2762 
2763  // Nothing more to check if all dimensions are static.
2764  if (t1.getNumDynamicDims() == 0)
2765  return true;
2766 
2767  // All dynamic sizes must be the same. The only supported case at the
2768  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2769  // thereof).
2770 
2771  // Apart from CastOp, only ExtractSliceOp is supported.
2772  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2773  if (!beforeSlice)
2774  return false;
2775 
2776  assert(static_cast<size_t>(t1.getRank()) ==
2777  beforeSlice.getMixedSizes().size());
2778  assert(static_cast<size_t>(t2.getRank()) ==
2779  afterTrimming.getMixedSizes().size());
2780 
2781  for (unsigned i = 0; i < t1.getRank(); ++i) {
2782  // Skip static dimensions.
2783  if (!t1.isDynamicDim(i))
2784  continue;
2785  auto size1 = beforeSlice.getMixedSizes()[i];
2786  auto size2 = afterTrimming.getMixedSizes()[i];
2787 
2788  // Case 1: Same value or same constant int.
2789  if (isEqualConstantIntOrValue(size1, size2))
2790  continue;
2791 
2792  // Other cases: Take a deeper look at defining ops of values.
2793  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2794  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2795  if (!v1 || !v2)
2796  return false;
2797 
2798  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2799  // CSE is run.)
2800  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2801  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2802  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2803  minOp1.getOperands() == minOp2.getOperands())
2804  continue;
2805 
2806  // Add additional cases as needed.
2807  }
2808 
2809  // All tests passed.
2810  return true;
2811  }
2812 };
2813 
2814 /// Returns the effective Pad value for the input op, provided it's a scalar.
2815 ///
2816 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2817 /// this Op performs padding, retrieve the padding value provided that it's
2818 /// a scalar and static/fixed for all the padded values. Returns an empty value
2819 /// otherwise.
2820 ///
2821 /// TODO: This is used twice (when checking vectorization pre-conditions and
2822 /// when vectorizing). Cache results instead of re-running.
2824  if (!op)
2825  return {};
2826 
2827  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2828  // being broadcast, provided that it's a scalar.
2829  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2830  auto source = bcast.getSource();
2831  if (llvm::dyn_cast<VectorType>(source.getType()))
2832  return {};
2833 
2834  return source;
2835  }
2836 
2837  // 2. linalg.fill - use the scalar input value that used to fill the output
2838  // tensor.
2839  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2840  return fill.getInputs()[0];
2841  }
2842 
2843  // 3. tensor.generateOp - can't guarantee the value is fixed without
2844  // analysing, bail out.
2845  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2846  return {};
2847  }
2848 
2849  // 4. vector.transfer_write - inspect the input vector that's written from. If
2850  // if contains a single value that has been broadcast (e.g. via
2851  // vector.broadcast), extract it, fail otherwise.
2852  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2853  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2854 
2855  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2856  // than the input tensor, then, provided it's constant, we'll extract the
2857  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2858  // TODO: Clarify the semantics when the input tensor is larger than the
2859  // destination.
2860  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2861  return getStaticPadVal(slice.getDest().getDefiningOp());
2862 
2863  return {};
2864 }
2865 
2866 static LogicalResult
2867 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2868  ArrayRef<int64_t> inputVectorSizes,
2869  SmallVectorImpl<Value> &newResults) {
2870  // TODO: Introduce a parent class that will handle the insertion point update.
2871  OpBuilder::InsertionGuard g(rewriter);
2872  rewriter.setInsertionPoint(sliceOp);
2873 
2874  TypedValue<RankedTensorType> source = sliceOp.getSource();
2875  auto sourceType = source.getType();
2876  auto resultType = sliceOp.getResultType();
2877 
2878  Value padValue = getStaticPadVal(sliceOp);
2879 
2880  if (!padValue) {
2881  auto elemType = sourceType.getElementType();
2882  padValue = rewriter.create<arith::ConstantOp>(
2883  sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2884  }
2885 
2886  // 2. Get the vector shape and in-bounds attributes
2887  SmallVector<int64_t> vecShape;
2888  SmallVector<bool> readInBounds;
2889  SmallVector<bool> writeInBounds;
2890  size_t rankDiff = resultType.getRank() - sourceType.getRank();
2891  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2892  if (!inputVectorSizes.empty()) {
2893  vecShape.push_back(inputVectorSizes[i]);
2894  readInBounds.push_back(false);
2895  writeInBounds.push_back(false);
2896  } else if (!sourceType.isDynamicDim(i)) {
2897  vecShape.push_back(sourceType.getDimSize(i));
2898  // Source shape is statically known: Neither read nor write are
2899  // out-of-bounds.
2900  readInBounds.push_back(true);
2901  writeInBounds.push_back(true);
2902  } else if (!resultType.isDynamicDim(i)) {
2903  // Source shape is not statically known, but result shape is.
2904  // Vectorize with size of result shape. This may be larger than the
2905  // source size.
2906  // FIXME: Using rankDiff implies that the source tensor is inserted at
2907  // the end of the destination tensor. However, that's not required.
2908  vecShape.push_back(resultType.getDimSize(rankDiff + i));
2909  // Read may be out-of-bounds because the result size could be larger
2910  // than the source size.
2911  readInBounds.push_back(false);
2912  // Write will be in-bounds provided that the corresponding write idx is 0.
2913  // To keep this logic simple, conservatively mark as out-of-bounds.
2914  writeInBounds.push_back(false);
2915  } else {
2916  // Neither source nor result dim of padOp is static. Cannot vectorize
2917  // the copy.
2918  // TODO: Add support for masking
2919  return failure();
2920  }
2921  }
2922  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2923 
2924  // 3. Generate TransferReadOp + TransferWriteOp
2925  ReifiedRankedShapedTypeDims reifiedSrcSizes;
2926  Value maskOp;
2927 
2928  // If vector sizes are user provided, make sure to mask. First, generate the
2929  // mask.
2930  if (!inputVectorSizes.empty()) {
2931  auto *srcDefOp = source.getDefiningOp();
2932  if (!srcDefOp) {
2933  LDBG("Unable to get the defining Op of " << sliceOp);
2934  return failure();
2935  }
2936 
2937  LogicalResult status =
2938  cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
2939  rewriter, reifiedSrcSizes);
2940  if (status.failed()) {
2941  LDBG("Unable to reify result shapes of " << srcDefOp);
2942  return failure();
2943  }
2944 
2945  // Create the mask
2946  auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
2947  maskOp = rewriter.create<vector::CreateMaskOp>(
2948  sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
2949  }
2950 
2951  SmallVector<Value> readIndices(
2952  vecType.getRank(),
2953  rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2954  Operation *read = rewriter.create<vector::TransferReadOp>(
2955  sliceOp.getLoc(), vecType, source, readIndices, padValue,
2956  ArrayRef<bool>{readInBounds});
2957 
2958  if (maskOp) {
2959  read = mlir::vector::maskOperation(rewriter, read, maskOp);
2960  }
2961 
2962  auto writeIndices = getValueOrCreateConstantIndexOp(
2963  rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2964 
2965  Operation *write = rewriter.create<vector::TransferWriteOp>(
2966  sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
2967  ArrayRef<bool>{writeInBounds});
2968 
2969  if (maskOp) {
2970  write = mlir::vector::maskOperation(rewriter, write, maskOp);
2971  }
2972 
2973  // 4. Finalize
2974  newResults.push_back(write->getResult(0));
2975 
2976  return success();
2977 }
2978 
2979 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2980 /// ```
2981 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2982 /// %r = tensor.insert_slice %0
2983 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2984 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2985 /// ```
2986 /// is rewritten to:
2987 /// ```
2988 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2989 /// : tensor<?x?xf32>, vector<17x5xf32>
2990 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2991 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2992 /// ```
2993 ///
2994 /// This rewrite is possible if:
2995 /// - Low padding is static 0.
2996 /// - `padOp` result shape is static.
2997 /// - The entire padded tensor is inserted.
2998 /// (Implies that sizes of `insertOp` are all static.)
2999 /// - Only unit strides in `insertOp`.
3000 /// - Single, scalar padding value.
3001 /// - `padOp` result not used as destination.
3003  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3005  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3006 
3007  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3008  tensor::InsertSliceOp insertOp) const override {
3009  // Low padding must be static 0.
3010  if (!padOp.hasZeroLowPad())
3011  return failure();
3012  // Only unit stride supported.
3013  if (!insertOp.hasUnitStride())
3014  return failure();
3015  // Pad value must be a constant.
3016  auto padValue = padOp.getConstantPaddingValue();
3017  if (!padValue)
3018  return failure();
3019  // Dynamic shapes not supported.
3020  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3021  return failure();
3022  // Pad result not used as destination.
3023  if (insertOp.getDest() == padOp.getResult())
3024  return failure();
3025 
3026  auto vecType = VectorType::get(padOp.getType().getShape(),
3027  padOp.getType().getElementType());
3028  unsigned vecRank = vecType.getRank();
3029  unsigned tensorRank = insertOp.getType().getRank();
3030 
3031  // Check if sizes match: Insert the entire tensor into most minor dims.
3032  // (No permutations allowed.)
3033  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3034  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3035  if (!llvm::all_of(
3036  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3037  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3038  }))
3039  return failure();
3040 
3041  // Insert the TransferReadOp and TransferWriteOp at the position of the
3042  // InsertSliceOp.
3043  rewriter.setInsertionPoint(insertOp);
3044 
3045  // Generate TransferReadOp: Read entire source tensor and add high
3046  // padding.
3047  SmallVector<Value> readIndices(
3048  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
3049  auto read = rewriter.create<vector::TransferReadOp>(
3050  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3051 
3052  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3053  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3054  // source must fit into the destination at the specified offsets.
3055  auto writeIndices = getValueOrCreateConstantIndexOp(
3056  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3057  SmallVector<bool> inBounds(vecRank, true);
3058  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3059  insertOp, read, insertOp.getDest(), writeIndices,
3060  ArrayRef<bool>{inBounds});
3061 
3062  return success();
3063  }
3064 };
3065 
3067  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3071  patterns.getContext(), baseBenefit.getBenefit() + 1);
3072 }
3073 
3074 //----------------------------------------------------------------------------//
3075 // Forwarding patterns
3076 //----------------------------------------------------------------------------//
3077 
3078 /// Check whether there is any interleaved use of any `values` between
3079 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3080 /// is in a different block.
3081 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3082  ValueRange values) {
3083  if (firstOp->getBlock() != secondOp->getBlock() ||
3084  !firstOp->isBeforeInBlock(secondOp)) {
3085  LDBG("interleavedUses precondition failed, firstOp: "
3086  << *firstOp << ", second op: " << *secondOp << "\n");
3087  return true;
3088  }
3089  for (auto v : values) {
3090  for (auto &u : v.getUses()) {
3091  Operation *owner = u.getOwner();
3092  if (owner == firstOp || owner == secondOp)
3093  continue;
3094  // TODO: this is too conservative, use dominance info in the future.
3095  if (owner->getBlock() == firstOp->getBlock() &&
3096  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3097  continue;
3098  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
3099  << ", second op: " << *secondOp << "\n");
3100  return true;
3101  }
3102  }
3103  return false;
3104 }
3105 
3106 /// Return the unique subview use of `v` if it is indeed unique, null
3107 /// otherwise.
3108 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3109  memref::SubViewOp subViewOp;
3110  for (auto &u : v.getUses()) {
3111  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3112  if (subViewOp)
3113  return memref::SubViewOp();
3114  subViewOp = newSubViewOp;
3115  }
3116  }
3117  return subViewOp;
3118 }
3119 
3120 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3121 /// when available.
3123  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3124 
3125  // TODO: support mask.
3126  if (xferOp.getMask())
3127  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3128 
3129  // Transfer into `view`.
3130  Value viewOrAlloc = xferOp.getBase();
3131  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3132  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3133  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3134 
3135  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3136  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3137  if (!subViewOp)
3138  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3139  Value subView = subViewOp.getResult();
3140 
3141  // Find the copy into `subView` without interleaved uses.
3142  memref::CopyOp copyOp;
3143  for (auto &u : subView.getUses()) {
3144  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3145  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3146  if (newCopyOp.getTarget() != subView)
3147  continue;
3148  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3149  continue;
3150  copyOp = newCopyOp;
3151  break;
3152  }
3153  }
3154  if (!copyOp)
3155  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3156 
3157  // Find the fill into `viewOrAlloc` without interleaved uses before the
3158  // copy.
3159  FillOp maybeFillOp;
3160  for (auto &u : viewOrAlloc.getUses()) {
3161  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3162  assert(isa<MemRefType>(newFillOp.output().getType()));
3163  if (newFillOp.output() != viewOrAlloc)
3164  continue;
3165  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3166  continue;
3167  maybeFillOp = newFillOp;
3168  break;
3169  }
3170  }
3171  // Ensure padding matches.
3172  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3173  return rewriter.notifyMatchFailure(xferOp,
3174  "padding value does not match fill");
3175 
3176  // `in` is the subview that memref.copy reads. Replace it.
3177  Value in = copyOp.getSource();
3178 
3179  // memref.copy + linalg.fill can be used to create a padded local buffer.
3180  // The `masked` attribute is only valid on this padded buffer.
3181  // When forwarding to vector.transfer_read, the attribute must be reset
3182  // conservatively.
3183  auto vectorType = xferOp.getVectorType();
3184  Value res = rewriter.create<vector::TransferReadOp>(
3185  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3186  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3187  rewriter.getBoolArrayAttr(
3188  SmallVector<bool>(vectorType.getRank(), false)));
3189 
3190  if (maybeFillOp)
3191  rewriter.eraseOp(maybeFillOp);
3192  rewriter.eraseOp(copyOp);
3193  rewriter.replaceOp(xferOp, res);
3194 
3195  return success();
3196 }
3197 
3198 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3199 /// when available.
3201  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3202  // TODO: support mask.
3203  if (xferOp.getMask())
3204  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3205 
3206  // Transfer into `viewOrAlloc`.
3207  Value viewOrAlloc = xferOp.getBase();
3208  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3209  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3210  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3211 
3212  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3213  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3214  if (!subViewOp)
3215  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3216  Value subView = subViewOp.getResult();
3217 
3218  // Find the copy from `subView` without interleaved uses.
3219  memref::CopyOp copyOp;
3220  for (auto &u : subViewOp.getResult().getUses()) {
3221  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3222  if (newCopyOp.getSource() != subView)
3223  continue;
3224  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3225  continue;
3226  copyOp = newCopyOp;
3227  break;
3228  }
3229  }
3230  if (!copyOp)
3231  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3232 
3233  // `out` is the subview copied into that we replace.
3234  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3235  Value out = copyOp.getTarget();
3236 
3237  // Forward vector.transfer into copy.
3238  // memref.copy + linalg.fill can be used to create a padded local buffer.
3239  // The `masked` attribute is only valid on this padded buffer.
3240  // When forwarding to vector.transfer_write, the attribute must be reset
3241  // conservatively.
3242  auto vector = xferOp.getVector();
3243  rewriter.create<vector::TransferWriteOp>(
3244  xferOp.getLoc(), vector, out, xferOp.getIndices(),
3245  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3247  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3248 
3249  rewriter.eraseOp(copyOp);
3250  rewriter.eraseOp(xferOp);
3251 
3252  return success();
3253 }
3254 
3255 //===----------------------------------------------------------------------===//
3256 // Convolution vectorization patterns
3257 //===----------------------------------------------------------------------===//
3258 
3259 template <int N>
3260 static void bindShapeDims(ShapedType shapedType) {}
3261 
3262 template <int N, typename IntTy, typename... IntTy2>
3263 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3264  val = shapedType.getShape()[N];
3265  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3266 }
3267 
3268 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3269 template <typename... IntTy>
3270 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3271  bindShapeDims<0>(shapedType, vals...);
3272 }
3273 
3274 namespace {
3275 /// Generate a vector implementation for either:
3276 /// ```
3277 /// Op def: ( w, kw )
3278 /// Iters: ({Par(), Red()})
3279 /// Layout: {{w + kw}, {kw}, {w}}
3280 /// ```
3281 /// kw is unrolled.
3282 ///
3283 /// or
3284 ///
3285 /// ```
3286 /// Op def: ( n, w, c, kw, f )
3287 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3288 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3289 /// ```
3290 /// kw is unrolled, w is unrolled iff dilationW > 1.
3291 ///
3292 /// or
3293 ///
3294 /// ```
3295 /// Op def: ( n, c, w, f, kw )
3296 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3297 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3298 /// ```
3299 /// kw is unrolled, w is unrolled iff dilationW > 1.
3300 ///
3301 /// or
3302 ///
3303 /// ```
3304 /// Op def: ( n, w, c, kw )
3305 /// Iters: ({Par(), Par(), Par(), Red()})
3306 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3307 /// ```
3308 /// kw is unrolled, w is unrolled iff dilationW > 1.
3309 struct Conv1DGenerator
3310  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3311  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3312  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3313 
3314  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3315  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3316  resShaped = linalgOp.getDpsInitOperand(0)->get();
3317  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3318  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3319  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3320 
3321  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3322  redOp = reduceOp->getName().getIdentifier();
3323 
3324  setConvOperationKind(reduceOp);
3325 
3326  auto maybeKind = getCombinerOpKind(reduceOp);
3327  reductionKind = maybeKind.value();
3328 
3329  // The ConvolutionOpInterface gives us guarantees of existence for
3330  // strides/dilations. However, we do not need to rely on those, we can
3331  // simply use them if present, otherwise use the default and let the generic
3332  // conv. matcher in the ConvGenerator succeed or fail.
3333  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3334  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3335  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3336  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3337  }
3338 
3339  /// Generate a vector implementation for:
3340  /// ```
3341  /// Op def: ( w, kw )
3342  /// Iters: ({Par(), Red()})
3343  /// Layout: {{w + kw}, {kw}, {w}}
3344  /// ```
3345  /// kw is always unrolled.
3346  ///
3347  /// or
3348  ///
3349  /// ```
3350  /// Op def: ( n, w, c, kw, f )
3351  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3352  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3353  /// ```
3354  /// kw is always unrolled.
3355  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3356  /// > 1.
3357  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3358  int64_t nSize, wSize, cSize, kwSize, fSize;
3359  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3360  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3361  switch (conv1DOpOrder) {
3362  case Conv1DOpOrder::W:
3363  // Initialize unused dimensions
3364  nSize = fSize = cSize = 0;
3365  // out{W}
3366  bindShapeDims(resShapedType, wSize);
3367  // kernel{kw}
3368  bindShapeDims(rhsShapedType, kwSize);
3369  lhsShape = {// iw = ow + kw - 1
3370  // (i.e. 16 convolved with 3 -> 14)
3371  (wSize + kwSize - 1)};
3372  rhsShape = {kwSize};
3373  resShape = {wSize};
3374  break;
3375  case Conv1DOpOrder::Nwc:
3376  // out{n, w, f}
3377  bindShapeDims(resShapedType, nSize, wSize, fSize);
3378  switch (oper) {
3379  case ConvOperationKind::Conv:
3380  // kernel{kw, c, f}
3381  bindShapeDims(rhsShapedType, kwSize, cSize);
3382  break;
3383  case ConvOperationKind::Pool:
3384  // kernel{kw}
3385  bindShapeDims(rhsShapedType, kwSize);
3386  cSize = fSize;
3387  break;
3388  }
3389  lhsShape = {nSize,
3390  // iw = ow * sw + kw * dw - 1
3391  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3392  // Perform the proper inclusive -> exclusive -> inclusive.
3393  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3394  1,
3395  cSize};
3396  switch (oper) {
3397  case ConvOperationKind::Conv:
3398  rhsShape = {kwSize, cSize, fSize};
3399  break;
3400  case ConvOperationKind::Pool:
3401  rhsShape = {kwSize};
3402  break;
3403  }
3404  resShape = {nSize, wSize, fSize};
3405  break;
3406  case Conv1DOpOrder::Ncw:
3407  // out{n, f, w}
3408  bindShapeDims(resShapedType, nSize, fSize, wSize);
3409  switch (oper) {
3410  case ConvOperationKind::Conv:
3411  // kernel{f, c, kw}
3412  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3413  break;
3414  case ConvOperationKind::Pool:
3415  // kernel{kw}
3416  bindShapeDims(rhsShapedType, kwSize);
3417  cSize = fSize;
3418  break;
3419  }
3420  lhsShape = {nSize, cSize,
3421  // iw = ow * sw + kw * dw - 1
3422  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3423  // Perform the proper inclusive -> exclusive -> inclusive.
3424  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3425  1};
3426  switch (oper) {
3427  case ConvOperationKind::Conv:
3428  rhsShape = {fSize, cSize, kwSize};
3429  break;
3430  case ConvOperationKind::Pool:
3431  rhsShape = {kwSize};
3432  break;
3433  }
3434  resShape = {nSize, fSize, wSize};
3435  break;
3436  }
3437 
3438  vector::TransferWriteOp write;
3439  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3440 
3441  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3442  // When strideW == 1, we can batch the contiguous loads and avoid
3443  // unrolling
3444  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3445 
3446  Type lhsEltType = lhsShapedType.getElementType();
3447  Type rhsEltType = rhsShapedType.getElementType();
3448  Type resEltType = resShapedType.getElementType();
3449  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3450  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3451  auto resType = VectorType::get(resShape, resEltType);
3452  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3453  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3454  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3455  SmallVector<Value> resPadding(resShape.size(), zero);
3456 
3457  // Read the whole lhs, rhs and res in one shot (with zero padding).
3458  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3459  lhsPadding);
3460  // This is needed only for Conv.
3461  Value rhs = nullptr;
3462  if (oper == ConvOperationKind::Conv)
3463  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3464  rhsPadding);
3465  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3466  resPadding);
3467 
3468  // The base vectorization case for channeled convolution is input:
3469  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3470  // vectorization case, we do pre transpose on input, weight, and output.
3471  switch (conv1DOpOrder) {
3472  case Conv1DOpOrder::W:
3473  case Conv1DOpOrder::Nwc:
3474  // Base case, so no transposes necessary.
3475  break;
3476  case Conv1DOpOrder::Ncw: {
3477  // To match base vectorization case, we pre-transpose current case.
3478  // ncw -> nwc
3479  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3480  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3481  // fcw -> wcf
3482  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3483 
3484  // This is needed only for Conv.
3485  if (oper == ConvOperationKind::Conv)
3486  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3487  // nfw -> nwf
3488  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3489  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3490  break;
3491  }
3492  }
3493 
3494  //===------------------------------------------------------------------===//
3495  // Begin vector-only rewrite part
3496  //===------------------------------------------------------------------===//
3497  // Unroll along kw and read slices of lhs and rhs.
3498  SmallVector<Value> lhsVals, rhsVals, resVals;
3499  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3500  kwSize, strideW, dilationW, wSizeStep,
3501  isSingleChanneled);
3502  // Do not do for pooling.
3503  if (oper == ConvOperationKind::Conv)
3504  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3505  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3506  wSizeStep, isSingleChanneled);
3507 
3508  auto linearIndex = [&](int64_t kw, int64_t w) {
3509  return kw * (wSize / wSizeStep) + w;
3510  };
3511 
3512  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3513  // or perform outerproduct for non-channeled convolution or perform simple
3514  // arith operation for pooling
3515  for (int64_t kw = 0; kw < kwSize; ++kw) {
3516  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3517  switch (oper) {
3518  case ConvOperationKind::Conv:
3519  if (isSingleChanneled) {
3520  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3521  lhsVals[linearIndex(kw, w)],
3522  rhsVals[kw], resVals[w]);
3523  } else {
3524  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3525  lhsVals[linearIndex(kw, w)],
3526  rhsVals[kw], resVals[w]);
3527  }
3528  break;
3529  case ConvOperationKind::Pool:
3530  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3531  resVals[w]);
3532  break;
3533  }
3534  }
3535  }
3536 
3537  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3538  isSingleChanneled);
3539  //===------------------------------------------------------------------===//
3540  // End vector-only rewrite part
3541  //===------------------------------------------------------------------===//
3542 
3543  // The base vectorization case for channeled convolution is output:
3544  // {n,w,f} To reuse the result from base pattern vectorization case, we
3545  // post transpose the base case result.
3546  switch (conv1DOpOrder) {
3547  case Conv1DOpOrder::W:
3548  case Conv1DOpOrder::Nwc:
3549  // Base case, so no transposes necessary.
3550  break;
3551  case Conv1DOpOrder::Ncw: {
3552  // nwf -> nfw
3553  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3554  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3555  break;
3556  }
3557  }
3558 
3559  return rewriter
3560  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3561  .getOperation();
3562  }
3563 
3564  // Take a value and widen to have the same element type as `ty`.
3565  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3566  const Type srcElementType = getElementTypeOrSelf(val.getType());
3567  const Type dstElementType = getElementTypeOrSelf(ty);
3568  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3569  if (srcElementType == dstElementType)
3570  return val;
3571 
3572  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3573  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3574  const Type dstType =
3575  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3576 
3577  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3578  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3579  }
3580 
3581  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3582  srcWidth < dstWidth)
3583  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3584 
3585  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3586  srcWidth < dstWidth)
3587  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3588 
3589  assert(false && "unhandled promotion case");
3590  return nullptr;
3591  }
3592 
3593  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3594  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3595  Value lhs, Value rhs, Value res) {
3596  vector::IteratorType par = vector::IteratorType::parallel;
3597  vector::IteratorType red = vector::IteratorType::reduction;
3598  AffineExpr n, w, f, c;
3599  bindDims(ctx, n, w, f, c);
3600  lhs = promote(rewriter, loc, lhs, res.getType());
3601  rhs = promote(rewriter, loc, rhs, res.getType());
3602  auto contrationOp = rewriter.create<vector::ContractionOp>(
3603  loc, lhs, rhs, res,
3604  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3605  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3606  contrationOp.setKind(reductionKind);
3607  return contrationOp;
3608  }
3609 
3610  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3611  // convolution.
3612  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3613  Value lhs, Value rhs, Value res) {
3614  return rewriter.create<vector::OuterProductOp>(
3615  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3616  }
3617 
3618  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3619  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3620  Value res) {
3621  if (isPoolExt)
3622  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3623  return rewriter
3624  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3625  ->getResult(0);
3626  }
3627 
3628  /// Generate a vector implementation for:
3629  /// ```
3630  /// Op def: ( n, w, c, kw)
3631  /// Iters: ({Par(), Par(), Par(), Red()})
3632  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3633  /// ```
3634  /// kw is always unrolled.
3635  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3636  /// > 1.
3637  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3638  bool channelDimScalableFlag,
3639  bool flatten) {
3640  bool scalableChDim = false;
3641  bool useMasking = false;
3642  int64_t nSize, wSize, cSize, kwSize;
3643  // kernel{kw, c}
3644  bindShapeDims(rhsShapedType, kwSize, cSize);
3645  if (ShapedType::isDynamic(cSize)) {
3646  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3647  cSize = channelDimVecSize;
3648  // Scalable vectors are only used when both conditions are met:
3649  // 1. channel dim is dynamic
3650  // 2. channelDimScalableFlag is set
3651  scalableChDim = channelDimScalableFlag;
3652  useMasking = true;
3653  }
3654 
3655  assert(!(useMasking && flatten) &&
3656  "Unsupported flattened conv with dynamic shapes");
3657 
3658  // out{n, w, c}
3659  bindShapeDims(resShapedType, nSize, wSize);
3660 
3661  vector::TransferWriteOp write;
3662  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3663 
3664  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3665  // When strideW == 1, we can batch the contiguous loads and avoid
3666  // unrolling
3667  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3668 
3669  Type lhsEltType = lhsShapedType.getElementType();
3670  Type rhsEltType = rhsShapedType.getElementType();
3671  Type resEltType = resShapedType.getElementType();
3672  VectorType lhsType = VectorType::get(
3673  {nSize,
3674  // iw = ow * sw + kw * dw - 1
3675  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3676  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3677  cSize},
3678  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3679  VectorType rhsType =
3680  VectorType::get({kwSize, cSize}, rhsEltType,
3681  /*scalableDims=*/{false, scalableChDim});
3682  VectorType resType =
3683  VectorType::get({nSize, wSize, cSize}, resEltType,
3684  /*scalableDims=*/{false, false, scalableChDim});
3685 
3686  // Masks the input xfer Op along the channel dim, iff the corresponding
3687  // scalable flag is set.
3688  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3689  ArrayRef<bool> scalableDims,
3690  Operation *opToMask) {
3691  if (!useMasking)
3692  return opToMask;
3693  auto maskType =
3694  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3695 
3696  SmallVector<bool> inBounds(maskShape.size(), true);
3697  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3698  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3699  rewriter.getBoolArrayAttr(inBounds));
3700 
3702  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3703 
3704  Value maskOp =
3705  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3706 
3707  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3708  };
3709 
3710  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3711  // 0].
3712  Value lhs = rewriter.create<vector::TransferReadOp>(
3713  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3714  auto maybeMaskedLhs = maybeMaskXferOp(
3715  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3716 
3717  // Read rhs slice of size {kw, c} @ [0, 0].
3718  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3719  ValueRange{zero, zero});
3720  auto maybeMaskedRhs = maybeMaskXferOp(
3721  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3722 
3723  // Read res slice of size {n, w, c} @ [0, 0, 0].
3724  Value res = rewriter.create<vector::TransferReadOp>(
3725  loc, resType, resShaped, ValueRange{zero, zero, zero});
3726  auto maybeMaskedRes = maybeMaskXferOp(
3727  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3728 
3729  //===------------------------------------------------------------------===//
3730  // Begin vector-only rewrite part
3731  //===------------------------------------------------------------------===//
3732  // Unroll along kw and read slices of lhs and rhs.
3733  SmallVector<Value> lhsVals, rhsVals, resVals;
3734  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3735  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3736 
3737  // Extract lhs slice of size {n, wSizeStep, c}
3738  // @ [0, sw * w + dw * kw, 0].
3739  for (int64_t kw = 0; kw < kwSize; ++kw) {
3740  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3741  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3742  loc, maybeMaskedLhs->getResult(0),
3743  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3744  inOutSliceSizes, inOutStrides));
3745  }
3746  }
3747  // Extract rhs slice of size {c} @ [kw].
3748  for (int64_t kw = 0; kw < kwSize; ++kw) {
3749  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3750  loc, maybeMaskedRhs->getResult(0),
3751  /*offsets=*/ArrayRef<int64_t>{kw}));
3752  }
3753  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3754  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3755  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3756  loc, maybeMaskedRes->getResult(0),
3757  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3758  inOutStrides));
3759  }
3760 
3761  auto linearIndex = [&](int64_t kw, int64_t w) {
3762  return kw * (wSize / wSizeStep) + w;
3763  };
3764 
3765  // Note - the scalable flags are ignored as flattening combined with
3766  // scalable vectorization is not supported.
3767  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3768  auto lhsTypeAfterFlattening =
3769  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3770  auto resTypeAfterFlattening =
3771  VectorType::get(inOutFlattenSliceSizes, resEltType);
3772 
3773  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3774  for (int64_t kw = 0; kw < kwSize; ++kw) {
3775  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3776  Value lhsVal = lhsVals[linearIndex(kw, w)];
3777  Value resVal = resVals[w];
3778  if (flatten) {
3779  // Flatten the input and output vectors (collapse the channel
3780  // dimension)
3781  lhsVal = rewriter.create<vector::ShapeCastOp>(
3782  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3783  resVal = rewriter.create<vector::ShapeCastOp>(
3784  loc, resTypeAfterFlattening, resVals[w]);
3785  }
3786  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3787  rhsVals[kw], resVal, flatten);
3788  if (flatten) {
3789  // Un-flatten the output vector (restore the channel dimension)
3790  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3791  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3792  }
3793  }
3794  }
3795 
3796  // Its possible we failed to create the Fma.
3797  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3798  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3799  for (auto &collection :
3800  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3801  for (Value v : collection)
3802  rewriter.eraseOp(v.getDefiningOp());
3803  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3804  }
3805 
3806  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3807  // This does not depend on kw.
3808  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3809  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3810  loc, resVals[w], maybeMaskedRes->getResult(0),
3811  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3812  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3813  }
3814  //===------------------------------------------------------------------===//
3815  // End vector-only rewrite part
3816  //===------------------------------------------------------------------===//
3817 
3818  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3819  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3820  loc, maybeMaskedRes->getResult(0), resShaped,
3821  ValueRange{zero, zero, zero});
3822  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3823  resOut);
3824  }
3825 
3826  /// Lower:
3827  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3828  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3829  /// to MulAcc.
3830  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3831  Value lhs, Value rhs, Value res,
3832  bool flatten) {
3833  auto rhsTy = cast<ShapedType>(rhs.getType());
3834  auto resTy = cast<ShapedType>(res.getType());
3835 
3836  // TODO(suderman): Change this to use a vector.ima intrinsic.
3837  lhs = promote(rewriter, loc, lhs, resTy);
3838 
3839  if (flatten) {
3840  // NOTE: This following logic won't work for scalable vectors. For this
3841  // reason, "flattening" is not supported when shapes are dynamic (this
3842  // should be captured by one of the pre-conditions).
3843 
3844  // There are two options for handling the filter:
3845  // * shape_cast(broadcast(filter))
3846  // * broadcast(shuffle(filter))
3847  // Opt for the option without shape_cast to simplify the codegen.
3848  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3849  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3850 
3851  SmallVector<int64_t, 16> indices;
3852  for (int i = 0; i < resSize / rhsSize; ++i) {
3853  for (int j = 0; j < rhsSize; ++j)
3854  indices.push_back(j);
3855  }
3856 
3857  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3858  }
3859  // Broadcast the filter to match the output vector
3860  rhs = rewriter.create<vector::BroadcastOp>(
3861  loc, resTy.clone(rhsTy.getElementType()), rhs);
3862 
3863  rhs = promote(rewriter, loc, rhs, resTy);
3864 
3865  if (!lhs || !rhs)
3866  return nullptr;
3867 
3868  if (isa<FloatType>(resTy.getElementType()))
3869  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3870 
3871  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3872  return rewriter.create<arith::AddIOp>(loc, mul, res);
3873  }
3874 
3875  /// Entry point for non-channeled convolution:
3876  /// {{w + kw}, {kw}, {w}}
3877  FailureOr<Operation *> generateNonChanneledConv() {
3878  AffineExpr w, kw;
3879  bindDims(ctx, w, kw);
3880  if (!iters({Par(), Red()}))
3881  return rewriter.notifyMatchFailure(op,
3882  "failed to match conv::W 1-par 1-red");
3883 
3884  // No transposition needed.
3885  if (layout({/*lhsIndex*/ {w + kw},
3886  /*rhsIndex*/ {kw},
3887  /*resIndex*/ {w}}))
3888  return conv(Conv1DOpOrder::W);
3889 
3890  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3891  }
3892 
3893  /// Entry point that transposes into the common form:
3894  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3895  FailureOr<Operation *> generateNwcConv() {
3896  AffineExpr n, w, f, kw, c;
3897  bindDims(ctx, n, w, f, kw, c);
3898  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3899  return rewriter.notifyMatchFailure(
3900  op, "failed to match conv::Nwc 3-par 2-red");
3901 
3902  // No transposition needed.
3903  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3904  /*rhsIndex*/ {kw, c, f},
3905  /*resIndex*/ {n, w, f}}))
3906  return conv(Conv1DOpOrder::Nwc);
3907 
3908  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3909  }
3910 
3911  /// Entry point that transposes into the common form:
3912  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3913  FailureOr<Operation *> generateNcwConv() {
3914  AffineExpr n, w, f, kw, c;
3915  bindDims(ctx, n, f, w, c, kw);
3916  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3917  return rewriter.notifyMatchFailure(
3918  op, "failed to match conv::Ncw 3-par 2-red");
3919 
3920  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3921  /*rhsIndex*/ {f, c, kw},
3922  /*resIndex*/ {n, f, w}}))
3923  return conv(Conv1DOpOrder::Ncw);
3924 
3925  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3926  }
3927 
3928  /// Entry point that transposes into the common form:
3929  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3930  FailureOr<Operation *> generateNwcPooling() {
3931  AffineExpr n, w, c, kw;
3932  bindDims(ctx, n, w, c, kw);
3933  if (!iters({Par(), Par(), Par(), Red()}))
3934  return rewriter.notifyMatchFailure(op,
3935  "failed to match pooling 3-par 1-red");
3936 
3937  // No transposition needed.
3938  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3939  /*rhsIndex*/ {kw},
3940  /*resIndex*/ {n, w, c}}))
3941  return conv(Conv1DOpOrder::Nwc);
3942 
3943  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3944  }
3945 
3946  /// Entry point that transposes into the common form:
3947  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3948  FailureOr<Operation *> generateNcwPooling() {
3949  AffineExpr n, w, c, kw;
3950  bindDims(ctx, n, c, w, kw);
3951  if (!iters({Par(), Par(), Par(), Red()}))
3952  return rewriter.notifyMatchFailure(op,
3953  "failed to match pooling 3-par 1-red");
3954 
3955  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3956  /*rhsIndex*/ {kw},
3957  /*resIndex*/ {n, c, w}}))
3958  return conv(Conv1DOpOrder::Ncw);
3959 
3960  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3961  }
3962 
3963  /// Entry point that transposes into the common form:
3964  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3965  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3966  bool vecChDimScalableFlag = false,
3967  bool flatten = false) {
3968  AffineExpr n, w, c, kw;
3969  bindDims(ctx, n, w, c, kw);
3970  if (!iters({Par(), Par(), Par(), Red()}))
3971  return rewriter.notifyMatchFailure(
3972  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3973 
3974  // No transposition needed.
3975  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3976  /*rhsIndex*/ {kw, c},
3977  /*resIndex*/ {n, w, c}}))
3978  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3979 
3980  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3981  }
3982 
3983 private:
3984  ConvOperationKind oper = ConvOperationKind::Conv;
3985  StringAttr redOp;
3986  StringAttr poolExtOp;
3987  bool isPoolExt = false;
3988  int strideW, dilationW;
3989  Value lhsShaped, rhsShaped, resShaped;
3990  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3991  vector::CombiningKind reductionKind;
3992 
3993  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3994  void setConvOperationKind(Operation *reduceOp) {
3995  int numBlockArguments =
3996  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3997  if (numBlockArguments == 1) {
3998  // Will be convolution if feeder is a MulOp.
3999  // A strength reduced version of MulOp for i1 type is AndOp which is also
4000  // supported. Otherwise, it can be pooling. This strength reduction logic
4001  // is in `buildBinaryFn` helper in the Linalg dialect.
4002  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4003  llvm::IsaPred<BlockArgument>);
4004  Operation *feedOp = (*feedValIt).getDefiningOp();
4005  if (isCastOfBlockArgument(feedOp)) {
4006  oper = ConvOperationKind::Pool;
4007  isPoolExt = true;
4008  poolExtOp = feedOp->getName().getIdentifier();
4009  return;
4010  }
4011  oper = ConvOperationKind::Conv;
4012  return;
4013  }
4014  // numBlockArugments == 2 and this is a pooling op.
4015  oper = ConvOperationKind::Pool;
4016  isPoolExt = false;
4017  }
4018 };
4019 } // namespace
4020 
4021 /// Helper function to vectorize a LinalgOp with convolution semantics.
4022 // TODO: extend the generic vectorization to support windows and drop this.
4023 static FailureOr<Operation *> vectorizeConvolution(
4024  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4025  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4026  Conv1DGenerator conv1dGen(rewriter, op);
4027  auto res = conv1dGen.generateNonChanneledConv();
4028  if (succeeded(res))
4029  return res;
4030  res = conv1dGen.generateNwcConv();
4031  if (succeeded(res))
4032  return res;
4033  res = conv1dGen.generateNcwConv();
4034  if (succeeded(res))
4035  return res;
4036  res = conv1dGen.generateNwcPooling();
4037  if (succeeded(res))
4038  return res;
4039  res = conv1dGen.generateNcwPooling();
4040  if (succeeded(res))
4041  return res;
4042 
4043  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4044  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4045  // masked/scalable) is the channel dim (i.e. the trailing dim).
4046  uint64_t vecChDimSize = ShapedType::kDynamic;
4047  bool vecChDimScalableFlag = false;
4048  if (!inputVecSizes.empty()) {
4049  // Only use the input vector size corresponding to the channel dim. Other
4050  // vector dims will be inferred from the Ops.
4051  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4052  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4053  "Not a 1D depthwise conv!");
4054  size_t chDimIdx =
4056  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4057  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4058 
4059  vecChDimSize = inputVecSizes[chDimIdx];
4060  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4061  }
4062  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4063  flatten1DDepthwiseConv);
4064 }
4065 
4068 
4069  LogicalResult matchAndRewrite(LinalgOp op,
4070  PatternRewriter &rewriter) const override {
4071  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4072  if (failed(resultOrFail))
4073  return failure();
4074  Operation *newOp = *resultOrFail;
4075  if (newOp->getNumResults() == 0) {
4076  rewriter.eraseOp(op.getOperation());
4077  return success();
4078  }
4079  assert(newOp->getNumResults() == 1 && "expected single result");
4080  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4081  return success();
4082  }
4083 };
4084 
4087  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4088 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4585
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4584
union mlir::linalg::@1200::ArityGroupAndKind::Kind kind
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4583
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore, Value dest, ArrayRef< int64_t > inputVecSizesForLeadingDims, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:615
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:142
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:604
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:645
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:442
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:666
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:219
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:203
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:242
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:651
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2664
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:812
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:788
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:719
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:332
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.