MLIR  21.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Support/LLVM.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/ADT/iterator_range.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 #include <type_traits>
47 
48 using namespace mlir;
49 using namespace mlir::linalg;
50 
51 #define DEBUG_TYPE "linalg-vectorization"
52 
53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
55 
56 /// Try to vectorize `convOp` as a convolution.
57 static FailureOr<Operation *>
58 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
59  ArrayRef<int64_t> inputVecSizes = {},
60  ArrayRef<bool> inputVecScalableFlags = {},
61  bool flatten1DDepthwiseConv = false);
62 
63 /// Vectorize tensor::InsertSliceOp with:
64 /// * vector::TransferReadOp + vector::TransferWriteOp
65 /// The vector sizes are either:
66 /// * user-provided in `inputVectorSizes`, or
67 /// * inferred from the static dims in the input and output tensors.
68 /// Bails out if:
69 /// * vector sizes are not user-provided, and
70 /// * at least one dim is dynamic (in both the input and output tensors).
71 ///
72 /// Before:
73 /// !t_in_type = tensor<1x2x3xf32>
74 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
75 /// !v_type = vector<1x2x3xf32>
76 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
77 /// into !t_out_type
78 /// After:
79 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
80 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
81 static LogicalResult
82 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
83  ArrayRef<int64_t> inputVectorSizes,
84  SmallVectorImpl<Value> &newResults);
85 
86 /// Returns the effective Pad value for the input op, provided it's a scalar.
87 ///
88 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
89 /// this Op performs padding, retrieve the padding value provided that it's
90 /// a scalar and static/fixed for all the padded values. Returns an empty value
91 /// otherwise.
92 static Value getStaticPadVal(Operation *op);
93 
94 /// Return the unique instance of OpType in `block` if it is indeed unique.
95 /// Return null if none or more than 1 instances exist.
96 template <typename OpType>
97 static OpType getSingleOpOfType(Block &block) {
98  OpType res;
99  block.walk([&](OpType op) {
100  if (res) {
101  res = nullptr;
102  return WalkResult::interrupt();
103  }
104  res = op;
105  return WalkResult::advance();
106  });
107  return res;
108 }
109 
110 /// Helper function to extract the input slices after filter is unrolled along
111 /// kw.
112 static SmallVector<Value>
114  int64_t nSize, int64_t wSize, int64_t cSize,
115  int64_t kwSize, int strideW, int dilationW,
116  int64_t wSizeStep, bool isSingleChanneled) {
117  SmallVector<Value> result;
118  if (isSingleChanneled) {
119  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
120  // convolution.
121  SmallVector<int64_t> sizes = {wSizeStep};
122  SmallVector<int64_t> strides = {1};
123  for (int64_t kw = 0; kw < kwSize; ++kw) {
124  for (int64_t w = 0; w < wSize; w += wSizeStep) {
125  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
126  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
127  }
128  }
129  } else {
130  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
131  // for channeled convolution.
132  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
133  SmallVector<int64_t> strides = {1, 1, 1};
134  for (int64_t kw = 0; kw < kwSize; ++kw) {
135  for (int64_t w = 0; w < wSize; w += wSizeStep) {
136  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
137  loc, input,
138  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
139  sizes, strides));
140  }
141  }
142  }
143  return result;
144 }
145 
146 /// Helper function to extract the filter slices after filter is unrolled along
147 /// kw.
149  Location loc, Value filter,
150  int64_t kwSize) {
151  SmallVector<Value> result;
152  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
153  // non-chanelled convolution] @ [kw].
154  for (int64_t kw = 0; kw < kwSize; ++kw) {
155  result.push_back(rewriter.create<vector::ExtractOp>(
156  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
157  }
158  return result;
159 }
160 
161 /// Helper function to extract the result slices after filter is unrolled along
162 /// kw.
163 static SmallVector<Value>
165  int64_t nSize, int64_t wSize, int64_t fSize,
166  int64_t wSizeStep, bool isSingleChanneled) {
167  SmallVector<Value> result;
168  if (isSingleChanneled) {
169  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
170  SmallVector<int64_t> sizes = {wSizeStep};
171  SmallVector<int64_t> strides = {1};
172  for (int64_t w = 0; w < wSize; w += wSizeStep) {
173  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
174  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
175  }
176  } else {
177  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
178  // convolution.
179  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
180  SmallVector<int64_t> strides = {1, 1, 1};
181  for (int64_t w = 0; w < wSize; w += wSizeStep) {
182  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
183  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
184  }
185  }
186  return result;
187 }
188 
189 /// Helper function to insert the computed result slices.
191  Value res, int64_t wSize, int64_t wSizeStep,
192  SmallVectorImpl<Value> &resVals,
193  bool isSingleChanneled) {
194 
195  if (isSingleChanneled) {
196  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
197  // This does not depend on kw.
198  SmallVector<int64_t> strides = {1};
199  for (int64_t w = 0; w < wSize; w += wSizeStep) {
200  res = rewriter.create<vector::InsertStridedSliceOp>(
201  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
202  }
203  } else {
204  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
205  // convolution. This does not depend on kw.
206  SmallVector<int64_t> strides = {1, 1, 1};
207  for (int64_t w = 0; w < wSize; w += wSizeStep) {
208  res = rewriter.create<vector::InsertStridedSliceOp>(
209  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
210  strides);
211  }
212  }
213  return res;
214 }
215 
216 /// Contains the vectorization state and related methods used across the
217 /// vectorization process of a given operation.
219  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
220 
221  /// Initializes the vectorization state, including the computation of the
222  /// canonical vector shape for vectorization.
223  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224  ArrayRef<int64_t> inputVectorSizes,
225  ArrayRef<bool> inputScalableVecDims);
226 
227  /// Returns the canonical vector shape used to vectorize the iteration space.
228  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
229 
230  /// Returns the vector dimensions that are scalable in the canonical vector
231  /// shape.
232  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
233 
234  /// Returns a vector type of the provided `elementType` with the canonical
235  /// vector shape and the corresponding fixed/scalable dimensions bit. If
236  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
237  /// accordingly.
239  Type elementType,
240  std::optional<AffineMap> dimPermutation = std::nullopt) const {
242  SmallVector<bool> scalableDims;
243  if (dimPermutation.has_value()) {
244  vectorShape =
245  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
246  scalableDims =
247  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
248  } else {
249  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
250  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
251  }
252 
253  return VectorType::get(vectorShape, elementType, scalableDims);
254  }
255 
256  /// Masks an operation with the canonical vector mask if the operation needs
257  /// masking. Returns the masked operation or the original operation if masking
258  /// is not needed. If provided, the canonical mask for this operation is
259  /// permuted using `maybeIndexingMap`.
260  Operation *
261  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
262  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
263 
264 private:
265  /// Initializes the iteration space static sizes using the Linalg op
266  /// information. This may become more complicated in the future.
267  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
268  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
269  }
270 
271  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
272  /// all the static and dynamic dimensions of the iteration space to be
273  /// vectorized and store them in `iterSpaceValueSizes`.
274  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
275  LinalgOp linalgOp);
276 
277  /// Create or retrieve an existing mask value to mask `opToMask` in the
278  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
279  /// permuted using that permutation map. If a new mask is created, it will be
280  /// cached for future users.
281  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
282  LinalgOp linalgOp,
283  std::optional<AffineMap> maybeMaskingMap);
284 
285  /// Check whether this permutation map can be used for masking. At the
286  /// moment we only make sure that there are no broadcast dimensions, but this
287  /// might change if indexing maps evolve.
288  bool isValidMaskingMap(AffineMap maskingMap) {
289  return maskingMap.getBroadcastDims().size() == 0;
290  }
291 
292  /// Turn the input indexing map into a valid masking map.
293  ///
294  /// The input indexing map may contain "zero" results, e.g.:
295  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
296  /// Applying such maps to canonical vector shapes like this one:
297  /// (1, 16, 16, 4)
298  /// would yield an invalid vector shape like this:
299  /// (16, 16, 1, 0)
300  /// Instead, drop the broadcasting dims that make no sense for masking perm.
301  /// maps:
302  /// (d0, d1, d2, d3) -> (d2, d1, d0)
303  /// This way, the corresponding vector/mask type will be:
304  /// vector<16x16x1xty>
305  /// rather than this invalid Vector type:
306  /// vector<16x16x1x0xty>
307  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
308  return indexingMap.dropZeroResults();
309  }
310 
311  // Holds the compile-time static sizes of the iteration space to vectorize.
312  // Dynamic dimensions are represented using ShapedType::kDynamic.
313  SmallVector<int64_t> iterSpaceStaticSizes;
314 
315  /// Holds the value sizes of the iteration space to vectorize. Static
316  /// dimensions are represented by 'arith.constant' and dynamic
317  /// dimensions by 'tensor/memref.dim'.
318  SmallVector<Value> iterSpaceValueSizes;
319 
320  /// Holds the canonical vector shape used to vectorize the iteration space.
321  SmallVector<int64_t> canonicalVecShape;
322 
323  /// Holds the vector dimensions that are scalable in the canonical vector
324  /// shape.
325  SmallVector<bool> scalableVecDims;
326 
327  /// Holds the active masks for permutations of the canonical vector iteration
328  /// space.
329  DenseMap<AffineMap, Value> activeMaskCache;
330 
331  /// Global vectorization guard for the incoming rewriter. It's initialized
332  /// when the vectorization state is initialized.
334 };
335 
336 LogicalResult
337 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
338  LinalgOp linalgOp) {
339  // TODO: Support 0-d vectors.
340  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
341  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
342  // Create constant index op for static dimensions.
343  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
344  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
345  continue;
346  }
347 
348  // Find an operand defined on this dimension of the iteration space to
349  // extract the runtime dimension size.
350  Value operand;
351  unsigned operandDimPos;
352  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
353  operandDimPos)))
354  return failure();
355 
356  Value dynamicDim = linalgOp.hasPureTensorSemantics()
357  ? (Value)rewriter.create<tensor::DimOp>(
358  linalgOp.getLoc(), operand, operandDimPos)
359  : (Value)rewriter.create<memref::DimOp>(
360  linalgOp.getLoc(), operand, operandDimPos);
361  iterSpaceValueSizes.push_back(dynamicDim);
362  }
363 
364  return success();
365 }
366 
367 /// Initializes the vectorization state, including the computation of the
368 /// canonical vector shape for vectorization.
369 // TODO: Move this to the constructor when we can remove the failure cases.
370 LogicalResult
371 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
372  ArrayRef<int64_t> inputVectorSizes,
373  ArrayRef<bool> inputScalableVecDims) {
374  // Initialize the insertion point.
375  rewriter.setInsertionPoint(linalgOp);
376 
377  if (!inputVectorSizes.empty()) {
378  // Get the canonical vector shape from the input vector sizes provided. This
379  // path should be taken to vectorize code with dynamic shapes and when using
380  // vector sizes greater than the iteration space sizes.
381  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
382  scalableVecDims.append(inputScalableVecDims.begin(),
383  inputScalableVecDims.end());
384  } else {
385  // Compute the canonical vector shape from the operation shape. If there are
386  // dynamic shapes, the operation won't be vectorized. We assume all the
387  // vector dimensions are fixed.
388  canonicalVecShape = linalgOp.getStaticLoopRanges();
389  scalableVecDims.append(linalgOp.getNumLoops(), false);
390  }
391 
392  LDBG("Canonical vector shape: ");
393  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
394  LLVM_DEBUG(llvm::dbgs() << "\n");
395  LDBG("Scalable vector dims: ");
396  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
397  LLVM_DEBUG(llvm::dbgs() << "\n");
398 
399  if (ShapedType::isDynamicShape(canonicalVecShape))
400  return failure();
401 
402  // Initialize iteration space static sizes.
403  initIterSpaceStaticSizes(linalgOp);
404 
405  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
406  // all the static and dynamic dimensions of the iteration space, needed to
407  // compute a mask during vectorization.
408  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
409  return failure();
410 
411  return success();
412 }
413 
414 /// Create or retrieve an existing mask value to mask `opToMask` in the
415 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
416 /// using that permutation map. If a new mask is created, it will be cached for
417 /// future users.
418 Value VectorizationState::getOrCreateMaskFor(
419  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
420  std::optional<AffineMap> maybeMaskingMap) {
421 
422  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
423  "Ill-formed masking map.");
424 
425  // No mask is needed if the operation is not maskable.
426  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
427  if (!maskableOp)
428  return Value();
429 
430  assert(!maskableOp.isMasked() &&
431  "Masking an operation that is already masked");
432 
433  // If no masking map was provided, use an identity map with the loop dims.
434  assert((!maybeMaskingMap || *maybeMaskingMap) &&
435  "Unexpected null mask permutation map");
436  AffineMap maskingMap =
437  maybeMaskingMap ? *maybeMaskingMap
439  linalgOp.getNumLoops(), rewriter.getContext());
440 
441  LDBG("Masking map: " << maskingMap << "\n");
442 
443  // Return the active mask for the masking map of this operation if it was
444  // already created.
445  auto activeMaskIt = activeMaskCache.find(maskingMap);
446  if (activeMaskIt != activeMaskCache.end()) {
447  Value mask = activeMaskIt->second;
448  LDBG("Reusing mask: " << mask << "\n");
449  return mask;
450  }
451 
452  // Compute permuted projection of the iteration space to be masked and the
453  // corresponding mask shape. If the resulting iteration space dimensions are
454  // static and identical to the mask shape, masking is not needed for this
455  // operation.
456  // TODO: Improve this check. Only projected permutation indexing maps are
457  // supported.
458  SmallVector<int64_t> permutedStaticSizes =
459  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
460  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
461  auto maskShape = maskType.getShape();
462 
463  LDBG("Mask shape: ");
464  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
465  LLVM_DEBUG(llvm::dbgs() << "\n");
466 
467  if (permutedStaticSizes == maskShape) {
468  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
469  activeMaskCache[maskingMap] = Value();
470  return Value();
471  }
472 
473  // Permute the iteration space value sizes to compute the mask upper bounds.
474  SmallVector<Value> upperBounds =
475  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
476  assert(!maskShape.empty() && !upperBounds.empty() &&
477  "Masked 0-d vectors are not supported yet");
478 
479  // Create the mask based on the dimension values.
480  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
481  maskType, upperBounds);
482  LDBG("Creating new mask: " << mask << "\n");
483  activeMaskCache[maskingMap] = mask;
484  return mask;
485 }
486 
487 Operation *
489  LinalgOp linalgOp,
490  std::optional<AffineMap> maybeIndexingMap) {
491  LDBG("Trying to mask: " << *opToMask << "\n");
492 
493  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
494  if (maybeIndexingMap)
495  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
496 
497  // Create or retrieve mask for this operation.
498  Value mask =
499  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
500 
501  if (!mask) {
502  LDBG("No mask required\n");
503  return opToMask;
504  }
505 
506  // Wrap the operation with a new `vector.mask` and update D-U chain.
507  assert(opToMask && "Expected a valid operation to mask");
508  auto maskOp = cast<vector::MaskOp>(
509  mlir::vector::maskOperation(rewriter, opToMask, mask));
510  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
511 
512  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
513  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
514  maskOpTerminator);
515 
516  LDBG("Masked operation: " << *maskOp << "\n");
517  return maskOp;
518 }
519 
520 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
521 /// projectedPermutation, compress the unused dimensions to serve as a
522 /// permutation_map for a vector transfer operation.
523 /// For example, given a linalg op such as:
524 ///
525 /// ```
526 /// %0 = linalg.generic {
527 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
528 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
529 /// }
530 /// ins(%0 : tensor<2x3x4xf32>)
531 /// outs(%1 : tensor<5x6xf32>)
532 /// ```
533 ///
534 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
535 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
536 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
538  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
539  "expected projected permutation");
540  auto res = compressUnusedDims(map);
541  assert(res.getNumDims() ==
542  (res.getNumResults() - res.getNumOfZeroResults()) &&
543  "expected reindexed map with same number of dims and results");
544  return res;
545 }
546 
547 /// Helper enum to represent conv1d input traversal order.
548 enum class Conv1DOpOrder {
549  W, // Corresponds to non-channeled 1D convolution operation.
550  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
551  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552 };
553 
554 /// Helper data structure to represent the result of vectorization for a single
555 /// operation. In certain specific cases, like terminators, we do not want to
556 /// propagate.
558  /// Op failed to vectorize.
559  Failure = 0,
560  /// Op vectorized and custom function took care of replacement logic
562  /// Op vectorized into a new Op whose results will replace original Op's
563  /// results.
564  NewOp
565  // TODO: support values if Op vectorized to Many-Ops whose results we need to
566  // aggregate for replacement.
567 };
568 /// VectorizationHookResult contains the vectorized op returned from a
569 /// CustomVectorizationHook. This is an internal implementation detail of
570 /// linalg vectorization, not to be confused with VectorizationResult.
572  /// Return status from vectorizing the current op.
574  /// New vectorized operation to replace the current op.
575  /// Replacement behavior is specified by `status`.
577 };
578 
579 std::optional<vector::CombiningKind>
581  using ::mlir::vector::CombiningKind;
582 
583  if (!combinerOp)
584  return std::nullopt;
586  .Case<arith::AddIOp, arith::AddFOp>(
587  [&](auto op) { return CombiningKind::ADD; })
588  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
589  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
590  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
591  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
592  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
593  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
594  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
595  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
596  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
597  .Case<arith::MulIOp, arith::MulFOp>(
598  [&](auto op) { return CombiningKind::MUL; })
599  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
600  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
601  .Default([&](auto op) { return std::nullopt; });
602 }
603 
604 /// Check whether `outputOperand` is a reduction with a single combiner
605 /// operation. Return the combiner operation of the reduction. Return
606 /// nullptr otherwise. Multiple reduction operations would impose an
607 /// ordering between reduction dimensions and is currently unsupported in
608 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
609 /// max(min(X))
610 // TODO: use in LinalgOp verification, there is a circular dependency atm.
611 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
612  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
613  unsigned outputPos =
614  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
615  // Only single combiner operations are supported for now.
616  SmallVector<Operation *, 4> combinerOps;
617  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
618  combinerOps.size() != 1)
619  return nullptr;
620 
621  // Return the combiner operation.
622  return combinerOps[0];
623 }
624 
625 /// Broadcast `value` to a vector of `shape` if possible. Return value
626 /// otherwise.
627 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
628  auto dstVecType = dyn_cast<VectorType>(dstType);
629  // If no shape to broadcast to, just return `value`.
630  if (dstVecType.getRank() == 0)
631  return value;
632  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
634  return value;
635  Location loc = b.getInsertionPoint()->getLoc();
636  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
637 }
638 
639 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
640 /// assumes that `reductionOp` has two operands and one of them is the reduction
641 /// initial value.buildMultiDimReduce
642 // Note: this is a true builder that notifies the OpBuilder listener.
643 // TODO: Consider moving as a static helper on the ReduceOp.
645  Value valueToReduce, Value acc,
646  ArrayRef<bool> dimsToMask) {
647  auto maybeKind = getCombinerOpKind(reduceOp);
648  assert(maybeKind && "Failed precondition: could not get reduction kind");
649  return b.create<vector::MultiDimReductionOp>(
650  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
651 }
652 
653 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
654  return llvm::to_vector(
655  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
656 }
657 
658 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
659 /// reduction iterator.
660 static bool hasReductionIterator(LinalgOp &op) {
661  return isa<linalg::ReduceOp>(op) ||
662  (isa<linalg::GenericOp>(op) &&
663  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
664 }
665 
666 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
667 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
668 /// currently being vectorized. If `dest` has null rank, build an memref.store.
669 /// Return the produced value or null if no value is produced.
670 // Note: this is a true builder that notifies the OpBuilder listener.
671 // TODO: Consider moving as a static helper on the ReduceOp.
672 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
673  OpOperand *outputOperand,
674  VectorizationState &state) {
675  Location loc = value.getLoc();
676  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
677  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
678 
679  // Compute the vector type of the value to store. This type should be an
680  // identity or projection of the canonical vector type without any permutation
681  // applied, given that any permutation in a transfer write happens as part of
682  // the write itself.
684  opOperandMap.getContext(), opOperandMap.getNumInputs(),
685  [&](AffineDimExpr dimExpr) -> bool {
686  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
687  });
688  auto vectorType = state.getCanonicalVecType(
689  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
690 
691  Operation *write;
692  if (vectorType.getRank() > 0) {
693  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
694  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
695  rewriter.create<arith::ConstantIndexOp>(loc, 0));
696  value = broadcastIfNeeded(rewriter, value, vectorType);
697  assert(value.getType() == vectorType && "Incorrect type");
698  write = rewriter.create<vector::TransferWriteOp>(
699  loc, value, outputOperand->get(), indices, writeMap);
700  } else {
701  // 0-d case is still special: do not invert the reindexing writeMap.
702  if (!isa<VectorType>(value.getType()))
703  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
704  assert(value.getType() == vectorType && "Incorrect type");
705  write = rewriter.create<vector::TransferWriteOp>(
706  loc, value, outputOperand->get(), ValueRange{});
707  }
708 
709  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
710 
711  // If masked, set in-bounds to true. Masking guarantees that the access will
712  // be in-bounds.
713  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
714  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
715  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
716  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
717  }
718 
719  LDBG("vectorized op: " << *write << "\n");
720  if (!write->getResults().empty())
721  return write->getResult(0);
722  return Value();
723 }
724 
725 // Custom vectorization precondition function type. This is intented to be used
726 // with CustomVectorizationHook. Returns success if the corresponding custom
727 // hook can vectorize the op.
729  std::function<LogicalResult(Operation *, bool)>;
730 
731 // Custom vectorization function type. Produce a vector form of Operation*
732 // assuming all its vectorized operands are already in the IRMapping.
733 // Return nullptr if the Operation cannot be vectorized.
735  std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
736 
737 /// Helper function to vectorize the terminator of a `linalgOp`. New result
738 /// vector values are appended to `newResults`. Return
739 /// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
740 /// that it should not try to map produced operations and instead return the
741 /// results using the `newResults` vector making them available to the
742 /// vectorization algorithm for RAUW. This function is meant to be used as a
743 /// CustomVectorizationHook.
746  const IRMapping &bvm, VectorizationState &state,
747  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
748  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
749  if (!yieldOp)
751  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
752  // TODO: Scan for an opportunity for reuse.
753  // TODO: use a map.
754  Value vectorValue = bvm.lookup(output.value());
755  Value newResult =
756  buildVectorWrite(rewriter, vectorValue,
757  linalgOp.getDpsInitOperand(output.index()), state);
758  if (newResult)
759  newResults.push_back(newResult);
760  }
761 
763 }
764 
765 /// Helper function to vectorize the index operations of a `linalgOp`. Return
766 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
767 /// should map the produced operations. This function is meant to be used as a
768 /// CustomVectorizationHook.
770  VectorizationState &state,
771  Operation *op,
772  LinalgOp linalgOp) {
773  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
774  if (!indexOp)
776  auto loc = indexOp.getLoc();
777  // Compute the static loop sizes of the index op.
778  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
779  auto dim = indexOp.getDim();
780  // Compute a one-dimensional index vector for the index op dimension.
781  auto indexVectorType =
782  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
783  state.getScalableVecDims()[dim]);
784  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
785  // Return the one-dimensional index vector if it lives in the trailing
786  // dimension of the iteration space since the vectorization algorithm in this
787  // case can handle the broadcast.
788  if (dim == targetShape.size() - 1)
790  // Otherwise permute the targetShape to move the index dimension last,
791  // broadcast the one-dimensional index vector to the permuted shape, and
792  // finally transpose the broadcasted index vector to undo the permutation.
793  auto permPattern =
794  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
795  std::swap(permPattern[dim], permPattern.back());
796  auto permMap =
797  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
798 
799  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
800  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
801  indexSteps);
802  SmallVector<int64_t> transposition =
803  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
804  std::swap(transposition.back(), transposition[dim]);
805  auto transposeOp =
806  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
808 }
809 
810 /// Helper function to check if the tensor.extract can be vectorized by the
811 /// custom hook vectorizeTensorExtract.
812 static LogicalResult
813 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
814  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
815  if (!extractOp)
816  return failure();
817 
818  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
819  return failure();
820 
821  // Check the index type, but only for non 0-d tensors (for which we do need
822  // access indices).
823  if (not extractOp.getIndices().empty()) {
824  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
825  return failure();
826  }
827 
828  if (!llvm::all_of(extractOp->getResultTypes(),
829  VectorType::isValidElementType)) {
830  return failure();
831  }
832 
833  return success();
834 }
835 
836 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
837 /// generated from `tensor.extract`. The offset is calculated as follows
838 /// (example using scalar values):
839 ///
840 /// offset = extractOp.indices[0]
841 /// for (i = 1; i < numIndices; i++)
842 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
843 ///
844 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
845 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
847  VectorizationState &state,
848  tensor::ExtractOp extractOp,
849  const IRMapping &bvm) {
850  // The vector of indices for GatherOp should be shaped as the output vector.
851  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
852  auto loc = extractOp.getLoc();
853 
854  Value offset = broadcastIfNeeded(
855  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
856 
857  const size_t numIndices = extractOp.getIndices().size();
858  for (size_t i = 1; i < numIndices; i++) {
859  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
860 
861  auto dimSize = broadcastIfNeeded(
862  rewriter,
863  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
864  indexVecType);
865 
866  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
867 
868  auto extractOpIndex = broadcastIfNeeded(
869  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
870 
871  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
872  }
873 
874  return offset;
875 }
876 
878 
879 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
880 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
881 /// represents a contiguous load operation.
882 ///
883 /// Note that when calling this hook, it is assumed that the output vector is
884 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
885 /// labelled as a gather load before entering this method.
886 ///
887 /// Following on from the above, it is assumed that:
888 /// * for statically shaped loops, when no masks are used, only one dim is !=
889 /// 1 (that's what the shape of the output vector is based on).
890 /// * for dynamically shaped loops, there might be more non-unit dims
891 /// as the output vector type is user-specified.
892 ///
893 /// TODO: Statically shaped loops + vector masking
894 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
895  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
896  assert(
897  (linalgOp.hasDynamicShape() ||
898  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
899  "For statically shaped Linalg Ops, only one "
900  "non-unit loop dim is expected");
901  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
902 
903  size_t idx = loopRanges.size() - 1;
904  for (; idx != 0; idx--)
905  if (loopRanges[idx] != 1)
906  break;
907 
908  return idx;
909 }
910 
911 /// Checks whether `val` can be used for calculating a loop invariant index.
912 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
913  VectorType resType) {
914 
915  assert(((llvm::count_if(resType.getShape(),
916  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
917  "n-D vectors are not yet supported");
918 
919  // Blocks outside _this_ linalg.generic are effectively loop invariant.
920  // However, analysing block arguments for _this_ linalg.generic Op is a bit
921  // tricky. Just bail out in the latter case.
922  // TODO: We could try analysing the corresponding affine map here.
923  auto *block = linalgOp.getBlock();
924  if (isa<BlockArgument>(val))
925  return !llvm::is_contained(block->getArguments(), val);
926 
927  Operation *defOp = val.getDefiningOp();
928  assert(defOp && "This is neither a block argument nor an operation result");
929 
930  // IndexOp is loop invariant as long as its result remains constant across
931  // iterations. Note that for dynamic shapes, the corresponding dim will also
932  // be conservatively treated as != 1.
933  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
934  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
935  }
936 
937  auto *ancestor = block->findAncestorOpInBlock(*defOp);
938 
939  // Values define outside `linalgOp` are loop invariant.
940  if (!ancestor)
941  return true;
942 
943  // Values defined inside `linalgOp`, which are constant, are loop invariant.
944  if (isa<arith::ConstantOp>(ancestor))
945  return true;
946 
947  bool result = true;
948  for (auto op : ancestor->getOperands())
949  result &= isLoopInvariantIdx(linalgOp, op, resType);
950 
951  return result;
952 }
953 
954 /// Check whether `val` could be used for calculating the trailing index for a
955 /// contiguous load operation.
956 ///
957 /// There are currently 3 types of values that are allowed here:
958 /// 1. loop-invariant values,
959 /// 2. values that increment by 1 with every loop iteration,
960 /// 3. results of basic arithmetic operations (linear and continuous)
961 /// involving 1., 2. and 3.
962 /// This method returns True if indeed only such values are used in calculating
963 /// `val.`
964 ///
965 /// Additionally, the trailing index for a contiguous load operation should
966 /// increment by 1 with every loop iteration, i.e. be based on:
967 /// * `linalg.index <dim>` ,
968 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
969 /// `linalg.index <dim>` increments by 1 with every loop iteration).
970 /// `foundIndexOp` is updated to `true` when such Op is found.
971 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
972  bool &foundIndexOp, VectorType resType) {
973 
974  assert(((llvm::count_if(resType.getShape(),
975  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976  "n-D vectors are not yet supported");
977 
978  // Blocks outside _this_ linalg.generic are effectively loop invariant.
979  // However, analysing block arguments for _this_ linalg.generic Op is a bit
980  // tricky. Just bail out in the latter case.
981  // TODO: We could try analysing the corresponding affine map here.
982  auto *block = linalgOp.getBlock();
983  if (isa<BlockArgument>(val))
984  return !llvm::is_contained(block->getArguments(), val);
985 
986  Operation *defOp = val.getDefiningOp();
987  assert(defOp && "This is neither a block argument nor an operation result");
988 
989  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
990  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
991 
992  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
993  return true;
994  }
995 
996  auto *ancestor = block->findAncestorOpInBlock(*defOp);
997 
998  if (!ancestor)
999  return false;
1000 
1001  // Conservatively reject Ops that could lead to indices with stride other
1002  // than 1.
1003  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1004  return false;
1005 
1006  bool result = false;
1007  for (auto op : ancestor->getOperands())
1008  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1009 
1010  return result;
1011 }
1012 
1013 /// Infer the memory access pattern for the input ExtractOp
1014 ///
1015 /// Based on the ExtratOp result shape and the access indices, decides whether
1016 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1017 /// or a gather load. When analysing the ExtractOp indices (to identify
1018 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1019 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1020 /// Op).
1021 ///
1022 /// Note that it is always safe to use gather load operations for contiguous
1023 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1024 /// that `extractOp` is a gather load.
1026 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1027  LinalgOp &linalgOp, VectorType resType) {
1028 
1029  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1030 
1031  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1032  if (inputShape.getShape().empty())
1034 
1035  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1036  // otherwise.
1037  bool isOutput1DVector =
1038  (llvm::count_if(resType.getShape(),
1039  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1040  // 1. Assume that it's a gather load when reading non-1D vector.
1041  if (!isOutput1DVector)
1043 
1044  bool leadingIdxsLoopInvariant = true;
1045 
1046  // 2. Analyze the leading indices of `extractOp`.
1047  // Look at the way each index is calculated and decide whether it is suitable
1048  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1049  // gather load.
1050  auto indices = extractOp.getIndices();
1051  auto leadIndices = indices.drop_back(1);
1052 
1053  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1054  if (inputShape.getShape()[i] == 1)
1055  continue;
1056 
1057  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1058  }
1059 
1060  if (!leadingIdxsLoopInvariant) {
1061  LDBG("Found gather load: " << extractOp);
1063  }
1064 
1065  // 3. Analyze the trailing index for `extractOp`.
1066  // At this point we know that the leading indices are loop invariant. This
1067  // means that is potentially a scalar or a contiguous load. We can decide
1068  // based on the trailing idx.
1069  auto extractOpTrailingIdx = indices.back();
1070 
1071  // 3a. Scalar broadcast load
1072  // If the trailing index is loop invariant then this is a scalar load.
1073  if (leadingIdxsLoopInvariant &&
1074  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1075  LDBG("Found scalar broadcast load: " << extractOp);
1076 
1078  }
1079 
1080  // 3b. Contiguous loads
1081  // The trailing `extractOp` index should increment with every loop iteration.
1082  // This effectively means that it must be based on the trailing loop index.
1083  // This is what the following bool captures.
1084  bool foundIndexOp = false;
1085  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1086  foundIndexOp, resType);
1087  // TODO: Support generating contiguous loads for column vectors - that will
1088  // require adding a permutation map to tranfer_read Ops.
1089  bool isRowVector = resType.getShape().back() != 1;
1090  isContiguousLoad &= (foundIndexOp && isRowVector);
1091 
1092  if (isContiguousLoad) {
1093  LDBG("Found contigous load: " << extractOp);
1095  }
1096 
1097  // 4. Fallback case - gather load.
1098  LDBG("Found gather load: " << extractOp);
1100 }
1101 
1102 /// Helper function to vectorize the tensor.extract operations. Returns
1103 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1104 /// should map the produced operations. This function is meant to be used as a
1105 /// CustomVectorizationHook.
1108  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1109  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1110  if (!extractOp)
1112  auto loc = extractOp.getLoc();
1113 
1114  // Compute the static loop sizes of the extract op.
1115  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1116  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1117  loc,
1118  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1119  /*value=*/true));
1120  auto passThruConstantOp =
1121  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1122 
1123  // Base indices are currently set to 0. We will need to re-visit if more
1124  // generic scenarios are to be supported.
1125  SmallVector<Value> baseIndices(
1126  extractOp.getIndices().size(),
1127  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1128 
1129  VectorMemoryAccessKind memAccessKind =
1130  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1131 
1132  // 1. Handle gather access
1133  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1134  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1135 
1136  // Generate the gather load
1137  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1138  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1139  maskConstantOp, passThruConstantOp);
1140  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1141 
1142  LDBG("Vectorised as gather load: " << extractOp << "\n");
1144  }
1145 
1146  // 2. Handle:
1147  // a. scalar loads + broadcast,
1148  // b. contiguous loads.
1149  // Both cases use vector.transfer_read.
1150 
1151  // Collect indices for `vector.transfer_read`. At this point, the indices will
1152  // either be scalars or would have been broadcast to vectors matching the
1153  // result type. For indices that are vectors, there are two options:
1154  // * for non-trailing indices, all elements are identical (contiguous
1155  // loads are identified by looking for non-trailing indices that are
1156  // invariant with respect to the corresponding linalg.generic), or
1157  // * for trailing indices, the index vector will contain values with stride
1158  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1159  // needed.
1160  // This means that
1161  // * for scalar indices - just re-use it,
1162  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1163  // (0th) element and use that.
1164  SmallVector<Value> transferReadIdxs;
1165  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1166  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1167  if (idx.getType().isIndex()) {
1168  transferReadIdxs.push_back(idx);
1169  continue;
1170  }
1171 
1172  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1173  loc,
1174  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1175  resultType.getScalableDims().back()),
1176  idx);
1177  transferReadIdxs.push_back(
1178  rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1179  }
1180 
1181  // `tensor.extract_element` is always in-bounds, hence the following holds.
1182  auto dstRank = resultType.getRank();
1183  auto srcRank = extractOp.getTensor().getType().getRank();
1184  SmallVector<bool> inBounds(dstRank, true);
1185 
1186  // 2a. Handle scalar broadcast access.
1187  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1188  MLIRContext *ctx = rewriter.getContext();
1189  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1190  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1191 
1192  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1193  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1194  /*padding=*/std::nullopt, permutationMap, inBounds);
1195 
1196  // Mask this broadcasting xfer_read here rather than relying on the generic
1197  // path (the generic path assumes identity masking map, which wouldn't be
1198  // valid here).
1199  SmallVector<int64_t> readMaskShape = {1};
1200  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1201  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1202  loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1203  auto *maskedReadOp =
1204  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1205 
1206  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1208  maskedReadOp};
1209  }
1210 
1211  // 2b. Handle contiguous access.
1212  auto permutationMap = AffineMap::getMinorIdentityMap(
1213  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1214 
1215  int32_t rankDiff = dstRank - srcRank;
1216  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1217  // dims so that the ranks match. This is done by extending the map with 0s.
1218  // For example, for dstRank = 3, srcRank = 2, the following map created
1219  // above:
1220  // (d0, d1) --> (d0, d1)
1221  // is extended as:
1222  // (d0, d1) --> (0, d0, d1)
1223  while (rankDiff > 0) {
1224  permutationMap = permutationMap.insertResult(
1225  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1226  rankDiff--;
1227  }
1228 
1229  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1230  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1231  /*padding=*/std::nullopt, permutationMap, inBounds);
1232 
1233  LDBG("Vectorised as contiguous load: " << extractOp);
1235  transferReadOp};
1236 }
1237 
1238 /// Emit reduction operations if the shapes of the value to reduce is different
1239 /// that the result shape.
1240 // Note: this is a true builder that notifies the OpBuilder listener.
1241 // TODO: Consider moving as a static helper on the ReduceOp.
1242 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1243  Value reduceValue, Value initialValue,
1244  const IRMapping &bvm) {
1245  Value reduceVec = bvm.lookup(reduceValue);
1246  Value outputVec = bvm.lookup(initialValue);
1247  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1248  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1249  // Reduce only if needed as the value may already have been reduce for
1250  // contraction vectorization.
1251  if (!reduceType ||
1252  (outputType && reduceType.getShape() == outputType.getShape()))
1253  return nullptr;
1254  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1255  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1256 }
1257 
1258 /// Generic vectorization for a single operation `op`, given already vectorized
1259 /// operands carried by `bvm`. Vectorization occurs as follows:
1260 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1261 /// result on success.
1262 /// 2. Clone any constant in the current scope without vectorization: each
1263 /// consumer of the constant will later determine the shape to which the
1264 /// constant needs to be broadcast to.
1265 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1266 /// of the `customVectorizationHooks` to cover such cases.
1267 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1268 /// operand of maximal rank. Other operands have smaller rank and are
1269 /// broadcast accordingly. It is assumed this broadcast is always legal,
1270 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1271 ///
1272 /// This function assumes all operands of `op` have been vectorized and are in
1273 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1274 /// a topologically-sorted list of ops.
1275 /// This function does not update `bvm` but returns a VectorizationHookStatus
1276 /// that instructs the caller what `bvm` update needs to occur.
1279  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1280  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1281  LDBG("vectorize op " << *op << "\n");
1282 
1283  // 1. Try to apply any CustomVectorizationHook.
1284  if (!customVectorizationHooks.empty()) {
1285  for (auto &customFunc : customVectorizationHooks) {
1286  VectorizationHookResult result = customFunc(op, bvm);
1288  continue;
1289  return result;
1290  }
1291  }
1292 
1293  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1294  // Clone so that the constant is not confined to the linalgOp block .
1295  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1297  rewriter.clone(*op)};
1298 
1299  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1302 
1303  // 4 . Check if the operation is a reduction.
1304  SmallVector<std::pair<Value, Value>> reductionOperands;
1305  for (Value operand : op->getOperands()) {
1306  auto blockArg = dyn_cast<BlockArgument>(operand);
1307  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1308  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1309  continue;
1310  SmallVector<Operation *> reductionOps;
1311  Value reduceValue = matchReduction(
1312  linalgOp.getRegionOutputArgs(),
1313  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1314  if (!reduceValue)
1315  continue;
1316  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1317  }
1318  if (!reductionOperands.empty()) {
1319  assert(reductionOperands.size() == 1);
1320  Operation *reduceOp =
1321  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1322  reductionOperands[0].second, bvm);
1323  if (reduceOp)
1325  }
1326 
1327  // 5. Generic vectorization path for ElementwiseMappable ops.
1328  // a. Get the first max ranked shape.
1329  VectorType firstMaxRankedType;
1330  for (Value operand : op->getOperands()) {
1331  auto vecOperand = bvm.lookup(operand);
1332  assert(vecOperand && "Vector operand couldn't be found");
1333 
1334  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1335  if (vecType && (!firstMaxRankedType ||
1336  firstMaxRankedType.getRank() < vecType.getRank()))
1337  firstMaxRankedType = vecType;
1338  }
1339  // b. Broadcast each op if needed.
1340  SmallVector<Value> vecOperands;
1341  for (Value scalarOperand : op->getOperands()) {
1342  Value vecOperand = bvm.lookup(scalarOperand);
1343  assert(vecOperand && "Vector operand couldn't be found");
1344 
1345  if (firstMaxRankedType) {
1346  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1347  getElementTypeOrSelf(vecOperand.getType()),
1348  firstMaxRankedType.getScalableDims());
1349  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1350  } else {
1351  vecOperands.push_back(vecOperand);
1352  }
1353  }
1354  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1355  SmallVector<Type> resultTypes;
1356  for (Type resultType : op->getResultTypes()) {
1357  resultTypes.push_back(
1358  firstMaxRankedType
1359  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1360  firstMaxRankedType.getScalableDims())
1361  : resultType);
1362  }
1363  // d. Build and return the new op.
1364  return VectorizationHookResult{
1366  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1367  resultTypes, op->getAttrs())};
1368 }
1369 
1370 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1371 /// vector form. Generic vectorization proceeds as follows:
1372 /// 1. Verify the `linalgOp` has one non-empty region.
1373 /// 2. Values defined above the region are mapped to themselves and will be
1374 /// broadcasted on a per-need basis by their consumers.
1375 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1376 /// load).
1377 /// TODO: Reuse opportunities for RAR dependencies.
1378 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1379 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1380 /// iteration indices.
1381 /// 5. Iteratively call vectorizeOneOp on the region operations.
1382 ///
1383 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1384 /// performed to the maximal common vector size implied by the `linalgOp`
1385 /// iteration space. This eager broadcasting is introduced in the
1386 /// permutation_map of the vector.transfer_read operations. The eager
1387 /// broadcasting makes it trivial to determine where broadcast, transposes and
1388 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1389 /// the absence of good canonicalizations, the amount of work increases.
1390 /// This is not deemed a problem as we expect canonicalizations and foldings to
1391 /// aggressively clean up the useless work.
1392 static LogicalResult
1394  LinalgOp linalgOp,
1395  SmallVectorImpl<Value> &newResults) {
1396  LDBG("Vectorizing operation as linalg generic\n");
1397  Block *block = linalgOp.getBlock();
1398 
1399  // 2. Values defined above the region can only be broadcast for now. Make them
1400  // map to themselves.
1401  IRMapping bvm;
1402  SetVector<Value> valuesSet;
1403  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1404  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1405 
1406  if (linalgOp.getNumDpsInits() == 0)
1407  return failure();
1408 
1409  // 3. Turn all BBArgs into vector.transfer_read / load.
1410  Location loc = linalgOp.getLoc();
1411  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1412  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1413  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1414  if (linalgOp.isScalar(opOperand)) {
1415  bvm.map(bbarg, opOperand->get());
1416  continue;
1417  }
1418 
1419  // 3.a. Convert the indexing map for this input/output to a transfer read
1420  // permutation map and masking map.
1421  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1422 
1423  AffineMap readMap;
1424  VectorType readType;
1425  Type elemType = getElementTypeOrSelf(opOperand->get());
1426  if (linalgOp.isDpsInput(opOperand)) {
1427  // 3.a.i. For input reads we use the canonical vector shape.
1428  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1429  readType = state.getCanonicalVecType(elemType);
1430  } else {
1431  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1432  // reductions), the vector shape is computed by mapping the canonical
1433  // vector shape to the output domain and back to the canonical domain.
1434  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1435  readType =
1436  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1437  }
1438 
1439  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1440 
1441  Operation *read = rewriter.create<vector::TransferReadOp>(
1442  loc, readType, opOperand->get(), indices,
1443  /*padding=*/std::nullopt, readMap);
1444  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1445  Value readValue = read->getResult(0);
1446 
1447  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1448  // will be in-bounds.
1449  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1450  SmallVector<bool> inBounds(readType.getRank(), true);
1451  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1452  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1453  }
1454 
1455  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1456  // TODO: remove this.
1457  if (readType.getRank() == 0)
1458  readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1459  ArrayRef<int64_t>());
1460 
1461  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1462  << "\n");
1463  bvm.map(bbarg, readValue);
1464  bvm.map(opOperand->get(), readValue);
1465  }
1466 
1468  // 4a. Register CustomVectorizationHook for yieldOp.
1469  CustomVectorizationHook vectorizeYield =
1470  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1471  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1472  };
1473  hooks.push_back(vectorizeYield);
1474 
1475  // 4b. Register CustomVectorizationHook for indexOp.
1476  CustomVectorizationHook vectorizeIndex =
1477  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1478  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1479  };
1480  hooks.push_back(vectorizeIndex);
1481 
1482  // 4c. Register CustomVectorizationHook for extractOp.
1483  CustomVectorizationHook vectorizeExtract =
1484  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1485  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1486  };
1487  hooks.push_back(vectorizeExtract);
1488 
1489  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1490  for (Operation &op : block->getOperations()) {
1491  VectorizationHookResult result =
1492  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1493  if (result.status == VectorizationHookStatus::Failure) {
1494  LDBG("failed to vectorize: " << op << "\n");
1495  return failure();
1496  }
1497  if (result.status == VectorizationHookStatus::NewOp) {
1498  Operation *maybeMaskedOp =
1499  state.maskOperation(rewriter, result.newOp, linalgOp);
1500  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1501  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1502  }
1503  }
1504 
1505  return success();
1506 }
1507 
1508 /// Given a linalg::PackOp, return the `dest` shape before any packing
1509 /// permutations.
1510 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1511  ArrayRef<int64_t> destShape) {
1512  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1513 }
1514 
1515 /// Determines whether a mask for xfer_write is trivially "all true"
1516 ///
1517 /// Given all the inputs required to generate a mask (mask sizes and shapes),
1518 /// and an xfer_write operation (write indices and the destination tensor
1519 /// shape), determines whether the corresponding mask would be trivially
1520 /// foldable (i.e., trivially "all true").
1521 ///
1522 /// Use this method to avoid generating spurious masks and relaying on
1523 /// vectorization post-processing to remove them.
1524 ///
1525 /// Pre-conditions for a mask to be trivially foldable:
1526 /// * All involved shapes (mask + destination tensor) are static.
1527 /// * All write indices are constant.
1528 /// * All mask sizes are constant (including `arith.constant`).
1529 ///
1530 /// If the pre-conditions are met, the method checks for each destination
1531 /// dimension `d`:
1532 /// (1) destDimSize[rankDiff + d] <= maskShape[d]
1533 /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1534 ///
1535 /// rankDiff = rank(dest) - rank(mask).
1536 ///
1537 /// This method takes a conservative view: it may return false even if the mask
1538 /// is technically foldable.
1539 ///
1540 /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1541 /// of the dest tensor):
1542 /// %c0 = arith.constant 0 : index
1543 /// %mask = vector.create_mask 5, 1
1544 /// vector.mask %mask {
1545 /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1546 /// {in_bounds = [true, true]}
1547 /// : vector<5x1xi32>, tensor<5x1xi32>
1548 /// }
1549 ///
1550 /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1551 /// mask is required to avoid out-of-bounds write):
1552 /// %c0 = arith.constant 0 : index
1553 /// %mask = vector.create_mask 5, 1
1554 /// vector.mask %mask {
1555 /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1556 /// {in_bounds = [true, true]}
1557 /// : vector<8x1xi32>, tensor<5x1xi32>
1558 /// }
1559 ///
1560 /// TODO: Re-use in createReadOrMaskedRead
1562  SmallVector<Value> &writeIdxs,
1563  ArrayRef<int64_t> destShape,
1564  ArrayRef<int64_t> maskShape) {
1565  // Masking is unavoidable in the case of dynamic tensors.
1566  if (ShapedType::isDynamicShape(destShape))
1567  return false;
1568 
1569  // Collect all constant mask sizes.
1570  SmallVector<int64_t, 4> cstMaskSizes;
1571  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1572  if (auto intSize = getConstantIntValue(dimSize)) {
1573  cstMaskSizes.push_back(*intSize);
1574  }
1575  }
1576 
1577  // If any of the mask sizes is non-constant, bail out.
1578  if (cstMaskSizes.size() != maskShape.size())
1579  return false;
1580 
1581  // Collect all constant write indices.
1582  SmallVector<int64_t, 4> cstWriteIdxs;
1583  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1584  APSInt intVal;
1585  if (matchPattern(idx, m_ConstantInt(&intVal))) {
1586  cstWriteIdxs.push_back(intVal.getSExtValue());
1587  }
1588  }
1589 
1590  // If any of the write indices is non-constant, bail out.
1591  if (cstWriteIdxs.size() != destShape.size())
1592  return false;
1593 
1594  // Go over all destination dims and check (1) and (2). Take into account that:
1595  // * The number of mask sizes will match the rank of the vector to store.
1596  // This could be lower than the rank of the destination tensor.
1597  // * Mask sizes could be larger than the corresponding mask shape (hence
1598  // `clamp`).
1599  // TODO: The 2nd item should be rejected by the verifier.
1600  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1601  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1602  if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1603  /*(2)*/ destShape[rankDiff + i] <
1604  (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1605  cstWriteIdxs[i]))
1606  return false;
1607  }
1608 
1609  return true;
1610 }
1611 
1612 /// Creates an optionally masked TransferWriteOp
1613 ///
1614 /// Generates the following operation:
1615 /// %res = vector.transfer_write %vecToStore into %dest
1616 ///
1617 /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1618 ///
1619 /// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1620 /// %res = vector.mask %mask {
1621 /// vector.transfer_write %vecToStore into %dest
1622 /// }
1623 ///
1624 /// The mask shape is identical to `vecToStore` (with the element type ==
1625 /// i1), and the mask values are based on the shape of the `dest` tensor.
1626 ///
1627 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1628 /// is used instead of masking:
1629 ///
1630 /// %write = vector.transfer_write %vecToStore into %dest
1631 /// in_bounds_flags = (...)
1632 /// %res = vector.transfer_write %input into %dest
1633 /// {in_bounds = in_bounds_flags}
1634 ///
1635 /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1636 /// are set to 0.
1637 static Operation *
1639  Value dest, SmallVector<Value> writeIndices = {},
1640  bool useInBoundsInsteadOfMasking = false) {
1641 
1642  ShapedType destType = cast<ShapedType>(dest.getType());
1643  int64_t destRank = destType.getRank();
1644  auto destShape = destType.getShape();
1645 
1646  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1647  int64_t vecToStoreRank = vecToStoreType.getRank();
1648  auto vecToStoreShape = vecToStoreType.getShape();
1649 
1650  // Compute the in_bounds attribute
1651  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1652  if (useInBoundsInsteadOfMasking) {
1653  // Update the inBounds attribute.
1654  // FIXME: This computation is too weak - it ignores the write indices.
1655  for (unsigned i = 0; i < vecToStoreRank; i++)
1656  inBoundsVal[i] =
1657  (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1658  !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
1659  }
1660 
1661  // If missing, initialize the write indices to 0.
1662  assert(writeIndices.empty() ||
1663  writeIndices.size() == static_cast<size_t>(destRank) &&
1664  "Invalid number of write indices!");
1665  if (writeIndices.empty()) {
1666  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1667  writeIndices.assign(destRank, zero);
1668  }
1669 
1670  // Generate the xfer_write Op
1671  Operation *write =
1672  builder.create<vector::TransferWriteOp>(loc,
1673  /*vector=*/vecToStore,
1674  /*source=*/dest,
1675  /*indices=*/writeIndices,
1676  /*inBounds=*/inBoundsVal);
1677 
1678  // If masking is disabled, exit.
1679  if (useInBoundsInsteadOfMasking)
1680  return write;
1681 
1682  // Check if masking is needed. If not, exit.
1683  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1684  return write;
1685 
1686  // Compute the mask and mask the write Op.
1687  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1688 
1689  SmallVector<OpFoldResult> destSizes =
1690  tensor::getMixedSizes(builder, loc, dest);
1691  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1692  destSizes.end());
1693 
1694  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1695  vecToStoreShape))
1696  return write;
1697 
1698  Value maskForWrite =
1699  builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1700  return mlir::vector::maskOperation(builder, write, maskForWrite);
1701 }
1702 
1703 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1704 /// padding value and (3) input vector sizes into:
1705 ///
1706 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1707 ///
1708 /// As in the following example:
1709 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1710 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1711 ///
1712 /// This pack would be vectorized to:
1713 ///
1714 /// %load = vector.mask %mask {
1715 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1716 /// {in_bounds = [true, true, true]} :
1717 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1718 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1719 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1720 /// to vector<32x4x2x1x16xf32>
1721 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1722 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1723 /// %write = vector.transfer_write %transpose,
1724 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1725 /// {in_bounds = [true, true, true, true, true]}
1726 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1727 ///
1728 /// If the (3) input vector sizes are not provided, the vector sizes are
1729 /// determined by the result tensor shape and the `in_bounds`
1730 /// attribute is used instead of masking to mark out-of-bounds accesses.
1731 ///
1732 /// NOTE: The input vector sizes specify the dimensions corresponding to the
1733 /// outer dimensions of the output tensor. The remaining dimensions are
1734 /// computed based on, e.g., the static inner tiles.
1735 /// Supporting dynamic inner tiles will require the user to specify the
1736 /// missing vector sizes. This is left as a TODO.
1737 static LogicalResult
1738 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1739  ArrayRef<int64_t> inputVectorSizes,
1740  SmallVectorImpl<Value> &newResults) {
1741  // TODO: Introduce a parent class that will handle the insertion point update.
1742  OpBuilder::InsertionGuard g(rewriter);
1743  rewriter.setInsertionPoint(packOp);
1744 
1745  Location loc = packOp.getLoc();
1746  auto padValue = packOp.getPaddingValue();
1747  if (!padValue) {
1748  padValue = rewriter.create<arith::ConstantOp>(
1749  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1750  }
1751  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1752  LogicalResult status =
1753  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1754  .reifyResultShapes(rewriter, reifiedReturnShapes);
1755  (void)status; // prevent unused variable warning on non-assert builds.
1756  assert(succeeded(status) && "failed to reify result shapes");
1757 
1758  // If the input vector sizes are not provided, then the vector sizes are
1759  // determined by the result tensor shape. In case the vector sizes aren't
1760  // provided, we update the inBounds attribute instead of masking.
1761  bool useInBoundsInsteadOfMasking = false;
1762  if (inputVectorSizes.empty()) {
1763  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1764  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1765  useInBoundsInsteadOfMasking = true;
1766  }
1767 
1768  // Create masked TransferReadOp.
1769  SmallVector<int64_t> inputShape(inputVectorSizes);
1770  auto innerTiles = packOp.getStaticInnerTiles();
1771  auto innerDimsPos = packOp.getInnerDimsPos();
1772  auto outerDimsPerm = packOp.getOuterDimsPerm();
1773  if (!outerDimsPerm.empty())
1774  applyPermutationToVector(inputShape,
1776  for (auto [idx, size] : enumerate(innerTiles))
1777  inputShape[innerDimsPos[idx]] *= size;
1778  auto maskedRead = vector::createReadOrMaskedRead(
1779  rewriter, loc, packOp.getSource(), inputShape, padValue,
1780  useInBoundsInsteadOfMasking);
1781 
1782  // Create ShapeCastOp.
1783  SmallVector<int64_t> destShape(inputVectorSizes);
1784  destShape.append(innerTiles.begin(), innerTiles.end());
1785  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1786  packOp.getDestType().getElementType());
1787  auto shapeCastOp =
1788  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1789 
1790  // Create TransposeOp.
1791  auto destPermutation =
1793  auto transposeOp = rewriter.create<vector::TransposeOp>(
1794  loc, shapeCastOp.getResult(), destPermutation);
1795 
1796  // Create TransferWriteOp.
1797  Value dest = rewriter.create<tensor::EmptyOp>(
1798  loc, reifiedReturnShapes[0],
1799  transposeOp.getResult().getType().getElementType());
1800  Operation *write =
1801  createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
1802  newResults.push_back(write->getResult(0));
1803  return success();
1804 }
1805 
1806 /// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1807 /// Vector::TransferReadOp - Reads a vector from the source tensor
1808 /// vector::TransposeOp - Transpose the Source tensor
1809 /// ShapeCastOp - Reshape the data based on the target.
1810 /// vector::TransferWriteOp. - Write the result vector back to the destination
1811 /// tensor.
1812 /// If the vector sizes are not provided:
1813 /// * the vector sizes are determined by the input operand and attributes,
1814 /// * update the inBounds attribute instead of masking.
1815 static LogicalResult
1816 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1817  ArrayRef<int64_t> inputVectorSizes,
1818  SmallVectorImpl<Value> &newResults) {
1819 
1820  // TODO: Introduce a parent class that will handle the insertion point update.
1821  OpBuilder::InsertionGuard g(rewriter);
1822  rewriter.setInsertionPoint(unpackOp);
1823 
1824  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1825 
1826  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1827  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1828  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1829  bool useInBoundsInsteadOfMasking = false;
1830  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1831 
1832  auto destSize = unpackOp.getDestRank();
1833 
1834  if (!inputVectorSizes.empty())
1835  assert(inputVectorSizes.size() == destSize &&
1836  "Incorrect number of input vector sizes");
1837 
1838  // vectorSizes is the shape of the vector that will be used to do final
1839  // write on the destination tensor. It is set like this: Let's say the
1840  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1841  // Thus:
1842  // 1. vectorSizes = sourceShape.take_front(N)
1843  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1844  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1845  // innerTiles attribute value.
1846  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1847  if (vectorSizes.empty()) {
1848  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1849  if (!outerDimsPerm.empty())
1851  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1852  vectorSizes[pos] *= innerTiles[i];
1853 
1854  useInBoundsInsteadOfMasking = true;
1855  }
1856 
1857  // readVectorSizes is the size of tensor used to read and apply mask. It is
1858  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1859  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1860  // size M-N
1861  // Thus:
1862  // - initially: readVectorSizes = vectorInputSizes
1863  // - Divide all the readMaskShape locations pointed by innerDimPos
1864  // by the innerTileSize attribute value.
1865  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1866  // - Append the remaining shape from SS
1867  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1868  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1869  // 128] and outer_dims_perm is [1, 0] then read shape is:
1870  // ReadVectorSizes(initial): [512, 128]
1871  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1872  // = [16, 8]
1873  // After applying outer_dims_perm: [8, 16]
1874  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1875 
1876  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1877 
1878  for (auto [index, size] : enumerate(innerTiles)) {
1879  readVectorSizes[innerDimPos[index]] =
1880  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1881  }
1882  if (!outerDimsPerm.empty()) {
1883  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1884  }
1885  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1886  sourceShape.end());
1887 
1888  ReifiedRankedShapedTypeDims reifiedRetShapes;
1889  LogicalResult status =
1890  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1891  .reifyResultShapes(rewriter, reifiedRetShapes);
1892  if (status.failed()) {
1893  LDBG("Unable to reify result shapes of " << unpackOp);
1894  return failure();
1895  }
1896  Location loc = unpackOp->getLoc();
1897 
1898  auto padValue = rewriter.create<arith::ConstantOp>(
1899  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1900 
1901  // Read result, mask if necessary. If transferReadOp shape is not equal
1902  // to shape of source, then a mask is necessary.
1904  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1905  /*useInBoundsInsteadOfMasking=*/false);
1906 
1907  PackingMetadata packMetadata;
1908  SmallVector<int64_t> lastDimToInsertPosPerm =
1909  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1910  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1911  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1912  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1913  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1914  RankedTensorType stripMineTensorType =
1915  RankedTensorType::get(stripMineShape, stripMineElemType);
1916  // Transpose the appropriate rows to match output.
1917  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1918  loc, readResult, lastDimToInsertPosPerm);
1919 
1920  // Collapse the vector to the size required by result.
1921  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1922  stripMineTensorType, packMetadata.reassociations);
1923  mlir::VectorType vecCollapsedType =
1924  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1925  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1926  loc, vecCollapsedType, transposeOp->getResult(0));
1927 
1928  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1929  // otherwise the validator complains that the mask size is invalid.
1930  SmallVector<int64_t> writeVectorSizes(
1931  unpackOp.getDestType().hasStaticShape()
1932  ? vectorSizes
1933  : shapeCastOp.getResultVectorType().getShape());
1934  Value dest = rewriter.create<tensor::EmptyOp>(
1935  loc, reifiedRetShapes[0],
1936  shapeCastOp.getResult().getType().getElementType());
1938  rewriter, loc, shapeCastOp.getResult(), dest,
1939  /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1940  newResults.push_back(write->getResult(0));
1941  return success();
1942 }
1943 
1944 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1945 /// and (3) all-zero lowPad to
1946 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1947 static LogicalResult
1948 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1949  ArrayRef<int64_t> inputVectorSizes,
1950  SmallVectorImpl<Value> &newResults) {
1951  auto padValue = padOp.getConstantPaddingValue();
1952  Location loc = padOp.getLoc();
1953 
1954  // TODO: Introduce a parent class that will handle the insertion point update.
1955  OpBuilder::InsertionGuard g(rewriter);
1956  rewriter.setInsertionPoint(padOp);
1957 
1958  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1959  LogicalResult status =
1960  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1961  .reifyResultShapes(rewriter, reifiedReturnShapes);
1962  (void)status; // prevent unused variable warning on non-assert builds
1963  assert(succeeded(status) && "failed to reify result shapes");
1964  auto maskedRead = vector::createReadOrMaskedRead(
1965  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1966  /*useInBoundsInsteadOfMasking=*/false);
1967 
1968  // Create Xfer write Op
1969  Value dest = rewriter.create<tensor::EmptyOp>(
1970  loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1971  Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
1972  newResults.push_back(write->getResult(0));
1973  return success();
1974 }
1975 
1976 // TODO: probably need some extra checks for reduction followed by consumer
1977 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1978 static LogicalResult reductionPreconditions(LinalgOp op) {
1979  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1980  LDBG("reduction precondition failed: no reduction iterator\n");
1981  return failure();
1982  }
1983  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1984  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1985  if (indexingMap.isPermutation())
1986  continue;
1987 
1988  Operation *reduceOp = matchLinalgReduction(&opOperand);
1989  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1990  LDBG("reduction precondition failed: reduction detection failed\n");
1991  return failure();
1992  }
1993  }
1994  return success();
1995 }
1996 
1997 static LogicalResult
1999  bool flatten1DDepthwiseConv) {
2000  if (flatten1DDepthwiseConv) {
2001  LDBG("Vectorization of flattened convs with dynamic shapes is not "
2002  "supported\n");
2003  return failure();
2004  }
2005 
2006  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2007  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
2008  return failure();
2009  }
2010 
2011  // Support dynamic shapes in 1D depthwise convolution, but only in the
2012  // _channel_ dimension.
2013  Value lhs = conv.getDpsInputOperand(0)->get();
2014  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2015  auto shapeWithoutCh = lhsShape.drop_back(1);
2016  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2017  LDBG("Dynamically-shaped op vectorization precondition failed: only "
2018  "channel dim can be dynamic\n");
2019  return failure();
2020  }
2021 
2022  return success();
2023 }
2024 
2025 static LogicalResult
2027  bool flatten1DDepthwiseConv) {
2028  if (isa<ConvolutionOpInterface>(op.getOperation()))
2029  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2030 
2031  if (hasReductionIterator(op))
2032  return reductionPreconditions(op);
2033 
2034  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2035  // linalg.copy ops and ops that implement ContractionOpInterface for now.
2036  if (!isElementwise(op) &&
2037  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2038  op.getOperation()))
2039  return failure();
2040 
2041  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
2042  return success();
2043 }
2044 
2045 /// Need to check if the inner-tiles are static/constant.
2046 static LogicalResult
2047 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2048  ArrayRef<int64_t> inputVectorSizes) {
2049 
2050  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
2051  return !getConstantIntValue(res).has_value();
2052  })) {
2053  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
2054  return failure();
2055  }
2056  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2057  bool satisfyEmptyCond = inputVectorSizes.empty() &&
2058  unpackOp.getDestType().hasStaticShape() &&
2059  unpackOp.getSourceType().hasStaticShape();
2060  if (!satisfyEmptyCond &&
2061  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
2062  return failure();
2063 
2064  return success();
2065 }
2066 
2067 static LogicalResult
2068 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2069  ArrayRef<int64_t> inputVectorSizes) {
2070 
2071  TypedValue<RankedTensorType> source = sliceOp.getSource();
2072  auto sourceType = source.getType();
2073  if (!VectorType::isValidElementType(sourceType.getElementType()))
2074  return failure();
2075 
2076  // Get the pad value.
2077  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2078  // scalar padding value. Note that:
2079  // * for in-bounds accesses,
2080  // the value is actually irrelevant. There are 2 cases in which xfer.read
2081  // accesses are known to be in-bounds:
2082  // 1. The source shape is static (output vector sizes would be based on
2083  // the source shape and hence all memory accesses would be in-bounds),
2084  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2085  // this case it is safe to assume that all memory accesses are in-bounds.
2086  //
2087  // When the value is not known and not needed, use 0. Otherwise, bail out.
2088  Value padValue = getStaticPadVal(sliceOp);
2089  bool isOutOfBoundsRead =
2090  !sourceType.hasStaticShape() && inputVectorSizes.empty();
2091 
2092  if (!padValue && isOutOfBoundsRead) {
2093  LDBG("Failed to get a pad value for out-of-bounds read access\n");
2094  return failure();
2095  }
2096  return success();
2097 }
2098 
2099 namespace {
2100 enum class ConvOperationKind { Conv, Pool };
2101 } // namespace
2102 
2104  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2105  isa<BlockArgument>(op->getOperand(0));
2106 }
2107 
2108 // Returns the ConvOperationKind of the op using reduceOp of the generic
2109 // payload. If it is neither a convolution nor a pooling, it returns
2110 // std::nullopt.
2111 //
2112 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2113 // + yield) and rhs is not used) then it is the body of a pooling
2114 // If conv, check for single `mul` predecessor. The `mul` operands must be
2115 // block arguments or extension of block arguments.
2116 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2117 // must be block arguments or extension of block arguments.
2118 static std::optional<ConvOperationKind>
2120  int numBlockArguments =
2121  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2122 
2123  switch (numBlockArguments) {
2124  case 1: {
2125  // Will be convolution if feeder is a MulOp.
2126  // A strength reduced version of MulOp for i1 type is AndOp which is also
2127  // supported. Otherwise, it can be pooling. This strength reduction logic
2128  // is in `buildBinaryFn` helper in the Linalg dialect.
2129  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2130  llvm::IsaPred<BlockArgument>);
2131  assert(feedValIt != reduceOp->operand_end() &&
2132  "Expected a non-block argument operand");
2133  Operation *feedOp = (*feedValIt).getDefiningOp();
2134  if (isCastOfBlockArgument(feedOp)) {
2135  return ConvOperationKind::Pool;
2136  }
2137 
2138  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2139  (isa<arith::AndIOp>(feedOp) &&
2140  feedOp->getResultTypes()[0].isInteger(1))) &&
2141  llvm::all_of(feedOp->getOperands(), [](Value v) {
2142  if (isa<BlockArgument>(v))
2143  return true;
2144  if (Operation *op = v.getDefiningOp())
2145  return isCastOfBlockArgument(op);
2146  return false;
2147  }))) {
2148  return std::nullopt;
2149  }
2150 
2151  return ConvOperationKind::Conv;
2152  }
2153  case 2:
2154  // Must be pooling
2155  return ConvOperationKind::Pool;
2156  default:
2157  return std::nullopt;
2158  }
2159 }
2160 
2161 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2162  switch (kind) {
2163  case vector::CombiningKind::ADD:
2164  case vector::CombiningKind::MAXNUMF:
2165  case vector::CombiningKind::MAXIMUMF:
2166  case vector::CombiningKind::MAXSI:
2167  case vector::CombiningKind::MAXUI:
2168  case vector::CombiningKind::MINNUMF:
2169  case vector::CombiningKind::MINIMUMF:
2170  case vector::CombiningKind::MINSI:
2172  return true;
2173  default:
2174  return false;
2175  }
2176 }
2177 
2178 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2179  auto getOperandType = [&](auto operand) {
2180  return dyn_cast<ShapedType>((operand->get()).getType());
2181  };
2182  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2183  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2184  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2185  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2186  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2187  // Note that this also ensures 2D and 3D convolutions are rejected.
2188  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2189  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2190  return failure();
2191 
2192  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2193  if (!reduceOp)
2194  return failure();
2195 
2196  auto maybeOper = getConvOperationKind(reduceOp);
2197  if (!maybeOper.has_value())
2198  return failure();
2199 
2200  auto maybeKind = getCombinerOpKind(reduceOp);
2201  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2202  // can get strength reduced to `OR` which is also supported. This strength
2203  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2204  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2205  *maybeKind != vector::CombiningKind::OR) &&
2206  (*maybeOper != ConvOperationKind::Pool ||
2207  !isSupportedPoolKind(*maybeKind)))) {
2208  return failure();
2209  }
2210 
2211  auto rhsRank = rhsShapedType.getRank();
2212  if (*maybeOper == ConvOperationKind::Pool) {
2213  if (rhsRank != 1)
2214  return failure();
2215  } else {
2216  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2217  return failure();
2218  }
2219 
2220  return success();
2221 }
2222 
2223 static LogicalResult vectorizeLinalgOpPrecondition(
2224  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2225  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2226  // tensor with dimension of 0 cannot be vectorized.
2227  if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2228  return llvm::is_contained(linalgOp.getShape(&operand), 0);
2229  }))
2230  return failure();
2231  // Check API contract for input vector sizes.
2232  if (!inputVectorSizes.empty() &&
2233  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2234  inputVectorSizes)))
2235  return failure();
2236 
2237  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2238  linalgOp, flatten1DDepthwiseConv))) {
2239  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
2240  return failure();
2241  }
2242 
2244 
2245  // Register CustomVectorizationPrecondition for extractOp.
2246  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2247 
2248  // All types in the body should be a supported element type for VectorType.
2249  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2250  // Check if any custom hook can vectorize the inner op.
2251  if (llvm::any_of(
2252  customPreconditions,
2253  [&](const CustomVectorizationPrecondition &customPrecondition) {
2254  return succeeded(
2255  customPrecondition(&innerOp, vectorizeNDExtract));
2256  })) {
2257  continue;
2258  }
2259  if (!llvm::all_of(innerOp.getOperandTypes(),
2260  VectorType::isValidElementType)) {
2261  return failure();
2262  }
2263  if (!llvm::all_of(innerOp.getResultTypes(),
2264  VectorType::isValidElementType)) {
2265  return failure();
2266  }
2267  }
2268  if (isElementwise(linalgOp))
2269  return success();
2270 
2271  // TODO: isaConvolutionOpInterface that can also infer from generic
2272  // features. But we will still need stride/dilation attributes that will be
2273  // annoying to reverse-engineer...
2274  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2275  return vectorizeConvOpPrecondition(linalgOp);
2276 
2277  // TODO: the common vector shape is equal to the static loop sizes only when
2278  // all indexing maps are projected permutations. For convs and stencils the
2279  // logic will need to evolve.
2280  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2281  LDBG("precondition failed: not projected permutations\n");
2282  return failure();
2283  }
2284  if (failed(reductionPreconditions(linalgOp))) {
2285  LDBG("precondition failed: reduction preconditions\n");
2286  return failure();
2287  }
2288  return success();
2289 }
2290 
2291 static LogicalResult
2292 vectorizePackOpPrecondition(linalg::PackOp packOp,
2293  ArrayRef<int64_t> inputVectorSizes) {
2294  auto padValue = packOp.getPaddingValue();
2295  Attribute cstAttr;
2296  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2297  LDBG("pad value is not constant: " << packOp << "\n");
2298  return failure();
2299  }
2300  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2301  bool satisfyEmptyCond = true;
2302  if (inputVectorSizes.empty()) {
2303  if (!packOp.getDestType().hasStaticShape() ||
2304  !packOp.getSourceType().hasStaticShape())
2305  satisfyEmptyCond = false;
2306  }
2307 
2308  if (!satisfyEmptyCond &&
2310  resultTensorShape.take_front(packOp.getSourceRank()),
2311  inputVectorSizes)))
2312  return failure();
2313 
2314  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2315  return !getConstantIntValue(v).has_value();
2316  })) {
2317  LDBG("inner_tiles must be constant: " << packOp << "\n");
2318  return failure();
2319  }
2320 
2321  return success();
2322 }
2323 
2324 static LogicalResult
2325 vectorizePadOpPrecondition(tensor::PadOp padOp,
2326  ArrayRef<int64_t> inputVectorSizes) {
2327  auto padValue = padOp.getConstantPaddingValue();
2328  if (!padValue) {
2329  LDBG("pad value is not constant: " << padOp << "\n");
2330  return failure();
2331  }
2332 
2333  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2334  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2335  inputVectorSizes)))
2336  return failure();
2337 
2338  // Padding with non-zero low pad values is not supported, unless the
2339  // corresponding result dim is 1 as this would require shifting the results to
2340  // the right for the low padded dims by the required amount of low padding.
2341  // However, we do support low padding if the dims being low padded have result
2342  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2343  // input size of that dimension will be dynamically zero (as the sum of the
2344  // low pad and input dim size has to be one) and hence we will create a zero
2345  // mask as the lowering logic just makes the mask one for the input dim size -
2346  // which is zero here. Hence we will load the pad value which is what we want
2347  // in this case. If the low pad is dynamically zero then the lowering is
2348  // correct as well as no shifts are necessary.
2349  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2350  Value padValue = en.value();
2351  unsigned pos = en.index();
2352  std::optional<int64_t> pad = getConstantIntValue(padValue);
2353  return (!pad.has_value() || pad.value() != 0) &&
2354  resultTensorShape[pos] != 1;
2355  })) {
2356  LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
2357  return failure();
2358  }
2359 
2360  return success();
2361 }
2362 
2363 /// Preconditions for scalable vectors. This is quite restrictive - it models
2364 /// the fact that in practice we would only make selected dimensions scalable.
2365 static LogicalResult
2367  ArrayRef<int64_t> inputVectorSizes,
2368  ArrayRef<bool> inputScalableVecDims) {
2369  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2370  "Number of input vector sizes and scalable dims doesn't match");
2371 
2372  size_t numOfScalableDims =
2373  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2374 
2375  if (numOfScalableDims == 0)
2376  return success();
2377 
2378  auto linalgOp = dyn_cast<LinalgOp>(op);
2379 
2380  // Cond 1: There's been no need for scalable vectorisation of
2381  // non-linalg Ops so far
2382  if (!linalgOp)
2383  return failure();
2384 
2385  // Cond 2: There's been no need for more than 2 scalable dims so far
2386  if (numOfScalableDims > 2)
2387  return failure();
2388 
2389  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2390  // it matches one of the supported cases:
2391  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2392  // (*).
2393  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2394  // parallel dims.
2395  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2396  // dim.
2397  // The 2nd restriction above means that only Matmul-like Ops are supported
2398  // when 2 dims are scalable, e.g. :
2399  // * iterators = [parallel, parallel, reduction]
2400  // * scalable flags = [true, true, false]
2401  //
2402  // (*) Non-unit dims get folded away in practice.
2403  // TODO: Relax these conditions as good motivating examples are identified.
2404 
2405  // Find the first scalable flag.
2406  bool seenNonUnitParallel = false;
2407  auto iterators = linalgOp.getIteratorTypesArray();
2408  SmallVector<bool> scalableFlags(inputScalableVecDims);
2409  int64_t idx = scalableFlags.size() - 1;
2410  while (!scalableFlags[idx]) {
2411  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2412  seenNonUnitParallel |=
2413  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2414 
2415  iterators.pop_back();
2416  scalableFlags.pop_back();
2417  --idx;
2418  }
2419 
2420  // Analyze the iterator corresponding to the first scalable dim.
2421  switch (iterators.back()) {
2422  case utils::IteratorType::reduction: {
2423  // Check 3. above is met.
2424  if (iterators.size() != inputVectorSizes.size()) {
2425  LDBG("Non-trailing reduction dim requested for scalable "
2426  "vectorization\n");
2427  return failure();
2428  }
2429  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2430  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2431  "is not supported\n");
2432  return failure();
2433  }
2434  break;
2435  }
2436  case utils::IteratorType::parallel: {
2437  // Check 1. and 2. above are met.
2438  if (seenNonUnitParallel) {
2439  LDBG("Inner parallel dim not requested for scalable "
2440  "vectorization\n");
2441  return failure();
2442  }
2443  break;
2444  }
2445  }
2446 
2447  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2448  // supported for which expect the folowing config:
2449  // * iterators = [parallel, parallel, reduction]
2450  // * scalable flags = [true, true, false]
2451  if (numOfScalableDims == 2) {
2452  // Disallow below case which breaks 3. above:
2453  // * iterators = [..., parallel, reduction]
2454  // * scalable flags = [..., true, true]
2455  if (iterators.back() == utils::IteratorType::reduction) {
2456  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2457  "vectorization\n");
2458  return failure();
2459  }
2460  scalableFlags.pop_back();
2461  iterators.pop_back();
2462 
2463  if (!scalableFlags.back() ||
2464  (iterators.back() != utils::IteratorType::parallel))
2465  return failure();
2466  }
2467 
2468  // Check to not let go the matmul with extended semantic, through this
2469  // transform.
2470  if (linalgOp.hasUserDefinedMaps())
2471  return failure();
2472 
2473  // Cond 4: Only the following ops are supported in the
2474  // presence of scalable vectors
2475  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2476  isa<linalg::MatmulTransposeAOp>(op) ||
2477  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2478  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2479 }
2480 
2482  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2483  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2484  bool flatten1DDepthwiseConv) {
2485 
2486  if (!hasVectorizationImpl(op))
2487  return failure();
2488 
2489  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2490  inputScalableVecDims)))
2491  return failure();
2492 
2494  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2495  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2496  vectorizeNDExtract,
2497  flatten1DDepthwiseConv);
2498  })
2499  .Case<tensor::PadOp>([&](auto padOp) {
2500  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2501  })
2502  .Case<linalg::PackOp>([&](auto packOp) {
2503  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2504  })
2505  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2506  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2507  })
2508  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2509  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2510  })
2511  .Default([](auto) { return failure(); });
2512 }
2513 
2514 /// Converts affine.apply Ops to arithmetic operations.
2515 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2516  OpBuilder::InsertionGuard g(rewriter);
2517  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2518 
2519  for (auto op : make_early_inc_range(toReplace)) {
2520  rewriter.setInsertionPoint(op);
2521  auto expanded = affine::expandAffineExpr(
2522  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2523  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2524  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2525  rewriter.replaceOp(op, expanded);
2526  }
2527 }
2528 
2530  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2531  tensor::InsertSliceOp>(op);
2532 }
2533 
2534 FailureOr<VectorizationResult>
2536  ArrayRef<int64_t> inputVectorSizes,
2537  ArrayRef<bool> inputScalableVecDims,
2538  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2539  LDBG("Attempting to vectorize:\n" << *op << "\n");
2540  LDBG("Input vector sizes: ");
2541  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2542  LLVM_DEBUG(llvm::dbgs() << "\n");
2543  LDBG("Input scalable vector dims: ");
2544  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2545  LLVM_DEBUG(llvm::dbgs() << "\n");
2546 
2547  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2548  vectorizeNDExtract,
2549  flatten1DDepthwiseConv))) {
2550  LDBG("Vectorization pre-conditions failed\n");
2551  return failure();
2552  }
2553 
2554  // Initialize vectorization state.
2555  VectorizationState state(rewriter);
2556  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2557  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2558  inputScalableVecDims))) {
2559  LDBG("Vectorization state couldn't be initialized\n");
2560  return failure();
2561  }
2562  }
2563 
2564  SmallVector<Value> results;
2565  auto vectorizeResult =
2567  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2568  // TODO: isaConvolutionOpInterface that can also infer from
2569  // generic features. Will require stride/dilation attributes
2570  // inference.
2571  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2572  FailureOr<Operation *> convOr = vectorizeConvolution(
2573  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2574  flatten1DDepthwiseConv);
2575  if (succeeded(convOr)) {
2576  llvm::append_range(results, (*convOr)->getResults());
2577  return success();
2578  }
2579 
2580  LDBG("Unsupported convolution can't be vectorized.\n");
2581  return failure();
2582  }
2583 
2584  LDBG("Vectorize generic by broadcasting to the canonical vector "
2585  "shape\n");
2586 
2587  // Pre-process before proceeding.
2588  convertAffineApply(rewriter, linalgOp);
2589 
2590  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2591  // to 'OpBuilder' when it is passed over to some methods like
2592  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2593  // erase an op within these methods, the actual rewriter won't be
2594  // notified and we will end up with read-after-free issues!
2595  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2596  })
2597  .Case<tensor::PadOp>([&](auto padOp) {
2598  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2599  results);
2600  })
2601  .Case<linalg::PackOp>([&](auto packOp) {
2602  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2603  results);
2604  })
2605  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2606  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2607  inputVectorSizes, results);
2608  })
2609  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2610  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2611  results);
2612  })
2613  .Default([](auto) { return failure(); });
2614 
2615  if (failed(vectorizeResult)) {
2616  LDBG("Vectorization failed\n");
2617  return failure();
2618  }
2619 
2620  return VectorizationResult{results};
2621 }
2622 
2624  memref::CopyOp copyOp) {
2625  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2626  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2627  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2628  return failure();
2629 
2630  auto srcElementType = getElementTypeOrSelf(srcType);
2631  auto dstElementType = getElementTypeOrSelf(dstType);
2632  if (!VectorType::isValidElementType(srcElementType) ||
2633  !VectorType::isValidElementType(dstElementType))
2634  return failure();
2635 
2636  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2637  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2638 
2639  Location loc = copyOp->getLoc();
2640  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2641  SmallVector<Value> indices(srcType.getRank(), zero);
2642 
2643  Value readValue = rewriter.create<vector::TransferReadOp>(
2644  loc, readType, copyOp.getSource(), indices,
2645  /*padding=*/std::nullopt,
2646  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2647  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2648  readValue =
2649  rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2650  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2651  }
2652  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2653  loc, readValue, copyOp.getTarget(), indices,
2654  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2655  rewriter.replaceOp(copyOp, writeValue->getResults());
2656  return success();
2657 }
2658 
2659 //----------------------------------------------------------------------------//
2660 // Misc. vectorization patterns.
2661 //----------------------------------------------------------------------------//
2662 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2663 /// given operation type OpTy.
2664 template <typename OpTy>
2665 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2667 
2668  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2669  PatternRewriter &rewriter) const final {
2670  bool changed = false;
2671  // Insert users in vector, because some users may be replaced/removed.
2672  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2673  if (auto op = dyn_cast<OpTy>(user))
2674  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2675  return success(changed);
2676  }
2677 
2678 protected:
2679  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2680  tensor::PadOp padOp, OpTy op) const = 0;
2681 };
2682 
2683 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2684 /// ```
2685 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2686 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2687 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2688 /// ```
2689 /// is rewritten to:
2690 /// ```
2691 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2692 /// {in_bounds = [true, true]}
2693 /// : tensor<?x?xf32>, vector<17x5xf32>
2694 /// ```
2695 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2696 /// sure that the original padding value %cst was never used.
2697 ///
2698 /// This rewrite is possible if:
2699 /// - `xferOp` has no out-of-bounds dims or mask.
2700 /// - Low padding is static 0.
2701 /// - Single, scalar padding value.
2703  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2705  vector::TransferReadOp>::VectorizePadOpUserPattern;
2706 
2707  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2708  vector::TransferReadOp xferOp) const override {
2709  // Low padding must be static 0.
2710  if (!padOp.hasZeroLowPad())
2711  return failure();
2712  // Pad value must be a constant.
2713  auto padValue = padOp.getConstantPaddingValue();
2714  if (!padValue)
2715  return failure();
2716  // Padding value of existing `xferOp` is unused.
2717  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2718  return failure();
2719 
2720  rewriter.modifyOpInPlace(xferOp, [&]() {
2721  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2722  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2723  rewriter.getBoolArrayAttr(inBounds));
2724  xferOp.getBaseMutable().assign(padOp.getSource());
2725  xferOp.getPaddingMutable().assign(padValue);
2726  });
2727 
2728  return success();
2729  }
2730 };
2731 
2732 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2733 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2734 /// value, where the same amount of padding is immediately removed again after
2735 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2736 /// tensor value and apply out-of-bounds masking. E.g.:
2737 /// ```
2738 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2739 /// : tensor<...> to tensor<?x?xf32>
2740 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2741 /// %2 = vector.transfer_write %vec, %1[...]
2742 /// : vector<17x5xf32>, tensor<17x5xf32>
2743 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2744 /// : tensor<17x5xf32> to tensor<?x?xf32>
2745 /// ```
2746 /// is rewritten to:
2747 /// ```
2748 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2749 /// : tensor<...> to tensor<?x?xf32>
2750 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2751 /// tensor<?x?xf32>
2752 /// ```
2753 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2754 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2755 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2756 /// from %r's old dimensions.
2757 ///
2758 /// This rewrite is possible if:
2759 /// - Low padding is static 0.
2760 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2761 /// ExtractSliceOp trims the same amount of padding that was added
2762 /// beforehand.
2763 /// - Single, scalar padding value.
2765  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2767  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2768 
2769  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2770  vector::TransferWriteOp xferOp) const override {
2771  // TODO: support 0-d corner case.
2772  if (xferOp.getTransferRank() == 0)
2773  return failure();
2774 
2775  // Low padding must be static 0.
2776  if (!padOp.hasZeroLowPad())
2777  return failure();
2778  // Pad value must be a constant.
2779  auto padValue = padOp.getConstantPaddingValue();
2780  if (!padValue)
2781  return failure();
2782  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2783  if (!xferOp->hasOneUse())
2784  return failure();
2785  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2786  if (!trimPadding)
2787  return failure();
2788  // Only static zero offsets supported when trimming padding.
2789  if (!trimPadding.hasZeroOffset())
2790  return failure();
2791  // trimPadding must remove the amount of padding that was added earlier.
2792  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2793  return failure();
2794 
2795  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2796  rewriter.setInsertionPoint(xferOp);
2797 
2798  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2799  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2800  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2801  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2802  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2803  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2804 
2805  return success();
2806  }
2807 
2808  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2809  /// i.e., same dimensions.
2810  ///
2811  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2812  /// dimensions, this function tries to infer the (static) tensor size by
2813  /// looking at the defining op and utilizing op-specific knowledge.
2814  ///
2815  /// This is a conservative analysis. In case equal tensor sizes cannot be
2816  /// proven statically, this analysis returns `false` even though the tensor
2817  /// sizes may turn out to be equal at runtime.
2818  bool hasSameTensorSize(Value beforePadding,
2819  tensor::ExtractSliceOp afterTrimming) const {
2820  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2821  // result and CastOp operand.
2822  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2823  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2824  return true;
2825 
2826  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2827  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2828  // Only RankedTensorType supported.
2829  if (!t1 || !t2)
2830  return false;
2831  // Rank of both values must be the same.
2832  if (t1.getRank() != t2.getRank())
2833  return false;
2834 
2835  // All static dimensions must be the same. Mixed cases (e.g., dimension
2836  // static in `t1` but dynamic in `t2`) are not supported.
2837  for (unsigned i = 0; i < t1.getRank(); ++i) {
2838  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2839  return false;
2840  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2841  return false;
2842  }
2843 
2844  // Nothing more to check if all dimensions are static.
2845  if (t1.getNumDynamicDims() == 0)
2846  return true;
2847 
2848  // All dynamic sizes must be the same. The only supported case at the
2849  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2850  // thereof).
2851 
2852  // Apart from CastOp, only ExtractSliceOp is supported.
2853  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2854  if (!beforeSlice)
2855  return false;
2856 
2857  assert(static_cast<size_t>(t1.getRank()) ==
2858  beforeSlice.getMixedSizes().size());
2859  assert(static_cast<size_t>(t2.getRank()) ==
2860  afterTrimming.getMixedSizes().size());
2861 
2862  for (unsigned i = 0; i < t1.getRank(); ++i) {
2863  // Skip static dimensions.
2864  if (!t1.isDynamicDim(i))
2865  continue;
2866  auto size1 = beforeSlice.getMixedSizes()[i];
2867  auto size2 = afterTrimming.getMixedSizes()[i];
2868 
2869  // Case 1: Same value or same constant int.
2870  if (isEqualConstantIntOrValue(size1, size2))
2871  continue;
2872 
2873  // Other cases: Take a deeper look at defining ops of values.
2874  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2875  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2876  if (!v1 || !v2)
2877  return false;
2878 
2879  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2880  // CSE is run.)
2881  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2882  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2883  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2884  minOp1.getOperands() == minOp2.getOperands())
2885  continue;
2886 
2887  // Add additional cases as needed.
2888  }
2889 
2890  // All tests passed.
2891  return true;
2892  }
2893 };
2894 
2895 /// Returns the effective Pad value for the input op, provided it's a scalar.
2896 ///
2897 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2898 /// this Op performs padding, retrieve the padding value provided that it's
2899 /// a scalar and static/fixed for all the padded values. Returns an empty value
2900 /// otherwise.
2901 ///
2902 /// TODO: This is used twice (when checking vectorization pre-conditions and
2903 /// when vectorizing). Cache results instead of re-running.
2905  if (!op)
2906  return {};
2907 
2908  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2909  // being broadcast, provided that it's a scalar.
2910  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2911  auto source = bcast.getSource();
2912  if (llvm::dyn_cast<VectorType>(source.getType()))
2913  return {};
2914 
2915  return source;
2916  }
2917 
2918  // 2. linalg.fill - use the scalar input value that used to fill the output
2919  // tensor.
2920  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2921  return fill.getInputs()[0];
2922  }
2923 
2924  // 3. tensor.generateOp - can't guarantee the value is fixed without
2925  // analysing, bail out.
2926  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2927  return {};
2928  }
2929 
2930  // 4. vector.transfer_write - inspect the input vector that's written from. If
2931  // if contains a single value that has been broadcast (e.g. via
2932  // vector.broadcast), extract it, fail otherwise.
2933  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2934  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2935 
2936  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2937  // than the input tensor, then, provided it's constant, we'll extract the
2938  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2939  // TODO: Clarify the semantics when the input tensor is larger than the
2940  // destination.
2941  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2942  return getStaticPadVal(slice.getDest().getDefiningOp());
2943 
2944  return {};
2945 }
2946 
2947 static LogicalResult
2948 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2949  ArrayRef<int64_t> inputVectorSizes,
2950  SmallVectorImpl<Value> &newResults) {
2951  // TODO: Introduce a parent class that will handle the insertion point update.
2952  OpBuilder::InsertionGuard g(rewriter);
2953  rewriter.setInsertionPoint(sliceOp);
2954 
2955  TypedValue<RankedTensorType> source = sliceOp.getSource();
2956  auto sourceType = source.getType();
2957  auto resultType = sliceOp.getResultType();
2958 
2959  Value padValue = getStaticPadVal(sliceOp);
2960 
2961  if (!padValue) {
2962  auto elemType = sourceType.getElementType();
2963  padValue = rewriter.create<arith::ConstantOp>(
2964  sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2965  }
2966 
2967  // 2. Get the vector shape
2968  SmallVector<int64_t> vecShape;
2969  size_t rankDiff = resultType.getRank() - sourceType.getRank();
2970  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2971  if (!inputVectorSizes.empty()) {
2972  vecShape.push_back(inputVectorSizes[i]);
2973  } else if (!sourceType.isDynamicDim(i)) {
2974  vecShape.push_back(sourceType.getDimSize(i));
2975  } else if (!resultType.isDynamicDim(i)) {
2976  // Source shape is not statically known, but result shape is.
2977  // Vectorize with size of result shape. This may be larger than the
2978  // source size.
2979  // FIXME: Using rankDiff implies that the source tensor is inserted at
2980  // the end of the destination tensor. However, that's not required.
2981  vecShape.push_back(resultType.getDimSize(rankDiff + i));
2982  } else {
2983  // Neither source nor result dim of padOp is static. Cannot vectorize
2984  // the copy.
2985  return failure();
2986  }
2987  }
2988  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2989 
2990  // 3. Generate TransferReadOp + TransferWriteOp
2991  auto loc = sliceOp.getLoc();
2992 
2993  // Create read
2994  SmallVector<Value> readIndices(
2995  vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
2997  rewriter, loc, source, vecType.getShape(), padValue,
2998  /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
2999 
3000  // Create write
3001  auto writeIndices =
3002  getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3003  Operation *write =
3004  createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3005  writeIndices, inputVectorSizes.empty());
3006 
3007  // 4. Finalize
3008  newResults.push_back(write->getResult(0));
3009 
3010  return success();
3011 }
3012 
3013 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3014 /// ```
3015 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3016 /// %r = tensor.insert_slice %0
3017 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3018 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3019 /// ```
3020 /// is rewritten to:
3021 /// ```
3022 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
3023 /// : tensor<?x?xf32>, vector<17x5xf32>
3024 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3025 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3026 /// ```
3027 ///
3028 /// This rewrite is possible if:
3029 /// - Low padding is static 0.
3030 /// - `padOp` result shape is static.
3031 /// - The entire padded tensor is inserted.
3032 /// (Implies that sizes of `insertOp` are all static.)
3033 /// - Only unit strides in `insertOp`.
3034 /// - Single, scalar padding value.
3035 /// - `padOp` result not used as destination.
3037  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3039  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3040 
3041  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3042  tensor::InsertSliceOp insertOp) const override {
3043  // Low padding must be static 0.
3044  if (!padOp.hasZeroLowPad())
3045  return failure();
3046  // Only unit stride supported.
3047  if (!insertOp.hasUnitStride())
3048  return failure();
3049  // Pad value must be a constant.
3050  auto padValue = padOp.getConstantPaddingValue();
3051  if (!padValue)
3052  return failure();
3053  // Dynamic shapes not supported.
3054  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3055  return failure();
3056  // Pad result not used as destination.
3057  if (insertOp.getDest() == padOp.getResult())
3058  return failure();
3059 
3060  auto vecType = VectorType::get(padOp.getType().getShape(),
3061  padOp.getType().getElementType());
3062  unsigned vecRank = vecType.getRank();
3063  unsigned tensorRank = insertOp.getType().getRank();
3064 
3065  // Check if sizes match: Insert the entire tensor into most minor dims.
3066  // (No permutations allowed.)
3067  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3068  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3069  if (!llvm::all_of(
3070  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3071  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3072  }))
3073  return failure();
3074 
3075  // Insert the TransferReadOp and TransferWriteOp at the position of the
3076  // InsertSliceOp.
3077  rewriter.setInsertionPoint(insertOp);
3078 
3079  // Generate TransferReadOp: Read entire source tensor and add high
3080  // padding.
3081  SmallVector<Value> readIndices(
3082  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
3083  auto read = rewriter.create<vector::TransferReadOp>(
3084  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3085 
3086  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3087  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3088  // source must fit into the destination at the specified offsets.
3089  auto writeIndices = getValueOrCreateConstantIndexOp(
3090  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3091  SmallVector<bool> inBounds(vecRank, true);
3092  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3093  insertOp, read, insertOp.getDest(), writeIndices,
3094  ArrayRef<bool>{inBounds});
3095 
3096  return success();
3097  }
3098 };
3099 
3101  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3105  patterns.getContext(), baseBenefit.getBenefit() + 1);
3106 }
3107 
3108 //----------------------------------------------------------------------------//
3109 // Forwarding patterns
3110 //----------------------------------------------------------------------------//
3111 
3112 /// Check whether there is any interleaved use of any `values` between
3113 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3114 /// is in a different block.
3115 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3116  ValueRange values) {
3117  if (firstOp->getBlock() != secondOp->getBlock() ||
3118  !firstOp->isBeforeInBlock(secondOp)) {
3119  LDBG("interleavedUses precondition failed, firstOp: "
3120  << *firstOp << ", second op: " << *secondOp << "\n");
3121  return true;
3122  }
3123  for (auto v : values) {
3124  for (auto &u : v.getUses()) {
3125  Operation *owner = u.getOwner();
3126  if (owner == firstOp || owner == secondOp)
3127  continue;
3128  // TODO: this is too conservative, use dominance info in the future.
3129  if (owner->getBlock() == firstOp->getBlock() &&
3130  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3131  continue;
3132  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
3133  << ", second op: " << *secondOp << "\n");
3134  return true;
3135  }
3136  }
3137  return false;
3138 }
3139 
3140 /// Return the unique subview use of `v` if it is indeed unique, null
3141 /// otherwise.
3142 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3143  memref::SubViewOp subViewOp;
3144  for (auto &u : v.getUses()) {
3145  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3146  if (subViewOp)
3147  return memref::SubViewOp();
3148  subViewOp = newSubViewOp;
3149  }
3150  }
3151  return subViewOp;
3152 }
3153 
3154 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3155 /// when available.
3157  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3158 
3159  // TODO: support mask.
3160  if (xferOp.getMask())
3161  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3162 
3163  // Transfer into `view`.
3164  Value viewOrAlloc = xferOp.getBase();
3165  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3166  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3167  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3168 
3169  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3170  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3171  if (!subViewOp)
3172  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3173  Value subView = subViewOp.getResult();
3174 
3175  // Find the copy into `subView` without interleaved uses.
3176  memref::CopyOp copyOp;
3177  for (auto &u : subView.getUses()) {
3178  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3179  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3180  if (newCopyOp.getTarget() != subView)
3181  continue;
3182  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3183  continue;
3184  copyOp = newCopyOp;
3185  break;
3186  }
3187  }
3188  if (!copyOp)
3189  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3190 
3191  // Find the fill into `viewOrAlloc` without interleaved uses before the
3192  // copy.
3193  FillOp maybeFillOp;
3194  for (auto &u : viewOrAlloc.getUses()) {
3195  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3196  assert(isa<MemRefType>(newFillOp.output().getType()));
3197  if (newFillOp.output() != viewOrAlloc)
3198  continue;
3199  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3200  continue;
3201  maybeFillOp = newFillOp;
3202  break;
3203  }
3204  }
3205  // Ensure padding matches.
3206  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3207  return rewriter.notifyMatchFailure(xferOp,
3208  "padding value does not match fill");
3209 
3210  // `in` is the subview that memref.copy reads. Replace it.
3211  Value in = copyOp.getSource();
3212 
3213  // memref.copy + linalg.fill can be used to create a padded local buffer.
3214  // The `masked` attribute is only valid on this padded buffer.
3215  // When forwarding to vector.transfer_read, the attribute must be reset
3216  // conservatively.
3217  auto vectorType = xferOp.getVectorType();
3218  Value res = rewriter.create<vector::TransferReadOp>(
3219  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3220  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3221  rewriter.getBoolArrayAttr(
3222  SmallVector<bool>(vectorType.getRank(), false)));
3223 
3224  if (maybeFillOp)
3225  rewriter.eraseOp(maybeFillOp);
3226  rewriter.eraseOp(copyOp);
3227  rewriter.replaceOp(xferOp, res);
3228 
3229  return success();
3230 }
3231 
3232 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3233 /// when available.
3235  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3236  // TODO: support mask.
3237  if (xferOp.getMask())
3238  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3239 
3240  // Transfer into `viewOrAlloc`.
3241  Value viewOrAlloc = xferOp.getBase();
3242  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3243  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3244  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3245 
3246  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3247  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3248  if (!subViewOp)
3249  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3250  Value subView = subViewOp.getResult();
3251 
3252  // Find the copy from `subView` without interleaved uses.
3253  memref::CopyOp copyOp;
3254  for (auto &u : subViewOp.getResult().getUses()) {
3255  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3256  if (newCopyOp.getSource() != subView)
3257  continue;
3258  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3259  continue;
3260  copyOp = newCopyOp;
3261  break;
3262  }
3263  }
3264  if (!copyOp)
3265  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3266 
3267  // `out` is the subview copied into that we replace.
3268  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3269  Value out = copyOp.getTarget();
3270 
3271  // Forward vector.transfer into copy.
3272  // memref.copy + linalg.fill can be used to create a padded local buffer.
3273  // The `masked` attribute is only valid on this padded buffer.
3274  // When forwarding to vector.transfer_write, the attribute must be reset
3275  // conservatively.
3276  auto vector = xferOp.getVector();
3277  rewriter.create<vector::TransferWriteOp>(
3278  xferOp.getLoc(), vector, out, xferOp.getIndices(),
3279  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3281  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3282 
3283  rewriter.eraseOp(copyOp);
3284  rewriter.eraseOp(xferOp);
3285 
3286  return success();
3287 }
3288 
3289 //===----------------------------------------------------------------------===//
3290 // Convolution vectorization patterns
3291 //===----------------------------------------------------------------------===//
3292 
3293 template <int N>
3294 static void bindShapeDims(ShapedType shapedType) {}
3295 
3296 template <int N, typename IntTy, typename... IntTy2>
3297 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3298  val = shapedType.getShape()[N];
3299  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3300 }
3301 
3302 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3303 template <typename... IntTy>
3304 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3305  bindShapeDims<0>(shapedType, vals...);
3306 }
3307 
3308 namespace {
3309 /// Generate a vector implementation for either:
3310 /// ```
3311 /// Op def: ( w, kw )
3312 /// Iters: ({Par(), Red()})
3313 /// Layout: {{w + kw}, {kw}, {w}}
3314 /// ```
3315 /// kw is unrolled.
3316 ///
3317 /// or
3318 ///
3319 /// ```
3320 /// Op def: ( n, w, c, kw, f )
3321 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3322 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3323 /// ```
3324 /// kw is unrolled, w is unrolled iff dilationW > 1.
3325 ///
3326 /// or
3327 ///
3328 /// ```
3329 /// Op def: ( n, c, w, f, kw )
3330 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3331 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3332 /// ```
3333 /// kw is unrolled, w is unrolled iff dilationW > 1.
3334 ///
3335 /// or
3336 ///
3337 /// ```
3338 /// Op def: ( n, w, c, kw )
3339 /// Iters: ({Par(), Par(), Par(), Red()})
3340 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3341 /// ```
3342 /// kw is unrolled, w is unrolled iff dilationW > 1.
3343 struct Conv1DGenerator
3344  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3345  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3346  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3347 
3348  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3349  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3350  resShaped = linalgOp.getDpsInitOperand(0)->get();
3351  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3352  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3353  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3354 
3355  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3356  redOp = reduceOp->getName().getIdentifier();
3357 
3358  setConvOperationKind(reduceOp);
3359 
3360  auto maybeKind = getCombinerOpKind(reduceOp);
3361  reductionKind = maybeKind.value();
3362 
3363  // The ConvolutionOpInterface gives us guarantees of existence for
3364  // strides/dilations. However, we do not need to rely on those, we can
3365  // simply use them if present, otherwise use the default and let the generic
3366  // conv. matcher in the ConvGenerator succeed or fail.
3367  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3368  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3369  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3370  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3371  }
3372 
3373  /// Generate a vector implementation for:
3374  /// ```
3375  /// Op def: ( w, kw )
3376  /// Iters: ({Par(), Red()})
3377  /// Layout: {{w + kw}, {kw}, {w}}
3378  /// ```
3379  /// kw is always unrolled.
3380  ///
3381  /// or
3382  ///
3383  /// ```
3384  /// Op def: ( n, w, c, kw, f )
3385  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3386  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3387  /// ```
3388  /// kw is always unrolled.
3389  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3390  /// > 1.
3391  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3392  int64_t nSize, wSize, cSize, kwSize, fSize;
3393  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3394  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3395  switch (conv1DOpOrder) {
3396  case Conv1DOpOrder::W:
3397  // Initialize unused dimensions
3398  nSize = fSize = cSize = 0;
3399  // out{W}
3400  bindShapeDims(resShapedType, wSize);
3401  // kernel{kw}
3402  bindShapeDims(rhsShapedType, kwSize);
3403  lhsShape = {// iw = ow + kw - 1
3404  // (i.e. 16 convolved with 3 -> 14)
3405  (wSize + kwSize - 1)};
3406  rhsShape = {kwSize};
3407  resShape = {wSize};
3408  break;
3409  case Conv1DOpOrder::Nwc:
3410  // out{n, w, f}
3411  bindShapeDims(resShapedType, nSize, wSize, fSize);
3412  switch (oper) {
3413  case ConvOperationKind::Conv:
3414  // kernel{kw, c, f}
3415  bindShapeDims(rhsShapedType, kwSize, cSize);
3416  break;
3417  case ConvOperationKind::Pool:
3418  // kernel{kw}
3419  bindShapeDims(rhsShapedType, kwSize);
3420  cSize = fSize;
3421  break;
3422  }
3423  lhsShape = {nSize,
3424  // iw = ow * sw + kw * dw - 1
3425  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3426  // Perform the proper inclusive -> exclusive -> inclusive.
3427  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3428  1,
3429  cSize};
3430  switch (oper) {
3431  case ConvOperationKind::Conv:
3432  rhsShape = {kwSize, cSize, fSize};
3433  break;
3434  case ConvOperationKind::Pool:
3435  rhsShape = {kwSize};
3436  break;
3437  }
3438  resShape = {nSize, wSize, fSize};
3439  break;
3440  case Conv1DOpOrder::Ncw:
3441  // out{n, f, w}
3442  bindShapeDims(resShapedType, nSize, fSize, wSize);
3443  switch (oper) {
3444  case ConvOperationKind::Conv:
3445  // kernel{f, c, kw}
3446  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3447  break;
3448  case ConvOperationKind::Pool:
3449  // kernel{kw}
3450  bindShapeDims(rhsShapedType, kwSize);
3451  cSize = fSize;
3452  break;
3453  }
3454  lhsShape = {nSize, cSize,
3455  // iw = ow * sw + kw * dw - 1
3456  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3457  // Perform the proper inclusive -> exclusive -> inclusive.
3458  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3459  1};
3460  switch (oper) {
3461  case ConvOperationKind::Conv:
3462  rhsShape = {fSize, cSize, kwSize};
3463  break;
3464  case ConvOperationKind::Pool:
3465  rhsShape = {kwSize};
3466  break;
3467  }
3468  resShape = {nSize, fSize, wSize};
3469  break;
3470  }
3471 
3472  vector::TransferWriteOp write;
3473  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3474 
3475  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3476  // When strideW == 1, we can batch the contiguous loads and avoid
3477  // unrolling
3478  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3479 
3480  Type lhsEltType = lhsShapedType.getElementType();
3481  Type rhsEltType = rhsShapedType.getElementType();
3482  Type resEltType = resShapedType.getElementType();
3483  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3484  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3485  auto resType = VectorType::get(resShape, resEltType);
3486  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3487  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3488  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3489  SmallVector<Value> resPadding(resShape.size(), zero);
3490 
3491  // Read the whole lhs, rhs and res in one shot (with zero padding).
3492  Value lhs = rewriter.create<vector::TransferReadOp>(
3493  loc, lhsType, lhsShaped, lhsPadding,
3494  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3495  // This is needed only for Conv.
3496  Value rhs = nullptr;
3497  if (oper == ConvOperationKind::Conv)
3498  rhs = rewriter.create<vector::TransferReadOp>(
3499  loc, rhsType, rhsShaped, rhsPadding,
3500  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3501  Value res = rewriter.create<vector::TransferReadOp>(
3502  loc, resType, resShaped, resPadding,
3503  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3504 
3505  // The base vectorization case for channeled convolution is input:
3506  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3507  // vectorization case, we do pre transpose on input, weight, and output.
3508  switch (conv1DOpOrder) {
3509  case Conv1DOpOrder::W:
3510  case Conv1DOpOrder::Nwc:
3511  // Base case, so no transposes necessary.
3512  break;
3513  case Conv1DOpOrder::Ncw: {
3514  // To match base vectorization case, we pre-transpose current case.
3515  // ncw -> nwc
3516  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3517  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3518  // fcw -> wcf
3519  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3520 
3521  // This is needed only for Conv.
3522  if (oper == ConvOperationKind::Conv)
3523  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3524  // nfw -> nwf
3525  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3526  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3527  break;
3528  }
3529  }
3530 
3531  //===------------------------------------------------------------------===//
3532  // Begin vector-only rewrite part
3533  //===------------------------------------------------------------------===//
3534  // Unroll along kw and read slices of lhs and rhs.
3535  SmallVector<Value> lhsVals, rhsVals, resVals;
3536  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3537  kwSize, strideW, dilationW, wSizeStep,
3538  isSingleChanneled);
3539  // Do not do for pooling.
3540  if (oper == ConvOperationKind::Conv)
3541  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3542  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3543  wSizeStep, isSingleChanneled);
3544 
3545  auto linearIndex = [&](int64_t kw, int64_t w) {
3546  return kw * (wSize / wSizeStep) + w;
3547  };
3548 
3549  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3550  // or perform outerproduct for non-channeled convolution or perform simple
3551  // arith operation for pooling
3552  for (int64_t kw = 0; kw < kwSize; ++kw) {
3553  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3554  switch (oper) {
3555  case ConvOperationKind::Conv:
3556  if (isSingleChanneled) {
3557  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3558  lhsVals[linearIndex(kw, w)],
3559  rhsVals[kw], resVals[w]);
3560  } else {
3561  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3562  lhsVals[linearIndex(kw, w)],
3563  rhsVals[kw], resVals[w]);
3564  }
3565  break;
3566  case ConvOperationKind::Pool:
3567  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3568  resVals[w]);
3569  break;
3570  }
3571  }
3572  }
3573 
3574  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3575  isSingleChanneled);
3576  //===------------------------------------------------------------------===//
3577  // End vector-only rewrite part
3578  //===------------------------------------------------------------------===//
3579 
3580  // The base vectorization case for channeled convolution is output:
3581  // {n,w,f} To reuse the result from base pattern vectorization case, we
3582  // post transpose the base case result.
3583  switch (conv1DOpOrder) {
3584  case Conv1DOpOrder::W:
3585  case Conv1DOpOrder::Nwc:
3586  // Base case, so no transposes necessary.
3587  break;
3588  case Conv1DOpOrder::Ncw: {
3589  // nwf -> nfw
3590  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3591  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3592  break;
3593  }
3594  }
3595 
3596  return rewriter
3597  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3598  .getOperation();
3599  }
3600 
3601  // Take a value and widen to have the same element type as `ty`.
3602  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3603  const Type srcElementType = getElementTypeOrSelf(val.getType());
3604  const Type dstElementType = getElementTypeOrSelf(ty);
3605  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3606  if (srcElementType == dstElementType)
3607  return val;
3608 
3609  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3610  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3611  const Type dstType =
3612  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3613 
3614  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3615  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3616  }
3617 
3618  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3619  srcWidth < dstWidth)
3620  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3621 
3622  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3623  srcWidth < dstWidth)
3624  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3625 
3626  assert(false && "unhandled promotion case");
3627  return nullptr;
3628  }
3629 
3630  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3631  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3632  Value lhs, Value rhs, Value res) {
3633  vector::IteratorType par = vector::IteratorType::parallel;
3634  vector::IteratorType red = vector::IteratorType::reduction;
3635  AffineExpr n, w, f, c;
3636  bindDims(ctx, n, w, f, c);
3637  lhs = promote(rewriter, loc, lhs, res.getType());
3638  rhs = promote(rewriter, loc, rhs, res.getType());
3639  auto contrationOp = rewriter.create<vector::ContractionOp>(
3640  loc, lhs, rhs, res,
3641  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3642  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3643  contrationOp.setKind(reductionKind);
3644  return contrationOp;
3645  }
3646 
3647  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3648  // convolution.
3649  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3650  Value lhs, Value rhs, Value res) {
3651  return rewriter.create<vector::OuterProductOp>(
3652  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3653  }
3654 
3655  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3656  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3657  Value res) {
3658  if (isPoolExt)
3659  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3660  return rewriter
3661  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3662  ->getResult(0);
3663  }
3664 
3665  /// Generate a vector implementation for:
3666  /// ```
3667  /// Op def: ( n, w, c, kw)
3668  /// Iters: ({Par(), Par(), Par(), Red()})
3669  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3670  /// ```
3671  /// kw is always unrolled.
3672  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3673  /// > 1.
3674  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3675  bool channelDimScalableFlag,
3676  bool flatten) {
3677  bool scalableChDim = false;
3678  bool useMasking = false;
3679  int64_t nSize, wSize, cSize, kwSize;
3680  // kernel{kw, c}
3681  bindShapeDims(rhsShapedType, kwSize, cSize);
3682  if (ShapedType::isDynamic(cSize)) {
3683  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3684  cSize = channelDimVecSize;
3685  // Scalable vectors are only used when both conditions are met:
3686  // 1. channel dim is dynamic
3687  // 2. channelDimScalableFlag is set
3688  scalableChDim = channelDimScalableFlag;
3689  useMasking = true;
3690  }
3691 
3692  assert(!(useMasking && flatten) &&
3693  "Unsupported flattened conv with dynamic shapes");
3694 
3695  // out{n, w, c}
3696  bindShapeDims(resShapedType, nSize, wSize);
3697 
3698  vector::TransferWriteOp write;
3699  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3700 
3701  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3702  // When strideW == 1, we can batch the contiguous loads and avoid
3703  // unrolling
3704  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3705 
3706  Type lhsEltType = lhsShapedType.getElementType();
3707  Type rhsEltType = rhsShapedType.getElementType();
3708  Type resEltType = resShapedType.getElementType();
3709  VectorType lhsType = VectorType::get(
3710  {nSize,
3711  // iw = ow * sw + kw * dw - 1
3712  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3713  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3714  cSize},
3715  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3716  VectorType rhsType =
3717  VectorType::get({kwSize, cSize}, rhsEltType,
3718  /*scalableDims=*/{false, scalableChDim});
3719  VectorType resType =
3720  VectorType::get({nSize, wSize, cSize}, resEltType,
3721  /*scalableDims=*/{false, false, scalableChDim});
3722 
3723  // Masks the input xfer Op along the channel dim, iff the corresponding
3724  // scalable flag is set.
3725  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3726  ArrayRef<bool> scalableDims,
3727  Operation *opToMask) {
3728  if (!useMasking)
3729  return opToMask;
3730  auto maskType =
3731  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3732 
3733  SmallVector<bool> inBounds(maskShape.size(), true);
3734  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3735  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3736  rewriter.getBoolArrayAttr(inBounds));
3737 
3739  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3740 
3741  Value maskOp =
3742  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3743 
3744  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3745  };
3746 
3747  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3748  // 0].
3749  Value lhs = rewriter.create<vector::TransferReadOp>(
3750  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3751  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3752  auto maybeMaskedLhs = maybeMaskXferOp(
3753  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3754 
3755  // Read rhs slice of size {kw, c} @ [0, 0].
3756  Value rhs = rewriter.create<vector::TransferReadOp>(
3757  loc, rhsType, rhsShaped, ValueRange{zero, zero},
3758  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3759  auto maybeMaskedRhs = maybeMaskXferOp(
3760  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3761 
3762  // Read res slice of size {n, w, c} @ [0, 0, 0].
3763  Value res = rewriter.create<vector::TransferReadOp>(
3764  loc, resType, resShaped, ValueRange{zero, zero, zero},
3765  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3766  auto maybeMaskedRes = maybeMaskXferOp(
3767  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3768 
3769  //===------------------------------------------------------------------===//
3770  // Begin vector-only rewrite part
3771  //===------------------------------------------------------------------===//
3772  // Unroll along kw and read slices of lhs and rhs.
3773  SmallVector<Value> lhsVals, rhsVals, resVals;
3774  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3775  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3776 
3777  // Extract lhs slice of size {n, wSizeStep, c}
3778  // @ [0, sw * w + dw * kw, 0].
3779  for (int64_t kw = 0; kw < kwSize; ++kw) {
3780  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3781  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3782  loc, maybeMaskedLhs->getResult(0),
3783  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3784  inOutSliceSizes, inOutStrides));
3785  }
3786  }
3787  // Extract rhs slice of size {c} @ [kw].
3788  for (int64_t kw = 0; kw < kwSize; ++kw) {
3789  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3790  loc, maybeMaskedRhs->getResult(0),
3791  /*offsets=*/ArrayRef<int64_t>{kw}));
3792  }
3793  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3794  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3795  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3796  loc, maybeMaskedRes->getResult(0),
3797  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3798  inOutStrides));
3799  }
3800 
3801  auto linearIndex = [&](int64_t kw, int64_t w) {
3802  return kw * (wSize / wSizeStep) + w;
3803  };
3804 
3805  // Note - the scalable flags are ignored as flattening combined with
3806  // scalable vectorization is not supported.
3807  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3808  auto lhsTypeAfterFlattening =
3809  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3810  auto resTypeAfterFlattening =
3811  VectorType::get(inOutFlattenSliceSizes, resEltType);
3812 
3813  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3814  for (int64_t kw = 0; kw < kwSize; ++kw) {
3815  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3816  Value lhsVal = lhsVals[linearIndex(kw, w)];
3817  Value resVal = resVals[w];
3818  if (flatten) {
3819  // Flatten the input and output vectors (collapse the channel
3820  // dimension)
3821  lhsVal = rewriter.create<vector::ShapeCastOp>(
3822  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3823  resVal = rewriter.create<vector::ShapeCastOp>(
3824  loc, resTypeAfterFlattening, resVals[w]);
3825  }
3826  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3827  rhsVals[kw], resVal, flatten);
3828  if (flatten) {
3829  // Un-flatten the output vector (restore the channel dimension)
3830  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3831  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3832  }
3833  }
3834  }
3835 
3836  // Its possible we failed to create the Fma.
3837  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3838  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3839  for (auto &collection :
3840  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3841  for (Value v : collection)
3842  rewriter.eraseOp(v.getDefiningOp());
3843  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3844  }
3845 
3846  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3847  // This does not depend on kw.
3848  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3849  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3850  loc, resVals[w], maybeMaskedRes->getResult(0),
3851  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3852  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3853  }
3854  //===------------------------------------------------------------------===//
3855  // End vector-only rewrite part
3856  //===------------------------------------------------------------------===//
3857 
3858  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3859  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3860  loc, maybeMaskedRes->getResult(0), resShaped,
3861  ValueRange{zero, zero, zero});
3862  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3863  resOut);
3864  }
3865 
3866  /// Lower:
3867  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3868  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3869  /// to MulAcc.
3870  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3871  Value lhs, Value rhs, Value res,
3872  bool flatten) {
3873  auto rhsTy = cast<ShapedType>(rhs.getType());
3874  auto resTy = cast<ShapedType>(res.getType());
3875 
3876  // TODO(suderman): Change this to use a vector.ima intrinsic.
3877  lhs = promote(rewriter, loc, lhs, resTy);
3878 
3879  if (flatten) {
3880  // NOTE: This following logic won't work for scalable vectors. For this
3881  // reason, "flattening" is not supported when shapes are dynamic (this
3882  // should be captured by one of the pre-conditions).
3883 
3884  // There are two options for handling the filter:
3885  // * shape_cast(broadcast(filter))
3886  // * broadcast(shuffle(filter))
3887  // Opt for the option without shape_cast to simplify the codegen.
3888  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3889  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3890 
3891  SmallVector<int64_t, 16> indices;
3892  for (int i = 0; i < resSize / rhsSize; ++i) {
3893  for (int j = 0; j < rhsSize; ++j)
3894  indices.push_back(j);
3895  }
3896 
3897  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3898  }
3899  // Broadcast the filter to match the output vector
3900  rhs = rewriter.create<vector::BroadcastOp>(
3901  loc, resTy.clone(rhsTy.getElementType()), rhs);
3902 
3903  rhs = promote(rewriter, loc, rhs, resTy);
3904 
3905  if (!lhs || !rhs)
3906  return nullptr;
3907 
3908  if (isa<FloatType>(resTy.getElementType()))
3909  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3910 
3911  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3912  return rewriter.create<arith::AddIOp>(loc, mul, res);
3913  }
3914 
3915  /// Entry point for non-channeled convolution:
3916  /// {{w + kw}, {kw}, {w}}
3917  FailureOr<Operation *> generateNonChanneledConv() {
3918  AffineExpr w, kw;
3919  bindDims(ctx, w, kw);
3920  if (!iters({Par(), Red()}))
3921  return rewriter.notifyMatchFailure(op,
3922  "failed to match conv::W 1-par 1-red");
3923 
3924  // No transposition needed.
3925  if (layout({/*lhsIndex*/ {w + kw},
3926  /*rhsIndex*/ {kw},
3927  /*resIndex*/ {w}}))
3928  return conv(Conv1DOpOrder::W);
3929 
3930  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3931  }
3932 
3933  /// Entry point that transposes into the common form:
3934  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3935  FailureOr<Operation *> generateNwcConv() {
3936  AffineExpr n, w, f, kw, c;
3937  bindDims(ctx, n, w, f, kw, c);
3938  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3939  return rewriter.notifyMatchFailure(
3940  op, "failed to match conv::Nwc 3-par 2-red");
3941 
3942  // No transposition needed.
3943  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3944  /*rhsIndex*/ {kw, c, f},
3945  /*resIndex*/ {n, w, f}}))
3946  return conv(Conv1DOpOrder::Nwc);
3947 
3948  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3949  }
3950 
3951  /// Entry point that transposes into the common form:
3952  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3953  FailureOr<Operation *> generateNcwConv() {
3954  AffineExpr n, w, f, kw, c;
3955  bindDims(ctx, n, f, w, c, kw);
3956  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3957  return rewriter.notifyMatchFailure(
3958  op, "failed to match conv::Ncw 3-par 2-red");
3959 
3960  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3961  /*rhsIndex*/ {f, c, kw},
3962  /*resIndex*/ {n, f, w}}))
3963  return conv(Conv1DOpOrder::Ncw);
3964 
3965  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3966  }
3967 
3968  /// Entry point that transposes into the common form:
3969  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3970  FailureOr<Operation *> generateNwcPooling() {
3971  AffineExpr n, w, c, kw;
3972  bindDims(ctx, n, w, c, kw);
3973  if (!iters({Par(), Par(), Par(), Red()}))
3974  return rewriter.notifyMatchFailure(op,
3975  "failed to match pooling 3-par 1-red");
3976 
3977  // No transposition needed.
3978  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3979  /*rhsIndex*/ {kw},
3980  /*resIndex*/ {n, w, c}}))
3981  return conv(Conv1DOpOrder::Nwc);
3982 
3983  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3984  }
3985 
3986  /// Entry point that transposes into the common form:
3987  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3988  FailureOr<Operation *> generateNcwPooling() {
3989  AffineExpr n, w, c, kw;
3990  bindDims(ctx, n, c, w, kw);
3991  if (!iters({Par(), Par(), Par(), Red()}))
3992  return rewriter.notifyMatchFailure(op,
3993  "failed to match pooling 3-par 1-red");
3994 
3995  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3996  /*rhsIndex*/ {kw},
3997  /*resIndex*/ {n, c, w}}))
3998  return conv(Conv1DOpOrder::Ncw);
3999 
4000  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4001  }
4002 
4003  /// Entry point that transposes into the common form:
4004  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4005  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4006  bool vecChDimScalableFlag = false,
4007  bool flatten = false) {
4008  AffineExpr n, w, c, kw;
4009  bindDims(ctx, n, w, c, kw);
4010  if (!iters({Par(), Par(), Par(), Red()}))
4011  return rewriter.notifyMatchFailure(
4012  op, "failed to match depthwise::Nwc conv 3-par 1-red");
4013 
4014  // No transposition needed.
4015  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4016  /*rhsIndex*/ {kw, c},
4017  /*resIndex*/ {n, w, c}}))
4018  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4019 
4020  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4021  }
4022 
4023 private:
4024  ConvOperationKind oper = ConvOperationKind::Conv;
4025  StringAttr redOp;
4026  StringAttr poolExtOp;
4027  bool isPoolExt = false;
4028  int strideW, dilationW;
4029  Value lhsShaped, rhsShaped, resShaped;
4030  ShapedType lhsShapedType, rhsShapedType, resShapedType;
4031  vector::CombiningKind reductionKind;
4032 
4033  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4034  void setConvOperationKind(Operation *reduceOp) {
4035  int numBlockArguments =
4036  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4037  if (numBlockArguments == 1) {
4038  // Will be convolution if feeder is a MulOp.
4039  // A strength reduced version of MulOp for i1 type is AndOp which is also
4040  // supported. Otherwise, it can be pooling. This strength reduction logic
4041  // is in `buildBinaryFn` helper in the Linalg dialect.
4042  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4043  llvm::IsaPred<BlockArgument>);
4044  Operation *feedOp = (*feedValIt).getDefiningOp();
4045  if (isCastOfBlockArgument(feedOp)) {
4046  oper = ConvOperationKind::Pool;
4047  isPoolExt = true;
4048  poolExtOp = feedOp->getName().getIdentifier();
4049  return;
4050  }
4051  oper = ConvOperationKind::Conv;
4052  return;
4053  }
4054  // numBlockArugments == 2 and this is a pooling op.
4055  oper = ConvOperationKind::Pool;
4056  isPoolExt = false;
4057  }
4058 };
4059 } // namespace
4060 
4061 /// Helper function to vectorize a LinalgOp with convolution semantics.
4062 // TODO: extend the generic vectorization to support windows and drop this.
4063 static FailureOr<Operation *> vectorizeConvolution(
4064  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4065  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4066  Conv1DGenerator conv1dGen(rewriter, op);
4067  auto res = conv1dGen.generateNonChanneledConv();
4068  if (succeeded(res))
4069  return res;
4070  res = conv1dGen.generateNwcConv();
4071  if (succeeded(res))
4072  return res;
4073  res = conv1dGen.generateNcwConv();
4074  if (succeeded(res))
4075  return res;
4076  res = conv1dGen.generateNwcPooling();
4077  if (succeeded(res))
4078  return res;
4079  res = conv1dGen.generateNcwPooling();
4080  if (succeeded(res))
4081  return res;
4082 
4083  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4084  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4085  // masked/scalable) is the channel dim (i.e. the trailing dim).
4086  uint64_t vecChDimSize = ShapedType::kDynamic;
4087  bool vecChDimScalableFlag = false;
4088  if (!inputVecSizes.empty()) {
4089  // Only use the input vector size corresponding to the channel dim. Other
4090  // vector dims will be inferred from the Ops.
4091  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4092  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4093  "Not a 1D depthwise conv!");
4094  size_t chDimIdx =
4096  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4097  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4098 
4099  vecChDimSize = inputVecSizes[chDimIdx];
4100  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4101  }
4102  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4103  flatten1DDepthwiseConv);
4104 }
4105 
4108 
4109  LogicalResult matchAndRewrite(LinalgOp op,
4110  PatternRewriter &rewriter) const override {
4111  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4112  if (failed(resultOrFail))
4113  return failure();
4114  Operation *newOp = *resultOrFail;
4115  if (newOp->getNumResults() == 0) {
4116  rewriter.eraseOp(op.getOperation());
4117  return success();
4118  }
4119  assert(newOp->getNumResults() == 1 && "expected single result");
4120  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4121  return success();
4122  }
4123 };
4124 
4127  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4128 }
union mlir::linalg::@1216::ArityGroupAndKind::Kind kind
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4715
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4714
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4713
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult reductionPreconditions(LinalgOp op)
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:138
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:600
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:665
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:97
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:219
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:301
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:203
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:242
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:651
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2799
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:332
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Definition: Transforms.h:868
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.