MLIR  21.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Vectorize tensor::InsertSliceOp with:
63 /// * vector::TransferReadOp + vector::TransferWriteOp
64 /// The vector sizes are either:
65 /// * user-provided in `inputVectorSizes`, or
66 /// * inferred from the static dims in the input and output tensors.
67 /// Bails out if:
68 /// * vector sizes are not user-provided, and
69 /// * at least one dim is dynamic (in both the input and output tensors).
70 ///
71 /// Before:
72 /// !t_in_type = tensor<1x2x3xf32>
73 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
74 /// !v_type = vector<1x2x3xf32>
75 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
76 /// into !t_out_type
77 /// After:
78 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
79 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
80 static LogicalResult
81 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
82  ArrayRef<int64_t> inputVectorSizes,
83  SmallVectorImpl<Value> &newResults);
84 
85 /// Returns the effective Pad value for the input op, provided it's a scalar.
86 ///
87 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
88 /// this Op performs padding, retrieve the padding value provided that it's
89 /// a scalar and static/fixed for all the padded values. Returns an empty value
90 /// otherwise.
91 static Value getStaticPadVal(Operation *op);
92 
93 /// Return the unique instance of OpType in `block` if it is indeed unique.
94 /// Return null if none or more than 1 instances exist.
95 template <typename OpType>
96 static OpType getSingleOpOfType(Block &block) {
97  OpType res;
98  block.walk([&](OpType op) {
99  if (res) {
100  res = nullptr;
101  return WalkResult::interrupt();
102  }
103  res = op;
104  return WalkResult::advance();
105  });
106  return res;
107 }
108 
109 /// Helper function to extract the input slices after filter is unrolled along
110 /// kw.
111 static SmallVector<Value>
113  int64_t nSize, int64_t wSize, int64_t cSize,
114  int64_t kwSize, int strideW, int dilationW,
115  int64_t wSizeStep, bool isSingleChanneled) {
116  SmallVector<Value> result;
117  if (isSingleChanneled) {
118  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
119  // convolution.
120  SmallVector<int64_t> sizes = {wSizeStep};
121  SmallVector<int64_t> strides = {1};
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  for (int64_t w = 0; w < wSize; w += wSizeStep) {
124  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
125  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
126  }
127  }
128  } else {
129  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
130  // for channeled convolution.
131  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
132  SmallVector<int64_t> strides = {1, 1, 1};
133  for (int64_t kw = 0; kw < kwSize; ++kw) {
134  for (int64_t w = 0; w < wSize; w += wSizeStep) {
135  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
136  loc, input,
137  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
138  sizes, strides));
139  }
140  }
141  }
142  return result;
143 }
144 
145 /// Helper function to extract the filter slices after filter is unrolled along
146 /// kw.
148  Location loc, Value filter,
149  int64_t kwSize) {
150  SmallVector<Value> result;
151  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
152  // non-chanelled convolution] @ [kw].
153  for (int64_t kw = 0; kw < kwSize; ++kw) {
154  result.push_back(rewriter.create<vector::ExtractOp>(
155  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
156  }
157  return result;
158 }
159 
160 /// Helper function to extract the result slices after filter is unrolled along
161 /// kw.
162 static SmallVector<Value>
164  int64_t nSize, int64_t wSize, int64_t fSize,
165  int64_t wSizeStep, bool isSingleChanneled) {
166  SmallVector<Value> result;
167  if (isSingleChanneled) {
168  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
169  SmallVector<int64_t> sizes = {wSizeStep};
170  SmallVector<int64_t> strides = {1};
171  for (int64_t w = 0; w < wSize; w += wSizeStep) {
172  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
173  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
174  }
175  } else {
176  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
177  // convolution.
178  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
179  SmallVector<int64_t> strides = {1, 1, 1};
180  for (int64_t w = 0; w < wSize; w += wSizeStep) {
181  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
182  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
183  }
184  }
185  return result;
186 }
187 
188 /// Helper function to insert the computed result slices.
190  Value res, int64_t wSize, int64_t wSizeStep,
191  SmallVectorImpl<Value> &resVals,
192  bool isSingleChanneled) {
193 
194  if (isSingleChanneled) {
195  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196  // This does not depend on kw.
197  SmallVector<int64_t> strides = {1};
198  for (int64_t w = 0; w < wSize; w += wSizeStep) {
199  res = rewriter.create<vector::InsertStridedSliceOp>(
200  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
201  }
202  } else {
203  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
204  // convolution. This does not depend on kw.
205  SmallVector<int64_t> strides = {1, 1, 1};
206  for (int64_t w = 0; w < wSize; w += wSizeStep) {
207  res = rewriter.create<vector::InsertStridedSliceOp>(
208  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
209  strides);
210  }
211  }
212  return res;
213 }
214 
215 /// Contains the vectorization state and related methods used across the
216 /// vectorization process of a given operation.
218  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
219 
220  /// Initializes the vectorization state, including the computation of the
221  /// canonical vector shape for vectorization.
222  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
223  ArrayRef<int64_t> inputVectorSizes,
224  ArrayRef<bool> inputScalableVecDims);
225 
226  /// Returns the canonical vector shape used to vectorize the iteration space.
227  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
228 
229  /// Returns the vector dimensions that are scalable in the canonical vector
230  /// shape.
231  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
232 
233  /// Returns a vector type of the provided `elementType` with the canonical
234  /// vector shape and the corresponding fixed/scalable dimensions bit. If
235  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
236  /// accordingly.
238  Type elementType,
239  std::optional<AffineMap> dimPermutation = std::nullopt) const {
241  SmallVector<bool> scalableDims;
242  if (dimPermutation.has_value()) {
243  vectorShape =
244  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
245  scalableDims =
246  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
247  } else {
248  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
249  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
250  }
251 
252  return VectorType::get(vectorShape, elementType, scalableDims);
253  }
254 
255  /// Masks an operation with the canonical vector mask if the operation needs
256  /// masking. Returns the masked operation or the original operation if masking
257  /// is not needed. If provided, the canonical mask for this operation is
258  /// permuted using `maybeIndexingMap`.
259  Operation *
260  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
261  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
262 
263 private:
264  /// Initializes the iteration space static sizes using the Linalg op
265  /// information. This may become more complicated in the future.
266  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
267  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
268  }
269 
270  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
271  /// all the static and dynamic dimensions of the iteration space to be
272  /// vectorized and store them in `iterSpaceValueSizes`.
273  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
274  LinalgOp linalgOp);
275 
276  /// Create or retrieve an existing mask value to mask `opToMask` in the
277  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
278  /// permuted using that permutation map. If a new mask is created, it will be
279  /// cached for future users.
280  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
281  LinalgOp linalgOp,
282  std::optional<AffineMap> maybeMaskingMap);
283 
284  /// Check whether this permutation map can be used for masking. At the
285  /// moment we only make sure that there are no broadcast dimensions, but this
286  /// might change if indexing maps evolve.
287  bool isValidMaskingMap(AffineMap maskingMap) {
288  return maskingMap.getBroadcastDims().size() == 0;
289  }
290 
291  /// Turn the input indexing map into a valid masking map.
292  ///
293  /// The input indexing map may contain "zero" results, e.g.:
294  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
295  /// Applying such maps to canonical vector shapes like this one:
296  /// (1, 16, 16, 4)
297  /// would yield an invalid vector shape like this:
298  /// (16, 16, 1, 0)
299  /// Instead, drop the broadcasting dims that make no sense for masking perm.
300  /// maps:
301  /// (d0, d1, d2, d3) -> (d2, d1, d0)
302  /// This way, the corresponding vector/mask type will be:
303  /// vector<16x16x1xty>
304  /// rather than this invalid Vector type:
305  /// vector<16x16x1x0xty>
306  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
307  return indexingMap.dropZeroResults();
308  }
309 
310  // Holds the compile-time static sizes of the iteration space to vectorize.
311  // Dynamic dimensions are represented using ShapedType::kDynamic.
312  SmallVector<int64_t> iterSpaceStaticSizes;
313 
314  /// Holds the value sizes of the iteration space to vectorize. Static
315  /// dimensions are represented by 'arith.constant' and dynamic
316  /// dimensions by 'tensor/memref.dim'.
317  SmallVector<Value> iterSpaceValueSizes;
318 
319  /// Holds the canonical vector shape used to vectorize the iteration space.
320  SmallVector<int64_t> canonicalVecShape;
321 
322  /// Holds the vector dimensions that are scalable in the canonical vector
323  /// shape.
324  SmallVector<bool> scalableVecDims;
325 
326  /// Holds the active masks for permutations of the canonical vector iteration
327  /// space.
328  DenseMap<AffineMap, Value> activeMaskCache;
329 
330  /// Global vectorization guard for the incoming rewriter. It's initialized
331  /// when the vectorization state is initialized.
333 };
334 
335 LogicalResult
336 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
337  LinalgOp linalgOp) {
338  // TODO: Support 0-d vectors.
339  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
340  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
341  // Create constant index op for static dimensions.
342  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
343  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
344  continue;
345  }
346 
347  // Find an operand defined on this dimension of the iteration space to
348  // extract the runtime dimension size.
349  Value operand;
350  unsigned operandDimPos;
351  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
352  operandDimPos)))
353  return failure();
354 
355  Value dynamicDim = linalgOp.hasPureTensorSemantics()
356  ? (Value)rewriter.create<tensor::DimOp>(
357  linalgOp.getLoc(), operand, operandDimPos)
358  : (Value)rewriter.create<memref::DimOp>(
359  linalgOp.getLoc(), operand, operandDimPos);
360  iterSpaceValueSizes.push_back(dynamicDim);
361  }
362 
363  return success();
364 }
365 
366 /// Initializes the vectorization state, including the computation of the
367 /// canonical vector shape for vectorization.
368 // TODO: Move this to the constructor when we can remove the failure cases.
369 LogicalResult
370 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
371  ArrayRef<int64_t> inputVectorSizes,
372  ArrayRef<bool> inputScalableVecDims) {
373  // Initialize the insertion point.
374  rewriter.setInsertionPoint(linalgOp);
375 
376  if (!inputVectorSizes.empty()) {
377  // Get the canonical vector shape from the input vector sizes provided. This
378  // path should be taken to vectorize code with dynamic shapes and when using
379  // vector sizes greater than the iteration space sizes.
380  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
381  scalableVecDims.append(inputScalableVecDims.begin(),
382  inputScalableVecDims.end());
383  } else {
384  // Compute the canonical vector shape from the operation shape. If there are
385  // dynamic shapes, the operation won't be vectorized. We assume all the
386  // vector dimensions are fixed.
387  canonicalVecShape = linalgOp.getStaticLoopRanges();
388  scalableVecDims.append(linalgOp.getNumLoops(), false);
389  }
390 
391  LDBG("Canonical vector shape: ");
392  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
393  LLVM_DEBUG(llvm::dbgs() << "\n");
394  LDBG("Scalable vector dims: ");
395  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
396  LLVM_DEBUG(llvm::dbgs() << "\n");
397 
398  if (ShapedType::isDynamicShape(canonicalVecShape))
399  return failure();
400 
401  // Initialize iteration space static sizes.
402  initIterSpaceStaticSizes(linalgOp);
403 
404  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
405  // all the static and dynamic dimensions of the iteration space, needed to
406  // compute a mask during vectorization.
407  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
408  return failure();
409 
410  return success();
411 }
412 
413 /// Create or retrieve an existing mask value to mask `opToMask` in the
414 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
415 /// using that permutation map. If a new mask is created, it will be cached for
416 /// future users.
417 Value VectorizationState::getOrCreateMaskFor(
418  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
419  std::optional<AffineMap> maybeMaskingMap) {
420 
421  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
422  "Ill-formed masking map.");
423 
424  // No mask is needed if the operation is not maskable.
425  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
426  if (!maskableOp)
427  return Value();
428 
429  assert(!maskableOp.isMasked() &&
430  "Masking an operation that is already masked");
431 
432  // If no masking map was provided, use an identity map with the loop dims.
433  assert((!maybeMaskingMap || *maybeMaskingMap) &&
434  "Unexpected null mask permutation map");
435  AffineMap maskingMap =
436  maybeMaskingMap ? *maybeMaskingMap
438  linalgOp.getNumLoops(), rewriter.getContext());
439 
440  LDBG("Masking map: " << maskingMap << "\n");
441 
442  // Return the active mask for the masking map of this operation if it was
443  // already created.
444  auto activeMaskIt = activeMaskCache.find(maskingMap);
445  if (activeMaskIt != activeMaskCache.end()) {
446  Value mask = activeMaskIt->second;
447  LDBG("Reusing mask: " << mask << "\n");
448  return mask;
449  }
450 
451  // Compute permuted projection of the iteration space to be masked and the
452  // corresponding mask shape. If the resulting iteration space dimensions are
453  // static and identical to the mask shape, masking is not needed for this
454  // operation.
455  // TODO: Improve this check. Only projected permutation indexing maps are
456  // supported.
457  SmallVector<int64_t> permutedStaticSizes =
458  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
459  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
460  auto maskShape = maskType.getShape();
461 
462  LDBG("Mask shape: ");
463  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
464  LLVM_DEBUG(llvm::dbgs() << "\n");
465 
466  if (permutedStaticSizes == maskShape) {
467  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
468  activeMaskCache[maskingMap] = Value();
469  return Value();
470  }
471 
472  // Permute the iteration space value sizes to compute the mask upper bounds.
473  SmallVector<Value> upperBounds =
474  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
475  assert(!maskShape.empty() && !upperBounds.empty() &&
476  "Masked 0-d vectors are not supported yet");
477 
478  // Create the mask based on the dimension values.
479  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
480  maskType, upperBounds);
481  LDBG("Creating new mask: " << mask << "\n");
482  activeMaskCache[maskingMap] = mask;
483  return mask;
484 }
485 
486 Operation *
488  LinalgOp linalgOp,
489  std::optional<AffineMap> maybeIndexingMap) {
490  LDBG("Trying to mask: " << *opToMask << "\n");
491 
492  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
493  if (maybeIndexingMap)
494  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
495 
496  // Create or retrieve mask for this operation.
497  Value mask =
498  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
499 
500  if (!mask) {
501  LDBG("No mask required\n");
502  return opToMask;
503  }
504 
505  // Wrap the operation with a new `vector.mask` and update D-U chain.
506  assert(opToMask && "Expected a valid operation to mask");
507  auto maskOp = cast<vector::MaskOp>(
508  mlir::vector::maskOperation(rewriter, opToMask, mask));
509  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
510 
511  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
512  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
513  maskOpTerminator);
514 
515  LDBG("Masked operation: " << *maskOp << "\n");
516  return maskOp;
517 }
518 
519 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
520 /// projectedPermutation, compress the unused dimensions to serve as a
521 /// permutation_map for a vector transfer operation.
522 /// For example, given a linalg op such as:
523 ///
524 /// ```
525 /// %0 = linalg.generic {
526 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
527 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
528 /// }
529 /// ins(%0 : tensor<2x3x4xf32>)
530 /// outs(%1 : tensor<5x6xf32>)
531 /// ```
532 ///
533 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
534 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
535 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
537  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
538  "expected projected permutation");
539  auto res = compressUnusedDims(map);
540  assert(res.getNumDims() ==
541  (res.getNumResults() - res.getNumOfZeroResults()) &&
542  "expected reindexed map with same number of dims and results");
543  return res;
544 }
545 
546 /// Helper enum to represent conv1d input traversal order.
547 enum class Conv1DOpOrder {
548  W, // Corresponds to non-channeled 1D convolution operation.
549  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
550  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
551 };
552 
553 /// Helper data structure to represent the result of vectorization.
554 /// In certain specific cases, like terminators, we do not want to propagate/
556  /// Op failed to vectorize.
557  Failure = 0,
558  /// Op vectorized and custom function took care of replacement logic
560  /// Op vectorized into a new Op whose results will replace original Op's
561  /// results.
562  NewOp
563  // TODO: support values if Op vectorized to Many-Ops whose results we need to
564  // aggregate for replacement.
565 };
567  /// Return status from vectorizing the current op.
569  /// New vectorized operation to replace the current op.
570  /// Replacement behavior is specified by `status`.
572 };
573 
574 std::optional<vector::CombiningKind>
576  using ::mlir::vector::CombiningKind;
577 
578  if (!combinerOp)
579  return std::nullopt;
581  .Case<arith::AddIOp, arith::AddFOp>(
582  [&](auto op) { return CombiningKind::ADD; })
583  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
584  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
585  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
586  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
587  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
588  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
589  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
590  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
591  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
592  .Case<arith::MulIOp, arith::MulFOp>(
593  [&](auto op) { return CombiningKind::MUL; })
594  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
595  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
596  .Default([&](auto op) { return std::nullopt; });
597 }
598 
599 /// Check whether `outputOperand` is a reduction with a single combiner
600 /// operation. Return the combiner operation of the reduction. Return
601 /// nullptr otherwise. Multiple reduction operations would impose an
602 /// ordering between reduction dimensions and is currently unsupported in
603 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
604 /// max(min(X))
605 // TODO: use in LinalgOp verification, there is a circular dependency atm.
606 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
607  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
608  unsigned outputPos =
609  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
610  // Only single combiner operations are supported for now.
611  SmallVector<Operation *, 4> combinerOps;
612  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
613  combinerOps.size() != 1)
614  return nullptr;
615 
616  // Return the combiner operation.
617  return combinerOps[0];
618 }
619 
620 /// Broadcast `value` to a vector of `shape` if possible. Return value
621 /// otherwise.
622 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
623  auto dstVecType = dyn_cast<VectorType>(dstType);
624  // If no shape to broadcast to, just return `value`.
625  if (dstVecType.getRank() == 0)
626  return value;
627  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
629  return value;
630  Location loc = b.getInsertionPoint()->getLoc();
631  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
632 }
633 
634 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
635 /// assumes that `reductionOp` has two operands and one of them is the reduction
636 /// initial value.buildMultiDimReduce
637 // Note: this is a true builder that notifies the OpBuilder listener.
638 // TODO: Consider moving as a static helper on the ReduceOp.
640  Value valueToReduce, Value acc,
641  ArrayRef<bool> dimsToMask) {
642  auto maybeKind = getCombinerOpKind(reduceOp);
643  assert(maybeKind && "Failed precondition: could not get reduction kind");
644  return b.create<vector::MultiDimReductionOp>(
645  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
646 }
647 
648 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
649  return llvm::to_vector(
650  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
651 }
652 
653 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
654 /// reduction iterator.
655 static bool hasReductionIterator(LinalgOp &op) {
656  return isa<linalg::ReduceOp>(op) ||
657  (isa<linalg::GenericOp>(op) &&
658  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
659 }
660 
661 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
662 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
663 /// currently being vectorized. If `dest` has null rank, build an memref.store.
664 /// Return the produced value or null if no value is produced.
665 // Note: this is a true builder that notifies the OpBuilder listener.
666 // TODO: Consider moving as a static helper on the ReduceOp.
667 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
668  OpOperand *outputOperand,
669  VectorizationState &state) {
670  Location loc = value.getLoc();
671  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
672  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
673 
674  // Compute the vector type of the value to store. This type should be an
675  // identity or projection of the canonical vector type without any permutation
676  // applied, given that any permutation in a transfer write happens as part of
677  // the write itself.
679  opOperandMap.getContext(), opOperandMap.getNumInputs(),
680  [&](AffineDimExpr dimExpr) -> bool {
681  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
682  });
683  auto vectorType = state.getCanonicalVecType(
684  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
685 
686  Operation *write;
687  if (vectorType.getRank() > 0) {
688  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
689  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
690  rewriter.create<arith::ConstantIndexOp>(loc, 0));
691  value = broadcastIfNeeded(rewriter, value, vectorType);
692  assert(value.getType() == vectorType && "Incorrect type");
693  write = rewriter.create<vector::TransferWriteOp>(
694  loc, value, outputOperand->get(), indices, writeMap);
695  } else {
696  // 0-d case is still special: do not invert the reindexing writeMap.
697  if (!isa<VectorType>(value.getType()))
698  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
699  assert(value.getType() == vectorType && "Incorrect type");
700  write = rewriter.create<vector::TransferWriteOp>(
701  loc, value, outputOperand->get(), ValueRange{});
702  }
703 
704  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
705 
706  // If masked, set in-bounds to true. Masking guarantees that the access will
707  // be in-bounds.
708  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
709  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
710  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
711  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
712  }
713 
714  LDBG("vectorized op: " << *write << "\n");
715  if (!write->getResults().empty())
716  return write->getResult(0);
717  return Value();
718 }
719 
720 // Custom vectorization precondition function type. This is intented to be used
721 // with CustomVectorizationHook. Returns success if the corresponding custom
722 // hook can vectorize the op.
724  std::function<LogicalResult(Operation *, bool)>;
725 
726 // Custom vectorization function type. Produce a vector form of Operation*
727 // assuming all its vectorized operands are already in the IRMapping.
728 // Return nullptr if the Operation cannot be vectorized.
730  std::function<VectorizationResult(Operation *, const IRMapping &)>;
731 
732 /// Helper function to vectorize the terminator of a `linalgOp`. New result
733 /// vector values are appended to `newResults`. Return
734 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
735 /// should not try to map produced operations and instead return the results
736 /// using the `newResults` vector making them available to the vectorization
737 /// algorithm for RAUW. This function is meant to be used as a
738 /// CustomVectorizationHook.
739 static VectorizationResult
741  const IRMapping &bvm, VectorizationState &state,
742  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
743  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
744  if (!yieldOp)
746  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
747  // TODO: Scan for an opportunity for reuse.
748  // TODO: use a map.
749  Value vectorValue = bvm.lookup(output.value());
750  Value newResult =
751  buildVectorWrite(rewriter, vectorValue,
752  linalgOp.getDpsInitOperand(output.index()), state);
753  if (newResult)
754  newResults.push_back(newResult);
755  }
756 
758 }
759 
760 /// Helper function to vectorize the index operations of a `linalgOp`. Return
761 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
762 /// should map the produced operations. This function is meant to be used as a
763 /// CustomVectorizationHook.
765  VectorizationState &state,
766  Operation *op,
767  LinalgOp linalgOp) {
768  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
769  if (!indexOp)
771  auto loc = indexOp.getLoc();
772  // Compute the static loop sizes of the index op.
773  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
774  auto dim = indexOp.getDim();
775  // Compute a one-dimensional index vector for the index op dimension.
776  auto indexVectorType =
777  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
778  state.getScalableVecDims()[dim]);
779  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
780  // Return the one-dimensional index vector if it lives in the trailing
781  // dimension of the iteration space since the vectorization algorithm in this
782  // case can handle the broadcast.
783  if (dim == targetShape.size() - 1)
785  // Otherwise permute the targetShape to move the index dimension last,
786  // broadcast the one-dimensional index vector to the permuted shape, and
787  // finally transpose the broadcasted index vector to undo the permutation.
788  auto permPattern =
789  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
790  std::swap(permPattern[dim], permPattern.back());
791  auto permMap =
792  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
793 
794  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
795  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
796  indexSteps);
797  SmallVector<int64_t> transposition =
798  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
799  std::swap(transposition.back(), transposition[dim]);
800  auto transposeOp =
801  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
802  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
803 }
804 
805 /// Helper function to check if the tensor.extract can be vectorized by the
806 /// custom hook vectorizeTensorExtract.
807 static LogicalResult
808 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
809  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
810  if (!extractOp)
811  return failure();
812 
813  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
814  return failure();
815 
816  // Check the index type, but only for non 0-d tensors (for which we do need
817  // access indices).
818  if (not extractOp.getIndices().empty()) {
819  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
820  return failure();
821  }
822 
823  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
824  return !VectorType::isValidElementType(type);
825  })) {
826  return failure();
827  }
828 
829  return success();
830 }
831 
832 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
833 /// generated from `tensor.extract`. The offset is calculated as follows
834 /// (example using scalar values):
835 ///
836 /// offset = extractOp.indices[0]
837 /// for (i = 1; i < numIndices; i++)
838 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
839 ///
840 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
841 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
843  VectorizationState &state,
844  tensor::ExtractOp extractOp,
845  const IRMapping &bvm) {
846  // The vector of indices for GatherOp should be shaped as the output vector.
847  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
848  auto loc = extractOp.getLoc();
849 
850  Value offset = broadcastIfNeeded(
851  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
852 
853  const size_t numIndices = extractOp.getIndices().size();
854  for (size_t i = 1; i < numIndices; i++) {
855  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
856 
857  auto dimSize = broadcastIfNeeded(
858  rewriter,
859  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
860  indexVecType);
861 
862  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
863 
864  auto extractOpIndex = broadcastIfNeeded(
865  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
866 
867  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
868  }
869 
870  return offset;
871 }
872 
874 
875 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
876 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
877 /// represents a contiguous load operation.
878 ///
879 /// Note that when calling this hook, it is assumed that the output vector is
880 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
881 /// labelled as a gather load before entering this method.
882 ///
883 /// Following on from the above, it is assumed that:
884 /// * for statically shaped loops, when no masks are used, only one dim is !=
885 /// 1 (that's what the shape of the output vector is based on).
886 /// * for dynamically shaped loops, there might be more non-unit dims
887 /// as the output vector type is user-specified.
888 ///
889 /// TODO: Statically shaped loops + vector masking
890 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
891  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
892  assert(
893  (linalgOp.hasDynamicShape() ||
894  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
895  "For statically shaped Linalg Ops, only one "
896  "non-unit loop dim is expected");
897  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
898 
899  size_t idx = loopRanges.size() - 1;
900  for (; idx != 0; idx--)
901  if (loopRanges[idx] != 1)
902  break;
903 
904  return idx;
905 }
906 
907 /// Checks whether `val` can be used for calculating a loop invariant index.
908 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
909  VectorType resType) {
910 
911  assert(((llvm::count_if(resType.getShape(),
912  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
913  "n-D vectors are not yet supported");
914 
915  // Blocks outside _this_ linalg.generic are effectively loop invariant.
916  // However, analysing block arguments for _this_ linalg.generic Op is a bit
917  // tricky. Just bail out in the latter case.
918  // TODO: We could try analysing the corresponding affine map here.
919  auto *block = linalgOp.getBlock();
920  if (isa<BlockArgument>(val))
921  return llvm::all_of(block->getArguments(),
922  [&val](Value v) { return (v != val); });
923 
924  Operation *defOp = val.getDefiningOp();
925  assert(defOp && "This is neither a block argument nor an operation result");
926 
927  // IndexOp is loop invariant as long as its result remains constant across
928  // iterations. Note that for dynamic shapes, the corresponding dim will also
929  // be conservatively treated as != 1.
930  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
931  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
932  }
933 
934  auto *ancestor = block->findAncestorOpInBlock(*defOp);
935 
936  // Values define outside `linalgOp` are loop invariant.
937  if (!ancestor)
938  return true;
939 
940  // Values defined inside `linalgOp`, which are constant, are loop invariant.
941  if (isa<arith::ConstantOp>(ancestor))
942  return true;
943 
944  bool result = true;
945  for (auto op : ancestor->getOperands())
946  result &= isLoopInvariantIdx(linalgOp, op, resType);
947 
948  return result;
949 }
950 
951 /// Check whether `val` could be used for calculating the trailing index for a
952 /// contiguous load operation.
953 ///
954 /// There are currently 3 types of values that are allowed here:
955 /// 1. loop-invariant values,
956 /// 2. values that increment by 1 with every loop iteration,
957 /// 3. results of basic arithmetic operations (linear and continuous)
958 /// involving 1., 2. and 3.
959 /// This method returns True if indeed only such values are used in calculating
960 /// `val.`
961 ///
962 /// Additionally, the trailing index for a contiguous load operation should
963 /// increment by 1 with every loop iteration, i.e. be based on:
964 /// * `linalg.index <dim>` ,
965 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
966 /// `linalg.index <dim>` increments by 1 with every loop iteration).
967 /// `foundIndexOp` is updated to `true` when such Op is found.
968 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
969  bool &foundIndexOp, VectorType resType) {
970 
971  assert(((llvm::count_if(resType.getShape(),
972  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
973  "n-D vectors are not yet supported");
974 
975  // Blocks outside _this_ linalg.generic are effectively loop invariant.
976  // However, analysing block arguments for _this_ linalg.generic Op is a bit
977  // tricky. Just bail out in the latter case.
978  // TODO: We could try analysing the corresponding affine map here.
979  auto *block = linalgOp.getBlock();
980  if (isa<BlockArgument>(val))
981  return llvm::all_of(block->getArguments(),
982  [&val](Value v) { return (v != val); });
983 
984  Operation *defOp = val.getDefiningOp();
985  assert(defOp && "This is neither a block argument nor an operation result");
986 
987  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
988  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
989 
990  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
991  return true;
992  }
993 
994  auto *ancestor = block->findAncestorOpInBlock(*defOp);
995 
996  if (!ancestor)
997  return false;
998 
999  // Conservatively reject Ops that could lead to indices with stride other
1000  // than 1.
1001  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1002  return false;
1003 
1004  bool result = false;
1005  for (auto op : ancestor->getOperands())
1006  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1007 
1008  return result;
1009 }
1010 
1011 /// Infer the memory access pattern for the input ExtractOp
1012 ///
1013 /// Based on the ExtratOp result shape and the access indices, decides whether
1014 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1015 /// or a gather load. When analysing the ExtractOp indices (to identify
1016 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1017 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1018 /// Op).
1019 ///
1020 /// Note that it is always safe to use gather load operations for contiguous
1021 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1022 /// that `extractOp` is a gather load.
1024 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1025  LinalgOp &linalgOp, VectorType resType) {
1026 
1027  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1028 
1029  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1030  if (inputShape.getShape().empty())
1032 
1033  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1034  // otherwise.
1035  bool isOutput1DVector =
1036  (llvm::count_if(resType.getShape(),
1037  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1038  // 1. Assume that it's a gather load when reading non-1D vector.
1039  if (!isOutput1DVector)
1041 
1042  bool leadingIdxsLoopInvariant = true;
1043 
1044  // 2. Analyze the leading indices of `extractOp`.
1045  // Look at the way each index is calculated and decide whether it is suitable
1046  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1047  // gather load.
1048  auto indices = extractOp.getIndices();
1049  auto leadIndices = indices.drop_back(1);
1050 
1051  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1052  if (inputShape.getShape()[i] == 1)
1053  continue;
1054 
1055  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1056  }
1057 
1058  if (!leadingIdxsLoopInvariant) {
1059  LDBG("Found gather load: " << extractOp);
1061  }
1062 
1063  // 3. Analyze the trailing index for `extractOp`.
1064  // At this point we know that the leading indices are loop invariant. This
1065  // means that is potentially a scalar or a contiguous load. We can decide
1066  // based on the trailing idx.
1067  auto extractOpTrailingIdx = indices.back();
1068 
1069  // 3a. Scalar broadcast load
1070  // If the trailing index is loop invariant then this is a scalar load.
1071  if (leadingIdxsLoopInvariant &&
1072  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1073  LDBG("Found scalar broadcast load: " << extractOp);
1074 
1076  }
1077 
1078  // 3b. Contiguous loads
1079  // The trailing `extractOp` index should increment with every loop iteration.
1080  // This effectively means that it must be based on the trailing loop index.
1081  // This is what the following bool captures.
1082  bool foundIndexOp = false;
1083  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1084  foundIndexOp, resType);
1085  // TODO: Support generating contiguous loads for column vectors - that will
1086  // require adding a permutation map to tranfer_read Ops.
1087  bool isRowVector = resType.getShape().back() != 1;
1088  isContiguousLoad &= (foundIndexOp && isRowVector);
1089 
1090  if (isContiguousLoad) {
1091  LDBG("Found contigous load: " << extractOp);
1093  }
1094 
1095  // 4. Fallback case - gather load.
1096  LDBG("Found gather load: " << extractOp);
1098 }
1099 
1100 /// Helper function to vectorize the tensor.extract operations. Returns
1101 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1102 /// should map the produced operations. This function is meant to be used as a
1103 /// CustomVectorizationHook.
1104 static VectorizationResult
1106  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1107  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1108  if (!extractOp)
1110  auto loc = extractOp.getLoc();
1111 
1112  // Compute the static loop sizes of the extract op.
1113  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1114  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1115  loc,
1116  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1117  /*value=*/true));
1118  auto passThruConstantOp =
1119  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1120 
1121  // Base indices are currently set to 0. We will need to re-visit if more
1122  // generic scenarios are to be supported.
1123  SmallVector<Value> baseIndices(
1124  extractOp.getIndices().size(),
1125  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1126 
1127  VectorMemoryAccessKind memAccessKind =
1128  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1129 
1130  // 1. Handle gather access
1131  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1132  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1133 
1134  // Generate the gather load
1135  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1136  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1137  maskConstantOp, passThruConstantOp);
1138  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1139 
1140  LDBG("Vectorised as gather load: " << extractOp << "\n");
1142  }
1143 
1144  // 2. Handle:
1145  // a. scalar loads + broadcast,
1146  // b. contiguous loads.
1147  // Both cases use vector.transfer_read.
1148 
1149  // Collect indices for `vector.transfer_read`. At this point, the indices will
1150  // either be scalars or would have been broadcast to vectors matching the
1151  // result type. For indices that are vectors, there are two options:
1152  // * for non-trailing indices, all elements are identical (contiguous
1153  // loads are identified by looking for non-trailing indices that are
1154  // invariant with respect to the corresponding linalg.generic), or
1155  // * for trailing indices, the index vector will contain values with stride
1156  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1157  // needed.
1158  // This means that
1159  // * for scalar indices - just re-use it,
1160  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1161  // (0th) element and use that.
1162  SmallVector<Value> transferReadIdxs;
1163  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1164  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1165  if (idx.getType().isIndex()) {
1166  transferReadIdxs.push_back(idx);
1167  continue;
1168  }
1169 
1170  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1171  loc,
1172  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1173  resultType.getScalableDims().back()),
1174  idx);
1175  transferReadIdxs.push_back(
1176  rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1177  }
1178 
1179  // `tensor.extract_element` is always in-bounds, hence the following holds.
1180  auto dstRank = resultType.getRank();
1181  auto srcRank = extractOp.getTensor().getType().getRank();
1182  SmallVector<bool> inBounds(dstRank, true);
1183 
1184  // 2a. Handle scalar broadcast access.
1185  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1186  MLIRContext *ctx = rewriter.getContext();
1187  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1188  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1189 
1190  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1191  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1192  permutationMap, inBounds);
1193 
1194  // Mask this broadcasting xfer_read here rather than relying on the generic
1195  // path (the generic path assumes identity masking map, which wouldn't be
1196  // valid here).
1197  SmallVector<int64_t> readMaskShape = {1};
1198  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1199  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1200  loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1201  auto *maskedReadOp =
1202  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1203 
1204  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1205  return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1206  }
1207 
1208  // 2b. Handle contiguous access.
1209  auto permutationMap = AffineMap::getMinorIdentityMap(
1210  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1211 
1212  int32_t rankDiff = dstRank - srcRank;
1213  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1214  // dims so that the ranks match. This is done by extending the map with 0s.
1215  // For example, for dstRank = 3, srcRank = 2, the following map created
1216  // above:
1217  // (d0, d1) --> (d0, d1)
1218  // is extended as:
1219  // (d0, d1) --> (0, d0, d1)
1220  while (rankDiff > 0) {
1221  permutationMap = permutationMap.insertResult(
1222  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1223  rankDiff--;
1224  }
1225 
1226  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1227  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1228  inBounds);
1229 
1230  LDBG("Vectorised as contiguous load: " << extractOp);
1231  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1232 }
1233 
1234 /// Emit reduction operations if the shapes of the value to reduce is different
1235 /// that the result shape.
1236 // Note: this is a true builder that notifies the OpBuilder listener.
1237 // TODO: Consider moving as a static helper on the ReduceOp.
1238 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1239  Value reduceValue, Value initialValue,
1240  const IRMapping &bvm) {
1241  Value reduceVec = bvm.lookup(reduceValue);
1242  Value outputVec = bvm.lookup(initialValue);
1243  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1244  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1245  // Reduce only if needed as the value may already have been reduce for
1246  // contraction vectorization.
1247  if (!reduceType ||
1248  (outputType && reduceType.getShape() == outputType.getShape()))
1249  return nullptr;
1250  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1251  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1252 }
1253 
1254 /// Generic vectorization for a single operation `op`, given already vectorized
1255 /// operands carried by `bvm`. Vectorization occurs as follows:
1256 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1257 /// result on success.
1258 /// 2. Clone any constant in the current scope without vectorization: each
1259 /// consumer of the constant will later determine the shape to which the
1260 /// constant needs to be broadcast to.
1261 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1262 /// of the `customVectorizationHooks` to cover such cases.
1263 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1264 /// operand of maximal rank. Other operands have smaller rank and are
1265 /// broadcast accordingly. It is assumed this broadcast is always legal,
1266 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1267 ///
1268 /// This function assumes all operands of `op` have been vectorized and are in
1269 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1270 /// a topologically-sorted list of ops.
1271 /// This function does not update `bvm` but returns a VectorizationStatus that
1272 /// instructs the caller what `bvm` update needs to occur.
1273 static VectorizationResult
1275  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1276  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1277  LDBG("vectorize op " << *op << "\n");
1278 
1279  // 1. Try to apply any CustomVectorizationHook.
1280  if (!customVectorizationHooks.empty()) {
1281  for (auto &customFunc : customVectorizationHooks) {
1282  VectorizationResult result = customFunc(op, bvm);
1283  if (result.status == VectorizationStatus::Failure)
1284  continue;
1285  return result;
1286  }
1287  }
1288 
1289  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1290  // Clone so that the constant is not confined to the linalgOp block .
1291  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1292  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1293 
1294  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1297 
1298  // 4 . Check if the operation is a reduction.
1299  SmallVector<std::pair<Value, Value>> reductionOperands;
1300  for (Value operand : op->getOperands()) {
1301  auto blockArg = dyn_cast<BlockArgument>(operand);
1302  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1303  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1304  continue;
1305  SmallVector<Operation *> reductionOps;
1306  Value reduceValue = matchReduction(
1307  linalgOp.getRegionOutputArgs(),
1308  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1309  if (!reduceValue)
1310  continue;
1311  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1312  }
1313  if (!reductionOperands.empty()) {
1314  assert(reductionOperands.size() == 1);
1315  Operation *reduceOp =
1316  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1317  reductionOperands[0].second, bvm);
1318  if (reduceOp)
1320  }
1321 
1322  // 5. Generic vectorization path for ElementwiseMappable ops.
1323  // a. Get the first max ranked shape.
1324  VectorType firstMaxRankedType;
1325  for (Value operand : op->getOperands()) {
1326  auto vecOperand = bvm.lookup(operand);
1327  assert(vecOperand && "Vector operand couldn't be found");
1328 
1329  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1330  if (vecType && (!firstMaxRankedType ||
1331  firstMaxRankedType.getRank() < vecType.getRank()))
1332  firstMaxRankedType = vecType;
1333  }
1334  // b. Broadcast each op if needed.
1335  SmallVector<Value> vecOperands;
1336  for (Value scalarOperand : op->getOperands()) {
1337  Value vecOperand = bvm.lookup(scalarOperand);
1338  assert(vecOperand && "Vector operand couldn't be found");
1339 
1340  if (firstMaxRankedType) {
1341  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1342  getElementTypeOrSelf(vecOperand.getType()),
1343  firstMaxRankedType.getScalableDims());
1344  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1345  } else {
1346  vecOperands.push_back(vecOperand);
1347  }
1348  }
1349  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1350  SmallVector<Type> resultTypes;
1351  for (Type resultType : op->getResultTypes()) {
1352  resultTypes.push_back(
1353  firstMaxRankedType
1354  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1355  firstMaxRankedType.getScalableDims())
1356  : resultType);
1357  }
1358  // d. Build and return the new op.
1359  return VectorizationResult{
1361  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1362  resultTypes, op->getAttrs())};
1363 }
1364 
1365 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1366 /// vector form. Generic vectorization proceeds as follows:
1367 /// 1. Verify the `linalgOp` has one non-empty region.
1368 /// 2. Values defined above the region are mapped to themselves and will be
1369 /// broadcasted on a per-need basis by their consumers.
1370 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1371 /// load).
1372 /// TODO: Reuse opportunities for RAR dependencies.
1373 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1374 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1375 /// iteration indices.
1376 /// 5. Iteratively call vectorizeOneOp on the region operations.
1377 ///
1378 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1379 /// performed to the maximal common vector size implied by the `linalgOp`
1380 /// iteration space. This eager broadcasting is introduced in the
1381 /// permutation_map of the vector.transfer_read operations. The eager
1382 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1383 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1384 /// the absence of good canonicalizations, the amount of work increases.
1385 /// This is not deemed a problem as we expect canonicalizations and foldings to
1386 /// aggressively clean up the useless work.
1387 static LogicalResult
1389  LinalgOp linalgOp,
1390  SmallVectorImpl<Value> &newResults) {
1391  LDBG("Vectorizing operation as linalg generic\n");
1392  Block *block = linalgOp.getBlock();
1393 
1394  // 2. Values defined above the region can only be broadcast for now. Make them
1395  // map to themselves.
1396  IRMapping bvm;
1397  SetVector<Value> valuesSet;
1398  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1399  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1400 
1401  if (linalgOp.getNumDpsInits() == 0)
1402  return failure();
1403 
1404  // 3. Turn all BBArgs into vector.transfer_read / load.
1405  Location loc = linalgOp.getLoc();
1406  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1407  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1408  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1409  if (linalgOp.isScalar(opOperand)) {
1410  bvm.map(bbarg, opOperand->get());
1411  continue;
1412  }
1413 
1414  // 3.a. Convert the indexing map for this input/output to a transfer read
1415  // permutation map and masking map.
1416  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1417 
1418  AffineMap readMap;
1419  VectorType readType;
1420  Type elemType = getElementTypeOrSelf(opOperand->get());
1421  if (linalgOp.isDpsInput(opOperand)) {
1422  // 3.a.i. For input reads we use the canonical vector shape.
1423  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1424  readType = state.getCanonicalVecType(elemType);
1425  } else {
1426  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1427  // reductions), the vector shape is computed by mapping the canonical
1428  // vector shape to the output domain and back to the canonical domain.
1429  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1430  readType =
1431  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1432  }
1433 
1434  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1435 
1436  Operation *read = rewriter.create<vector::TransferReadOp>(
1437  loc, readType, opOperand->get(), indices, readMap);
1438  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1439  Value readValue = read->getResult(0);
1440 
1441  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1442  // will be in-bounds.
1443  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1444  SmallVector<bool> inBounds(readType.getRank(), true);
1445  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1446  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1447  }
1448 
1449  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1450  // TODO: remove this.
1451  if (readType.getRank() == 0)
1452  readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1453  ArrayRef<int64_t>());
1454 
1455  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1456  << "\n");
1457  bvm.map(bbarg, readValue);
1458  bvm.map(opOperand->get(), readValue);
1459  }
1460 
1462  // 4a. Register CustomVectorizationHook for yieldOp.
1463  CustomVectorizationHook vectorizeYield =
1464  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1465  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1466  };
1467  hooks.push_back(vectorizeYield);
1468 
1469  // 4b. Register CustomVectorizationHook for indexOp.
1470  CustomVectorizationHook vectorizeIndex =
1471  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1472  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1473  };
1474  hooks.push_back(vectorizeIndex);
1475 
1476  // 4c. Register CustomVectorizationHook for extractOp.
1477  CustomVectorizationHook vectorizeExtract =
1478  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1479  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1480  };
1481  hooks.push_back(vectorizeExtract);
1482 
1483  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1484  for (Operation &op : block->getOperations()) {
1485  VectorizationResult result =
1486  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1487  if (result.status == VectorizationStatus::Failure) {
1488  LDBG("failed to vectorize: " << op << "\n");
1489  return failure();
1490  }
1491  if (result.status == VectorizationStatus::NewOp) {
1492  Operation *maybeMaskedOp =
1493  state.maskOperation(rewriter, result.newOp, linalgOp);
1494  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1495  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1496  }
1497  }
1498 
1499  return success();
1500 }
1501 
1502 /// Given a linalg::PackOp, return the `dest` shape before any packing
1503 /// permutations.
1504 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1505  ArrayRef<int64_t> destShape) {
1506  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1507 }
1508 
1509 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1510 /// create an empty destination tensor and create a TransferWriteOp from the
1511 /// input to the empty tensor. If the destination shape is not the same as the
1512 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1513 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1514 /// inBounds attribute of the transfer write op instead of masking.
1516  Value input,
1517  SmallVector<OpFoldResult> destSizes,
1518  ArrayRef<int64_t> inputVectorSizes,
1519  bool useInBoundsInsteadOfMasking) {
1520 
1521  auto inputType = cast<VectorType>(input.getType());
1522  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1523  inputType.getElementType());
1524  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1525  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1526  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1527  SmallVector<bool> inBoundsVal(rank, true);
1528  if (useInBoundsInsteadOfMasking) {
1529  // Update the inBounds attribute.
1530  for (unsigned i = 0; i < rank; i++)
1531  inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1532  !ShapedType::isDynamic(destShape[i]);
1533  }
1534  Operation *write = builder.create<vector::TransferWriteOp>(
1535  loc,
1536  /*vector=*/input,
1537  /*source=*/dest,
1538  /*indices=*/SmallVector<Value>(rank, zero),
1539  /*inBounds=*/inBoundsVal);
1540  assert(llvm::none_of(
1541  destShape.drop_front(inputVectorSizes.size()),
1542  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1543  "Only dims aligned with inputVectorSizes may be dynamic");
1544  if (useInBoundsInsteadOfMasking)
1545  return write;
1546  bool needMaskForWrite = !llvm::equal(
1547  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1548  if (needMaskForWrite) {
1549  SmallVector<int64_t> writeMaskShape;
1550  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1551  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1552  destShape.end());
1553  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1554  Value maskForWrite =
1555  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1556  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1557  }
1558  return write;
1559 }
1560 
1561 /// Vectorize linalg::PackOp with (1) static innerTiles (2) constant
1562 /// padding value and (3) input vector sizes into:
1563 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1564 /// As in the following example:
1565 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1566 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1567 ///
1568 /// This pack would be vectorized to:
1569 ///
1570 /// %load = vector.mask %mask {
1571 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1572 /// {in_bounds = [true, true, true]} :
1573 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1574 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1575 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1576 /// to vector<32x4x2x1x16xf32>
1577 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1578 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1579 /// %write = vector.transfer_write %transpose,
1580 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1581 /// {in_bounds = [true, true, true, true, true]}
1582 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1583 ///
1584 /// If the (3) input vector sizes are not provided, the vector sizes are
1585 /// determined by the result tensor shape. Also, we update the inBounds
1586 /// attribute instead of masking.
1587 static LogicalResult
1588 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1589  ArrayRef<int64_t> inputVectorSizes,
1590  SmallVectorImpl<Value> &newResults) {
1591  // TODO: Introduce a parent class that will handle the insertion point update.
1592  OpBuilder::InsertionGuard g(rewriter);
1593  rewriter.setInsertionPoint(packOp);
1594 
1595  Location loc = packOp.getLoc();
1596  auto padValue = packOp.getPaddingValue();
1597  if (!padValue) {
1598  padValue = rewriter.create<arith::ConstantOp>(
1599  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1600  }
1601  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1602  LogicalResult status =
1603  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1604  .reifyResultShapes(rewriter, reifiedReturnShapes);
1605  (void)status; // prevent unused variable warning on non-assert builds.
1606  assert(succeeded(status) && "failed to reify result shapes");
1607 
1608  // If the input vector sizes are not provided, then the vector sizes are
1609  // determined by the result tensor shape. In case the vector sizes aren't
1610  // provided, we update the inBounds attribute instead of masking.
1611  bool useInBoundsInsteadOfMasking = false;
1612  if (inputVectorSizes.empty()) {
1613  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1614  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1615  useInBoundsInsteadOfMasking = true;
1616  }
1617 
1618  // Create masked TransferReadOp.
1619  SmallVector<int64_t> inputShape(inputVectorSizes);
1620  auto innerTiles = packOp.getStaticInnerTiles();
1621  auto innerDimsPos = packOp.getInnerDimsPos();
1622  auto outerDimsPerm = packOp.getOuterDimsPerm();
1623  if (!outerDimsPerm.empty())
1624  applyPermutationToVector(inputShape,
1626  for (auto [idx, size] : enumerate(innerTiles))
1627  inputShape[innerDimsPos[idx]] *= size;
1628  auto maskedRead = vector::createReadOrMaskedRead(
1629  rewriter, loc, packOp.getSource(), inputShape, padValue,
1630  useInBoundsInsteadOfMasking);
1631 
1632  // Create ShapeCastOp.
1633  SmallVector<int64_t> destShape(inputVectorSizes);
1634  destShape.append(innerTiles.begin(), innerTiles.end());
1635  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1636  packOp.getDestType().getElementType());
1637  auto shapeCastOp =
1638  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1639 
1640  // Create TransposeOp.
1641  auto destPermutation =
1643  auto transposeOp = rewriter.create<vector::TransposeOp>(
1644  loc, shapeCastOp.getResult(), destPermutation);
1645 
1646  // Create TransferWriteOp.
1648  rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1649  inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1650  newResults.push_back(write->getResult(0));
1651  return success();
1652 }
1653 
1654 /// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1655 /// Vector::TransferReadOp - Reads a vector from the source tensor
1656 /// vector::TransposeOp - Transpose the Source tensor
1657 /// ShapeCastOp - Reshape the data based on the target.
1658 /// vector::TransferWriteOp. - Write the result vector back to the destination
1659 /// tensor.
1660 /// If the vector sizes are not provided:
1661 /// * the vector sizes are determined by the input operand and attributes,
1662 /// * update the inBounds attribute instead of masking.
1663 static LogicalResult
1664 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1665  ArrayRef<int64_t> inputVectorSizes,
1666  SmallVectorImpl<Value> &newResults) {
1667 
1668  // TODO: Introduce a parent class that will handle the insertion point update.
1669  OpBuilder::InsertionGuard g(rewriter);
1670  rewriter.setInsertionPoint(unpackOp);
1671 
1672  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1673 
1674  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1675  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1676  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1677  bool useInBoundsInsteadOfMasking = false;
1678  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1679 
1680  auto destSize = unpackOp.getDestRank();
1681 
1682  if (!inputVectorSizes.empty())
1683  assert(inputVectorSizes.size() == destSize &&
1684  "Incorrect number of input vector sizes");
1685 
1686  // vectorSizes is the shape of the vector that will be used to do final
1687  // write on the destination tensor. It is set like this: Let's say the
1688  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1689  // Thus:
1690  // 1. vectorSizes = sourceShape.take_front(N)
1691  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1692  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1693  // innerTiles attribute value.
1694  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1695  if (vectorSizes.empty()) {
1696  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1697  if (!outerDimsPerm.empty())
1699  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1700  vectorSizes[pos] *= innerTiles[i];
1701 
1702  useInBoundsInsteadOfMasking = true;
1703  }
1704 
1705  // readVectorSizes is the size of tensor used to read and apply mask. It is
1706  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1707  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1708  // size M-N
1709  // Thus:
1710  // - initially: readVectorSizes = vectorInputSizes
1711  // - Divide all the readMaskShape locations pointed by innerDimPos
1712  // by the innerTileSize attribute value.
1713  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1714  // - Append the remaining shape from SS
1715  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1716  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1717  // 128] and outer_dims_perm is [1, 0] then read shape is:
1718  // ReadVectorSizes(initial): [512, 128]
1719  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1720  // = [16, 8]
1721  // After applying outer_dims_perm: [8, 16]
1722  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1723 
1724  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1725 
1726  for (auto [index, size] : enumerate(innerTiles)) {
1727  readVectorSizes[innerDimPos[index]] =
1728  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1729  }
1730  if (!outerDimsPerm.empty()) {
1731  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1732  }
1733  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1734  sourceShape.end());
1735 
1736  ReifiedRankedShapedTypeDims reifiedRetShapes;
1737  LogicalResult status =
1738  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1739  .reifyResultShapes(rewriter, reifiedRetShapes);
1740  if (status.failed()) {
1741  LDBG("Unable to reify result shapes of " << unpackOp);
1742  return failure();
1743  }
1744  Location loc = unpackOp->getLoc();
1745 
1746  auto padValue = rewriter.create<arith::ConstantOp>(
1747  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1748 
1749  // Read result, mask if necessary. If transferReadOp shape is not equal
1750  // to shape of source, then a mask is necessary.
1752  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1753  /*useInBoundsInsteadOfMasking=*/false);
1754 
1755  PackingMetadata packMetadata;
1756  SmallVector<int64_t> lastDimToInsertPosPerm =
1757  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1758  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1759  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1760  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1761  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1762  RankedTensorType stripMineTensorType =
1763  RankedTensorType::get(stripMineShape, stripMineElemType);
1764  // Transpose the appropriate rows to match output.
1765  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1766  loc, readResult, lastDimToInsertPosPerm);
1767 
1768  // Collapse the vector to the size required by result.
1769  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1770  stripMineTensorType, packMetadata.reassociations);
1771  mlir::VectorType vecCollapsedType =
1772  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1773  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1774  loc, vecCollapsedType, transposeOp->getResult(0));
1775 
1776  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1777  // otherwise the validator complains that the mask size is invalid.
1778  SmallVector<int64_t> writeVectorSizes(
1779  unpackOp.getDestType().hasStaticShape()
1780  ? vectorSizes
1781  : shapeCastOp.getResultVectorType().getShape());
1783  rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1784  writeVectorSizes, useInBoundsInsteadOfMasking);
1785  newResults.push_back(write->getResult(0));
1786  return success();
1787 }
1788 
1789 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1790 /// and (3) all-zero lowPad to
1791 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1792 static LogicalResult
1793 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1794  ArrayRef<int64_t> inputVectorSizes,
1795  SmallVectorImpl<Value> &newResults) {
1796  auto padValue = padOp.getConstantPaddingValue();
1797  Location loc = padOp.getLoc();
1798 
1799  // TODO: Introduce a parent class that will handle the insertion point update.
1800  OpBuilder::InsertionGuard g(rewriter);
1801  rewriter.setInsertionPoint(padOp);
1802 
1803  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1804  LogicalResult status =
1805  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1806  .reifyResultShapes(rewriter, reifiedReturnShapes);
1807  (void)status; // prevent unused variable warning on non-assert builds
1808  assert(succeeded(status) && "failed to reify result shapes");
1809  auto maskedRead = vector::createReadOrMaskedRead(
1810  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1811  /*useInBoundsInsteadOfMasking=*/false);
1813  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1814  /*useInBoundsInsteadOfMasking=*/false);
1815  newResults.push_back(write->getResult(0));
1816  return success();
1817 }
1818 
1819 // TODO: probably need some extra checks for reduction followed by consumer
1820 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1821 static LogicalResult reductionPreconditions(LinalgOp op) {
1822  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1823  LDBG("reduction precondition failed: no reduction iterator\n");
1824  return failure();
1825  }
1826  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1827  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1828  if (indexingMap.isPermutation())
1829  continue;
1830 
1831  Operation *reduceOp = matchLinalgReduction(&opOperand);
1832  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1833  LDBG("reduction precondition failed: reduction detection failed\n");
1834  return failure();
1835  }
1836  }
1837  return success();
1838 }
1839 
1840 static LogicalResult
1842  bool flatten1DDepthwiseConv) {
1843  if (flatten1DDepthwiseConv) {
1844  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1845  "supported\n");
1846  return failure();
1847  }
1848 
1849  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1850  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1851  return failure();
1852  }
1853 
1854  // Support dynamic shapes in 1D depthwise convolution, but only in the
1855  // _channel_ dimension.
1856  Value lhs = conv.getDpsInputOperand(0)->get();
1857  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1858  auto shapeWithoutCh = lhsShape.drop_back(1);
1859  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1860  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1861  "channel dim can be dynamic\n");
1862  return failure();
1863  }
1864 
1865  return success();
1866 }
1867 
1868 static LogicalResult
1870  bool flatten1DDepthwiseConv) {
1871  if (isa<ConvolutionOpInterface>(op.getOperation()))
1872  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1873 
1874  if (hasReductionIterator(op))
1875  return reductionPreconditions(op);
1876 
1877  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1878  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1879  if (!isElementwise(op) &&
1880  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1881  op.getOperation()))
1882  return failure();
1883 
1884  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1885  return success();
1886 }
1887 
1888 /// Need to check if the inner-tiles are static/constant.
1889 static LogicalResult
1890 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
1891  ArrayRef<int64_t> inputVectorSizes) {
1892 
1893  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1894  return !getConstantIntValue(res).has_value();
1895  })) {
1896  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1897  return failure();
1898  }
1899  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1900  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1901  unpackOp.getDestType().hasStaticShape() &&
1902  unpackOp.getSourceType().hasStaticShape();
1903  if (!satisfyEmptyCond &&
1904  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1905  return failure();
1906 
1907  return success();
1908 }
1909 
1910 static LogicalResult
1911 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1912  ArrayRef<int64_t> inputVectorSizes) {
1913 
1914  TypedValue<RankedTensorType> source = sliceOp.getSource();
1915  auto sourceType = source.getType();
1916  if (!VectorType::isValidElementType(sourceType.getElementType()))
1917  return failure();
1918 
1919  // Get the pad value.
1920  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
1921  // scalar padding value. Note that:
1922  // * for in-bounds accesses,
1923  // the value is actually irrelevant. There are 2 cases in which xfer.read
1924  // accesses are known to be in-bounds:
1925  // 1. The source shape is static (output vector sizes would be based on
1926  // the source shape and hence all memory accesses would be in-bounds),
1927  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
1928  // this case it is safe to assume that all memory accesses are in-bounds.
1929  //
1930  // When the value is not known and not needed, use 0. Otherwise, bail out.
1931  Value padValue = getStaticPadVal(sliceOp);
1932  bool isOutOfBoundsRead =
1933  !sourceType.hasStaticShape() && inputVectorSizes.empty();
1934 
1935  if (!padValue && isOutOfBoundsRead) {
1936  LDBG("Failed to get a pad value for out-of-bounds read access\n");
1937  return failure();
1938  }
1939  return success();
1940 }
1941 
1942 namespace {
1943 enum class ConvOperationKind { Conv, Pool };
1944 } // namespace
1945 
1947  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
1948  isa<BlockArgument>(op->getOperand(0));
1949 }
1950 
1951 // Returns the ConvOperationKind of the op using reduceOp of the generic
1952 // payload. If it is neither a convolution nor a pooling, it returns
1953 // std::nullopt.
1954 //
1955 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
1956 // + yield) and rhs is not used) then it is the body of a pooling
1957 // If conv, check for single `mul` predecessor. The `mul` operands must be
1958 // block arguments or extension of block arguments.
1959 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
1960 // must be block arguments or extension of block arguments.
1961 static std::optional<ConvOperationKind>
1963  int numBlockArguments =
1964  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
1965 
1966  switch (numBlockArguments) {
1967  case 1: {
1968  // Will be convolution if feeder is a MulOp.
1969  // A strength reduced version of MulOp for i1 type is AndOp which is also
1970  // supported. Otherwise, it can be pooling. This strength reduction logic
1971  // is in `buildBinaryFn` helper in the Linalg dialect.
1972  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
1973  llvm::IsaPred<BlockArgument>);
1974  assert(feedValIt != reduceOp->operand_end() &&
1975  "Expected a non-block argument operand");
1976  Operation *feedOp = (*feedValIt).getDefiningOp();
1977  if (isCastOfBlockArgument(feedOp)) {
1978  return ConvOperationKind::Pool;
1979  }
1980 
1981  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1982  (isa<arith::AndIOp>(feedOp) &&
1983  feedOp->getResultTypes()[0].isInteger(1))) &&
1984  llvm::all_of(feedOp->getOperands(), [](Value v) {
1985  if (isa<BlockArgument>(v))
1986  return true;
1987  if (Operation *op = v.getDefiningOp())
1988  return isCastOfBlockArgument(op);
1989  return false;
1990  }))) {
1991  return std::nullopt;
1992  }
1993 
1994  return ConvOperationKind::Conv;
1995  }
1996  case 2:
1997  // Must be pooling
1998  return ConvOperationKind::Pool;
1999  default:
2000  return std::nullopt;
2001  }
2002 }
2003 
2004 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2005  switch (kind) {
2006  case vector::CombiningKind::ADD:
2007  case vector::CombiningKind::MAXNUMF:
2008  case vector::CombiningKind::MAXIMUMF:
2009  case vector::CombiningKind::MAXSI:
2010  case vector::CombiningKind::MAXUI:
2011  case vector::CombiningKind::MINNUMF:
2012  case vector::CombiningKind::MINIMUMF:
2013  case vector::CombiningKind::MINSI:
2015  return true;
2016  default:
2017  return false;
2018  }
2019 }
2020 
2021 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2022  auto getOperandType = [&](auto operand) {
2023  return dyn_cast<ShapedType>((operand->get()).getType());
2024  };
2025  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2026  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2027  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2028  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2029  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2030  // Note that this also ensures 2D and 3D convolutions are rejected.
2031  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2032  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2033  return failure();
2034 
2035  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2036  if (!reduceOp)
2037  return failure();
2038 
2039  auto maybeOper = getConvOperationKind(reduceOp);
2040  if (!maybeOper.has_value())
2041  return failure();
2042 
2043  auto maybeKind = getCombinerOpKind(reduceOp);
2044  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2045  // can get strength reduced to `OR` which is also supported. This strength
2046  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2047  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2048  *maybeKind != vector::CombiningKind::OR) &&
2049  (*maybeOper != ConvOperationKind::Pool ||
2050  !isSupportedPoolKind(*maybeKind)))) {
2051  return failure();
2052  }
2053 
2054  auto rhsRank = rhsShapedType.getRank();
2055  if (*maybeOper == ConvOperationKind::Pool) {
2056  if (rhsRank != 1)
2057  return failure();
2058  } else {
2059  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2060  return failure();
2061  }
2062 
2063  return success();
2064 }
2065 
2066 static LogicalResult vectorizeLinalgOpPrecondition(
2067  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2068  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2069  // tensor with dimension of 0 cannot be vectorized.
2070  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
2071  return failure();
2072  // Check API contract for input vector sizes.
2073  if (!inputVectorSizes.empty() &&
2074  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2075  inputVectorSizes)))
2076  return failure();
2077 
2078  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2079  linalgOp, flatten1DDepthwiseConv))) {
2080  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
2081  return failure();
2082  }
2083 
2085 
2086  // Register CustomVectorizationPrecondition for extractOp.
2087  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2088 
2089  // All types in the body should be a supported element type for VectorType.
2090  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2091  // Check if any custom hook can vectorize the inner op.
2092  if (llvm::any_of(
2093  customPreconditions,
2094  [&](const CustomVectorizationPrecondition &customPrecondition) {
2095  return succeeded(
2096  customPrecondition(&innerOp, vectorizeNDExtract));
2097  })) {
2098  continue;
2099  }
2100  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
2101  return !VectorType::isValidElementType(type);
2102  })) {
2103  return failure();
2104  }
2105  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
2106  return !VectorType::isValidElementType(type);
2107  })) {
2108  return failure();
2109  }
2110  }
2111  if (isElementwise(linalgOp))
2112  return success();
2113 
2114  // TODO: isaConvolutionOpInterface that can also infer from generic
2115  // features. But we will still need stride/dilation attributes that will be
2116  // annoying to reverse-engineer...
2117  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2118  return vectorizeConvOpPrecondition(linalgOp);
2119 
2120  // TODO: the common vector shape is equal to the static loop sizes only when
2121  // all indexing maps are projected permutations. For convs and stencils the
2122  // logic will need to evolve.
2123  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2124  LDBG("precondition failed: not projected permutations\n");
2125  return failure();
2126  }
2127  if (failed(reductionPreconditions(linalgOp))) {
2128  LDBG("precondition failed: reduction preconditions\n");
2129  return failure();
2130  }
2131  return success();
2132 }
2133 
2134 static LogicalResult
2135 vectorizePackOpPrecondition(linalg::PackOp packOp,
2136  ArrayRef<int64_t> inputVectorSizes) {
2137  auto padValue = packOp.getPaddingValue();
2138  Attribute cstAttr;
2139  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2140  LDBG("pad value is not constant: " << packOp << "\n");
2141  return failure();
2142  }
2143  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2144  bool satisfyEmptyCond = true;
2145  if (inputVectorSizes.empty()) {
2146  if (!packOp.getDestType().hasStaticShape() ||
2147  !packOp.getSourceType().hasStaticShape())
2148  satisfyEmptyCond = false;
2149  }
2150 
2151  if (!satisfyEmptyCond &&
2153  resultTensorShape.take_front(packOp.getSourceRank()),
2154  inputVectorSizes)))
2155  return failure();
2156 
2157  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2158  return !getConstantIntValue(v).has_value();
2159  })) {
2160  LDBG("inner_tiles must be constant: " << packOp << "\n");
2161  return failure();
2162  }
2163 
2164  return success();
2165 }
2166 
2167 static LogicalResult
2168 vectorizePadOpPrecondition(tensor::PadOp padOp,
2169  ArrayRef<int64_t> inputVectorSizes) {
2170  auto padValue = padOp.getConstantPaddingValue();
2171  if (!padValue) {
2172  LDBG("pad value is not constant: " << padOp << "\n");
2173  return failure();
2174  }
2175 
2176  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2177  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2178  inputVectorSizes)))
2179  return failure();
2180 
2181  // Padding with non-zero low pad values is not supported, unless the
2182  // corresponding result dim is 1 as this would require shifting the results to
2183  // the right for the low padded dims by the required amount of low padding.
2184  // However, we do support low padding if the dims being low padded have result
2185  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2186  // input size of that dimension will be dynamically zero (as the sum of the
2187  // low pad and input dim size has to be one) and hence we will create a zero
2188  // mask as the lowering logic just makes the mask one for the input dim size -
2189  // which is zero here. Hence we will load the pad value which is what we want
2190  // in this case. If the low pad is dynamically zero then the lowering is
2191  // correct as well as no shifts are necessary.
2192  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2193  Value padValue = en.value();
2194  unsigned pos = en.index();
2195  std::optional<int64_t> pad = getConstantIntValue(padValue);
2196  return (!pad.has_value() || pad.value() != 0) &&
2197  resultTensorShape[pos] != 1;
2198  })) {
2199  LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
2200  return failure();
2201  }
2202 
2203  return success();
2204 }
2205 
2206 /// Preconditions for scalable vectors. This is quite restrictive - it models
2207 /// the fact that in practice we would only make selected dimensions scalable.
2208 static LogicalResult
2210  ArrayRef<int64_t> inputVectorSizes,
2211  ArrayRef<bool> inputScalableVecDims) {
2212  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2213  "Number of input vector sizes and scalable dims doesn't match");
2214 
2215  size_t numOfScalableDims =
2216  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2217 
2218  if (numOfScalableDims == 0)
2219  return success();
2220 
2221  auto linalgOp = dyn_cast<LinalgOp>(op);
2222 
2223  // Cond 1: There's been no need for scalable vectorisation of
2224  // non-linalg Ops so far
2225  if (!linalgOp)
2226  return failure();
2227 
2228  // Cond 2: There's been no need for more than 2 scalable dims so far
2229  if (numOfScalableDims > 2)
2230  return failure();
2231 
2232  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2233  // it matches one of the supported cases:
2234  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2235  // (*).
2236  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2237  // parallel dims.
2238  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2239  // dim.
2240  // The 2nd restriction above means that only Matmul-like Ops are supported
2241  // when 2 dims are scalable, e.g. :
2242  // * iterators = [parallel, parallel, reduction]
2243  // * scalable flags = [true, true, false]
2244  //
2245  // (*) Non-unit dims get folded away in practice.
2246  // TODO: Relax these conditions as good motivating examples are identified.
2247 
2248  // Find the first scalable flag.
2249  bool seenNonUnitParallel = false;
2250  auto iterators = linalgOp.getIteratorTypesArray();
2251  SmallVector<bool> scalableFlags(inputScalableVecDims);
2252  int64_t idx = scalableFlags.size() - 1;
2253  while (!scalableFlags[idx]) {
2254  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2255  seenNonUnitParallel |=
2256  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2257 
2258  iterators.pop_back();
2259  scalableFlags.pop_back();
2260  --idx;
2261  }
2262 
2263  // Analyze the iterator corresponding to the first scalable dim.
2264  switch (iterators.back()) {
2265  case utils::IteratorType::reduction: {
2266  // Check 3. above is met.
2267  if (iterators.size() != inputVectorSizes.size()) {
2268  LDBG("Non-trailing reduction dim requested for scalable "
2269  "vectorization\n");
2270  return failure();
2271  }
2272  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2273  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2274  "is not supported\n");
2275  return failure();
2276  }
2277  break;
2278  }
2279  case utils::IteratorType::parallel: {
2280  // Check 1. and 2. above are met.
2281  if (seenNonUnitParallel) {
2282  LDBG("Inner parallel dim not requested for scalable "
2283  "vectorization\n");
2284  return failure();
2285  }
2286  break;
2287  }
2288  }
2289 
2290  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2291  // supported for which expect the folowing config:
2292  // * iterators = [parallel, parallel, reduction]
2293  // * scalable flags = [true, true, false]
2294  if (numOfScalableDims == 2) {
2295  // Disallow below case which breaks 3. above:
2296  // * iterators = [..., parallel, reduction]
2297  // * scalable flags = [..., true, true]
2298  if (iterators.back() == utils::IteratorType::reduction) {
2299  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2300  "vectorization\n");
2301  return failure();
2302  }
2303  scalableFlags.pop_back();
2304  iterators.pop_back();
2305 
2306  if (!scalableFlags.back() ||
2307  (iterators.back() != utils::IteratorType::parallel))
2308  return failure();
2309  }
2310 
2311  // Check to not let go the matmul with extended semantic, through this
2312  // transform.
2313  if (linalgOp.hasUserDefinedMaps())
2314  return failure();
2315 
2316  // Cond 4: Only the following ops are supported in the
2317  // presence of scalable vectors
2318  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2319  isa<linalg::MatmulTransposeAOp>(op) ||
2320  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2321  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2322 }
2323 
2325  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2326  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2327  bool flatten1DDepthwiseConv) {
2328 
2329  if (!hasVectorizationImpl(op))
2330  return failure();
2331 
2332  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2333  inputScalableVecDims)))
2334  return failure();
2335 
2337  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2338  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2339  vectorizeNDExtract,
2340  flatten1DDepthwiseConv);
2341  })
2342  .Case<tensor::PadOp>([&](auto padOp) {
2343  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2344  })
2345  .Case<linalg::PackOp>([&](auto packOp) {
2346  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2347  })
2348  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2349  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2350  })
2351  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2352  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2353  })
2354  .Default([](auto) { return failure(); });
2355 }
2356 
2357 /// Converts affine.apply Ops to arithmetic operations.
2358 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2359  OpBuilder::InsertionGuard g(rewriter);
2360  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2361 
2362  for (auto op : make_early_inc_range(toReplace)) {
2363  rewriter.setInsertionPoint(op);
2364  auto expanded = affine::expandAffineExpr(
2365  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2366  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2367  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2368  rewriter.replaceOp(op, expanded);
2369  }
2370 }
2371 
2373  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2374  tensor::InsertSliceOp>(op);
2375 }
2376 
2377 /// Emit a suitable vector form for an operation. If provided,
2378 /// `inputVectorSizes` are used to vectorize this operation.
2379 /// `inputVectorSizes` must match the rank of the iteration space of the
2380 /// operation and the input vector sizes must be greater than or equal to
2381 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2382 /// also allows the vectorization of operations with dynamic shapes.
2383 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2384  ArrayRef<int64_t> inputVectorSizes,
2385  ArrayRef<bool> inputScalableVecDims,
2386  bool vectorizeNDExtract,
2387  bool flatten1DDepthwiseConv) {
2388  LDBG("Attempting to vectorize:\n" << *op << "\n");
2389  LDBG("Input vector sizes: ");
2390  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2391  LLVM_DEBUG(llvm::dbgs() << "\n");
2392  LDBG("Input scalable vector dims: ");
2393  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2394  LLVM_DEBUG(llvm::dbgs() << "\n");
2395 
2396  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2397  vectorizeNDExtract,
2398  flatten1DDepthwiseConv))) {
2399  LDBG("Vectorization pre-conditions failed\n");
2400  return failure();
2401  }
2402 
2403  // Initialize vectorization state.
2404  VectorizationState state(rewriter);
2405  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2406  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2407  inputScalableVecDims))) {
2408  LDBG("Vectorization state couldn't be initialized\n");
2409  return failure();
2410  }
2411  }
2412 
2413  SmallVector<Value> results;
2414  auto vectorizeResult =
2416  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2417  // TODO: isaConvolutionOpInterface that can also infer from
2418  // generic features. Will require stride/dilation attributes
2419  // inference.
2420  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2421  FailureOr<Operation *> convOr = vectorizeConvolution(
2422  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2423  flatten1DDepthwiseConv);
2424  if (succeeded(convOr)) {
2425  llvm::append_range(results, (*convOr)->getResults());
2426  return success();
2427  }
2428 
2429  LDBG("Unsupported convolution can't be vectorized.\n");
2430  return failure();
2431  }
2432 
2433  LDBG("Vectorize generic by broadcasting to the canonical vector "
2434  "shape\n");
2435 
2436  // Pre-process before proceeding.
2437  convertAffineApply(rewriter, linalgOp);
2438 
2439  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2440  // to 'OpBuilder' when it is passed over to some methods like
2441  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2442  // erase an op within these methods, the actual rewriter won't be
2443  // notified and we will end up with read-after-free issues!
2444  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2445  })
2446  .Case<tensor::PadOp>([&](auto padOp) {
2447  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2448  results);
2449  })
2450  .Case<linalg::PackOp>([&](auto packOp) {
2451  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2452  results);
2453  })
2454  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2455  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2456  inputVectorSizes, results);
2457  })
2458  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2459  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2460  results);
2461  })
2462  .Default([](auto) { return failure(); });
2463 
2464  if (failed(vectorizeResult)) {
2465  LDBG("Vectorization failed\n");
2466  return failure();
2467  }
2468 
2469  if (!results.empty())
2470  rewriter.replaceOp(op, results);
2471  else
2472  rewriter.eraseOp(op);
2473 
2474  return success();
2475 }
2476 
2478  memref::CopyOp copyOp) {
2479  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2480  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2481  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2482  return failure();
2483 
2484  auto srcElementType = getElementTypeOrSelf(srcType);
2485  auto dstElementType = getElementTypeOrSelf(dstType);
2486  if (!VectorType::isValidElementType(srcElementType) ||
2487  !VectorType::isValidElementType(dstElementType))
2488  return failure();
2489 
2490  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2491  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2492 
2493  Location loc = copyOp->getLoc();
2494  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2495  SmallVector<Value> indices(srcType.getRank(), zero);
2496 
2497  Value readValue = rewriter.create<vector::TransferReadOp>(
2498  loc, readType, copyOp.getSource(), indices,
2499  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2500  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2501  readValue =
2502  rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2503  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2504  }
2505  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2506  loc, readValue, copyOp.getTarget(), indices,
2507  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2508  rewriter.replaceOp(copyOp, writeValue->getResults());
2509  return success();
2510 }
2511 
2512 //----------------------------------------------------------------------------//
2513 // Misc. vectorization patterns.
2514 //----------------------------------------------------------------------------//
2515 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2516 /// given operation type OpTy.
2517 template <typename OpTy>
2518 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2520 
2521  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2522  PatternRewriter &rewriter) const final {
2523  bool changed = false;
2524  // Insert users in vector, because some users may be replaced/removed.
2525  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2526  if (auto op = dyn_cast<OpTy>(user))
2527  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2528  return success(changed);
2529  }
2530 
2531 protected:
2532  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2533  tensor::PadOp padOp, OpTy op) const = 0;
2534 };
2535 
2536 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2537 /// ```
2538 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2539 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2540 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2541 /// ```
2542 /// is rewritten to:
2543 /// ```
2544 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2545 /// {in_bounds = [true, true]}
2546 /// : tensor<?x?xf32>, vector<17x5xf32>
2547 /// ```
2548 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2549 /// sure that the original padding value %cst was never used.
2550 ///
2551 /// This rewrite is possible if:
2552 /// - `xferOp` has no out-of-bounds dims or mask.
2553 /// - Low padding is static 0.
2554 /// - Single, scalar padding value.
2556  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2558  vector::TransferReadOp>::VectorizePadOpUserPattern;
2559 
2560  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2561  vector::TransferReadOp xferOp) const override {
2562  // Low padding must be static 0.
2563  if (!padOp.hasZeroLowPad())
2564  return failure();
2565  // Pad value must be a constant.
2566  auto padValue = padOp.getConstantPaddingValue();
2567  if (!padValue)
2568  return failure();
2569  // Padding value of existing `xferOp` is unused.
2570  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2571  return failure();
2572 
2573  rewriter.modifyOpInPlace(xferOp, [&]() {
2574  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2575  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2576  rewriter.getBoolArrayAttr(inBounds));
2577  xferOp.getSourceMutable().assign(padOp.getSource());
2578  xferOp.getPaddingMutable().assign(padValue);
2579  });
2580 
2581  return success();
2582  }
2583 };
2584 
2585 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2586 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2587 /// value, where the same amount of padding is immediately removed again after
2588 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2589 /// tensor value and apply out-of-bounds masking. E.g.:
2590 /// ```
2591 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2592 /// : tensor<...> to tensor<?x?xf32>
2593 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2594 /// %2 = vector.transfer_write %vec, %1[...]
2595 /// : vector<17x5xf32>, tensor<17x5xf32>
2596 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2597 /// : tensor<17x5xf32> to tensor<?x?xf32>
2598 /// ```
2599 /// is rewritten to:
2600 /// ```
2601 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2602 /// : tensor<...> to tensor<?x?xf32>
2603 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2604 /// tensor<?x?xf32>
2605 /// ```
2606 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2607 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2608 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2609 /// from %r's old dimensions.
2610 ///
2611 /// This rewrite is possible if:
2612 /// - Low padding is static 0.
2613 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2614 /// ExtractSliceOp trims the same amount of padding that was added
2615 /// beforehand.
2616 /// - Single, scalar padding value.
2618  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2620  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2621 
2622  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2623  vector::TransferWriteOp xferOp) const override {
2624  // TODO: support 0-d corner case.
2625  if (xferOp.getTransferRank() == 0)
2626  return failure();
2627 
2628  // Low padding must be static 0.
2629  if (!padOp.hasZeroLowPad())
2630  return failure();
2631  // Pad value must be a constant.
2632  auto padValue = padOp.getConstantPaddingValue();
2633  if (!padValue)
2634  return failure();
2635  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2636  if (!xferOp->hasOneUse())
2637  return failure();
2638  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2639  if (!trimPadding)
2640  return failure();
2641  // Only static zero offsets supported when trimming padding.
2642  if (!trimPadding.hasZeroOffset())
2643  return failure();
2644  // trimPadding must remove the amount of padding that was added earlier.
2645  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2646  return failure();
2647 
2648  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2649  rewriter.setInsertionPoint(xferOp);
2650 
2651  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2652  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2653  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2654  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2655  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2656  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2657 
2658  return success();
2659  }
2660 
2661  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2662  /// i.e., same dimensions.
2663  ///
2664  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2665  /// dimensions, this function tries to infer the (static) tensor size by
2666  /// looking at the defining op and utilizing op-specific knowledge.
2667  ///
2668  /// This is a conservative analysis. In case equal tensor sizes cannot be
2669  /// proven statically, this analysis returns `false` even though the tensor
2670  /// sizes may turn out to be equal at runtime.
2671  bool hasSameTensorSize(Value beforePadding,
2672  tensor::ExtractSliceOp afterTrimming) const {
2673  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2674  // result and CastOp operand.
2675  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2676  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2677  return true;
2678 
2679  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2680  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2681  // Only RankedTensorType supported.
2682  if (!t1 || !t2)
2683  return false;
2684  // Rank of both values must be the same.
2685  if (t1.getRank() != t2.getRank())
2686  return false;
2687 
2688  // All static dimensions must be the same. Mixed cases (e.g., dimension
2689  // static in `t1` but dynamic in `t2`) are not supported.
2690  for (unsigned i = 0; i < t1.getRank(); ++i) {
2691  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2692  return false;
2693  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2694  return false;
2695  }
2696 
2697  // Nothing more to check if all dimensions are static.
2698  if (t1.getNumDynamicDims() == 0)
2699  return true;
2700 
2701  // All dynamic sizes must be the same. The only supported case at the
2702  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2703  // thereof).
2704 
2705  // Apart from CastOp, only ExtractSliceOp is supported.
2706  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2707  if (!beforeSlice)
2708  return false;
2709 
2710  assert(static_cast<size_t>(t1.getRank()) ==
2711  beforeSlice.getMixedSizes().size());
2712  assert(static_cast<size_t>(t2.getRank()) ==
2713  afterTrimming.getMixedSizes().size());
2714 
2715  for (unsigned i = 0; i < t1.getRank(); ++i) {
2716  // Skip static dimensions.
2717  if (!t1.isDynamicDim(i))
2718  continue;
2719  auto size1 = beforeSlice.getMixedSizes()[i];
2720  auto size2 = afterTrimming.getMixedSizes()[i];
2721 
2722  // Case 1: Same value or same constant int.
2723  if (isEqualConstantIntOrValue(size1, size2))
2724  continue;
2725 
2726  // Other cases: Take a deeper look at defining ops of values.
2727  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2728  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2729  if (!v1 || !v2)
2730  return false;
2731 
2732  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2733  // CSE is run.)
2734  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2735  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2736  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2737  minOp1.getOperands() == minOp2.getOperands())
2738  continue;
2739 
2740  // Add additional cases as needed.
2741  }
2742 
2743  // All tests passed.
2744  return true;
2745  }
2746 };
2747 
2748 /// Returns the effective Pad value for the input op, provided it's a scalar.
2749 ///
2750 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2751 /// this Op performs padding, retrieve the padding value provided that it's
2752 /// a scalar and static/fixed for all the padded values. Returns an empty value
2753 /// otherwise.
2754 ///
2755 /// TODO: This is used twice (when checking vectorization pre-conditions and
2756 /// when vectorizing). Cache results instead of re-running.
2758  if (!op)
2759  return {};
2760 
2761  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2762  // being broadcast, provided that it's a scalar.
2763  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2764  auto source = bcast.getSource();
2765  if (llvm::dyn_cast<VectorType>(source.getType()))
2766  return {};
2767 
2768  return source;
2769  }
2770 
2771  // 2. linalg.fill - use the scalar input value that used to fill the output
2772  // tensor.
2773  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2774  return fill.getInputs()[0];
2775  }
2776 
2777  // 3. tensor.generateOp - can't guarantee the value is fixed without
2778  // analysing, bail out.
2779  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2780  return {};
2781  }
2782 
2783  // 4. vector.transfer_write - inspect the input vector that's written from. If
2784  // if contains a single value that has been broadcast (e.g. via
2785  // vector.broadcast), extract it, fail otherwise.
2786  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2787  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2788 
2789  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2790  // than the input tensor, then, provided it's constant, we'll extract the
2791  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2792  // TODO: Clarify the semantics when the input tensor is larger than the
2793  // destination.
2794  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2795  return getStaticPadVal(slice.getDest().getDefiningOp());
2796 
2797  return {};
2798 }
2799 
2800 static LogicalResult
2801 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2802  ArrayRef<int64_t> inputVectorSizes,
2803  SmallVectorImpl<Value> &newResults) {
2804  // TODO: Introduce a parent class that will handle the insertion point update.
2805  OpBuilder::InsertionGuard g(rewriter);
2806  rewriter.setInsertionPoint(sliceOp);
2807 
2808  TypedValue<RankedTensorType> source = sliceOp.getSource();
2809  auto sourceType = source.getType();
2810  auto resultType = sliceOp.getResultType();
2811 
2812  Value padValue = getStaticPadVal(sliceOp);
2813 
2814  if (!padValue) {
2815  auto elemType = sourceType.getElementType();
2816  padValue = rewriter.create<arith::ConstantOp>(
2817  sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2818  }
2819 
2820  // 2. Get the vector shape and in-bounds attributes
2821  SmallVector<int64_t> vecShape;
2822  SmallVector<bool> readInBounds;
2823  SmallVector<bool> writeInBounds;
2824  size_t rankDiff = resultType.getRank() - sourceType.getRank();
2825  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2826  if (!inputVectorSizes.empty()) {
2827  vecShape.push_back(inputVectorSizes[i]);
2828  readInBounds.push_back(false);
2829  writeInBounds.push_back(false);
2830  } else if (!sourceType.isDynamicDim(i)) {
2831  vecShape.push_back(sourceType.getDimSize(i));
2832  // Source shape is statically known: Neither read nor write are
2833  // out-of-bounds.
2834  readInBounds.push_back(true);
2835  writeInBounds.push_back(true);
2836  } else if (!resultType.isDynamicDim(i)) {
2837  // Source shape is not statically known, but result shape is.
2838  // Vectorize with size of result shape. This may be larger than the
2839  // source size.
2840  // FIXME: Using rankDiff implies that the source tensor is inserted at
2841  // the end of the destination tensor. However, that's not required.
2842  vecShape.push_back(resultType.getDimSize(rankDiff + i));
2843  // Read may be out-of-bounds because the result size could be larger
2844  // than the source size.
2845  readInBounds.push_back(false);
2846  // Write will be in-bounds provided that the corresponding write idx is 0.
2847  // To keep this logic simple, conservatively mark as out-of-bounds.
2848  writeInBounds.push_back(false);
2849  } else {
2850  // Neither source nor result dim of padOp is static. Cannot vectorize
2851  // the copy.
2852  // TODO: Add support for masking
2853  return failure();
2854  }
2855  }
2856  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2857 
2858  // 3. Generate TransferReadOp + TransferWriteOp
2859  ReifiedRankedShapedTypeDims reifiedSrcSizes;
2860  Value maskOp;
2861 
2862  // If vector sizes are user provided, make sure to mask. First, generate the
2863  // mask.
2864  if (!inputVectorSizes.empty()) {
2865  auto *srcDefOp = source.getDefiningOp();
2866  if (!srcDefOp) {
2867  LDBG("Unable to get the defining Op of " << sliceOp);
2868  return failure();
2869  }
2870 
2871  LogicalResult status =
2872  cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
2873  rewriter, reifiedSrcSizes);
2874  if (status.failed()) {
2875  LDBG("Unable to reify result shapes of " << srcDefOp);
2876  return failure();
2877  }
2878 
2879  // Create the mask
2880  auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
2881  maskOp = rewriter.create<vector::CreateMaskOp>(
2882  sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
2883  }
2884 
2885  SmallVector<Value> readIndices(
2886  vecType.getRank(),
2887  rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2888  Operation *read = rewriter.create<vector::TransferReadOp>(
2889  sliceOp.getLoc(), vecType, source, readIndices, padValue,
2890  ArrayRef<bool>{readInBounds});
2891 
2892  if (maskOp) {
2893  read = mlir::vector::maskOperation(rewriter, read, maskOp);
2894  }
2895 
2896  auto writeIndices = getValueOrCreateConstantIndexOp(
2897  rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2898 
2899  Operation *write = rewriter.create<vector::TransferWriteOp>(
2900  sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
2901  ArrayRef<bool>{writeInBounds});
2902 
2903  if (maskOp) {
2904  write = mlir::vector::maskOperation(rewriter, write, maskOp);
2905  }
2906 
2907  // 4. Finalize
2908  newResults.push_back(write->getResult(0));
2909 
2910  return success();
2911 }
2912 
2913 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2914 /// ```
2915 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2916 /// %r = tensor.insert_slice %0
2917 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2918 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2919 /// ```
2920 /// is rewritten to:
2921 /// ```
2922 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2923 /// : tensor<?x?xf32>, vector<17x5xf32>
2924 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2925 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2926 /// ```
2927 ///
2928 /// This rewrite is possible if:
2929 /// - Low padding is static 0.
2930 /// - `padOp` result shape is static.
2931 /// - The entire padded tensor is inserted.
2932 /// (Implies that sizes of `insertOp` are all static.)
2933 /// - Only unit strides in `insertOp`.
2934 /// - Single, scalar padding value.
2935 /// - `padOp` result not used as destination.
2937  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2939  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2940 
2941  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2942  tensor::InsertSliceOp insertOp) const override {
2943  // Low padding must be static 0.
2944  if (!padOp.hasZeroLowPad())
2945  return failure();
2946  // Only unit stride supported.
2947  if (!insertOp.hasUnitStride())
2948  return failure();
2949  // Pad value must be a constant.
2950  auto padValue = padOp.getConstantPaddingValue();
2951  if (!padValue)
2952  return failure();
2953  // Dynamic shapes not supported.
2954  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2955  return failure();
2956  // Pad result not used as destination.
2957  if (insertOp.getDest() == padOp.getResult())
2958  return failure();
2959 
2960  auto vecType = VectorType::get(padOp.getType().getShape(),
2961  padOp.getType().getElementType());
2962  unsigned vecRank = vecType.getRank();
2963  unsigned tensorRank = insertOp.getType().getRank();
2964 
2965  // Check if sizes match: Insert the entire tensor into most minor dims.
2966  // (No permutations allowed.)
2967  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2968  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2969  if (!llvm::all_of(
2970  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2971  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2972  }))
2973  return failure();
2974 
2975  // Insert the TransferReadOp and TransferWriteOp at the position of the
2976  // InsertSliceOp.
2977  rewriter.setInsertionPoint(insertOp);
2978 
2979  // Generate TransferReadOp: Read entire source tensor and add high
2980  // padding.
2981  SmallVector<Value> readIndices(
2982  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2983  auto read = rewriter.create<vector::TransferReadOp>(
2984  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2985 
2986  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2987  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2988  // source must fit into the destination at the specified offsets.
2989  auto writeIndices = getValueOrCreateConstantIndexOp(
2990  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2991  SmallVector<bool> inBounds(vecRank, true);
2992  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2993  insertOp, read, insertOp.getDest(), writeIndices,
2994  ArrayRef<bool>{inBounds});
2995 
2996  return success();
2997  }
2998 };
2999 
3001  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3005  patterns.getContext(), baseBenefit.getBenefit() + 1);
3006 }
3007 
3008 //----------------------------------------------------------------------------//
3009 // Forwarding patterns
3010 //----------------------------------------------------------------------------//
3011 
3012 /// Check whether there is any interleaved use of any `values` between
3013 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3014 /// is in a different block.
3015 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3016  ValueRange values) {
3017  if (firstOp->getBlock() != secondOp->getBlock() ||
3018  !firstOp->isBeforeInBlock(secondOp)) {
3019  LDBG("interleavedUses precondition failed, firstOp: "
3020  << *firstOp << ", second op: " << *secondOp << "\n");
3021  return true;
3022  }
3023  for (auto v : values) {
3024  for (auto &u : v.getUses()) {
3025  Operation *owner = u.getOwner();
3026  if (owner == firstOp || owner == secondOp)
3027  continue;
3028  // TODO: this is too conservative, use dominance info in the future.
3029  if (owner->getBlock() == firstOp->getBlock() &&
3030  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3031  continue;
3032  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
3033  << ", second op: " << *secondOp << "\n");
3034  return true;
3035  }
3036  }
3037  return false;
3038 }
3039 
3040 /// Return the unique subview use of `v` if it is indeed unique, null
3041 /// otherwise.
3042 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3043  memref::SubViewOp subViewOp;
3044  for (auto &u : v.getUses()) {
3045  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3046  if (subViewOp)
3047  return memref::SubViewOp();
3048  subViewOp = newSubViewOp;
3049  }
3050  }
3051  return subViewOp;
3052 }
3053 
3054 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3055 /// when available.
3057  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3058 
3059  // TODO: support mask.
3060  if (xferOp.getMask())
3061  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3062 
3063  // Transfer into `view`.
3064  Value viewOrAlloc = xferOp.getSource();
3065  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3066  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3067  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3068 
3069  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3070  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3071  if (!subViewOp)
3072  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3073  Value subView = subViewOp.getResult();
3074 
3075  // Find the copy into `subView` without interleaved uses.
3076  memref::CopyOp copyOp;
3077  for (auto &u : subView.getUses()) {
3078  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3079  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3080  if (newCopyOp.getTarget() != subView)
3081  continue;
3082  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3083  continue;
3084  copyOp = newCopyOp;
3085  break;
3086  }
3087  }
3088  if (!copyOp)
3089  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3090 
3091  // Find the fill into `viewOrAlloc` without interleaved uses before the
3092  // copy.
3093  FillOp maybeFillOp;
3094  for (auto &u : viewOrAlloc.getUses()) {
3095  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3096  assert(isa<MemRefType>(newFillOp.output().getType()));
3097  if (newFillOp.output() != viewOrAlloc)
3098  continue;
3099  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3100  continue;
3101  maybeFillOp = newFillOp;
3102  break;
3103  }
3104  }
3105  // Ensure padding matches.
3106  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3107  return rewriter.notifyMatchFailure(xferOp,
3108  "padding value does not match fill");
3109 
3110  // `in` is the subview that memref.copy reads. Replace it.
3111  Value in = copyOp.getSource();
3112 
3113  // memref.copy + linalg.fill can be used to create a padded local buffer.
3114  // The `masked` attribute is only valid on this padded buffer.
3115  // When forwarding to vector.transfer_read, the attribute must be reset
3116  // conservatively.
3117  auto vectorType = xferOp.getVectorType();
3118  Value res = rewriter.create<vector::TransferReadOp>(
3119  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3120  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3121  rewriter.getBoolArrayAttr(
3122  SmallVector<bool>(vectorType.getRank(), false)));
3123 
3124  if (maybeFillOp)
3125  rewriter.eraseOp(maybeFillOp);
3126  rewriter.eraseOp(copyOp);
3127  rewriter.replaceOp(xferOp, res);
3128 
3129  return success();
3130 }
3131 
3132 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3133 /// when available.
3135  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3136  // TODO: support mask.
3137  if (xferOp.getMask())
3138  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3139 
3140  // Transfer into `viewOrAlloc`.
3141  Value viewOrAlloc = xferOp.getSource();
3142  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3143  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3144  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3145 
3146  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3147  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3148  if (!subViewOp)
3149  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3150  Value subView = subViewOp.getResult();
3151 
3152  // Find the copy from `subView` without interleaved uses.
3153  memref::CopyOp copyOp;
3154  for (auto &u : subViewOp.getResult().getUses()) {
3155  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3156  if (newCopyOp.getSource() != subView)
3157  continue;
3158  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3159  continue;
3160  copyOp = newCopyOp;
3161  break;
3162  }
3163  }
3164  if (!copyOp)
3165  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3166 
3167  // `out` is the subview copied into that we replace.
3168  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3169  Value out = copyOp.getTarget();
3170 
3171  // Forward vector.transfer into copy.
3172  // memref.copy + linalg.fill can be used to create a padded local buffer.
3173  // The `masked` attribute is only valid on this padded buffer.
3174  // When forwarding to vector.transfer_write, the attribute must be reset
3175  // conservatively.
3176  auto vector = xferOp.getVector();
3177  rewriter.create<vector::TransferWriteOp>(
3178  xferOp.getLoc(), vector, out, xferOp.getIndices(),
3179  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3181  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3182 
3183  rewriter.eraseOp(copyOp);
3184  rewriter.eraseOp(xferOp);
3185 
3186  return success();
3187 }
3188 
3189 //===----------------------------------------------------------------------===//
3190 // Convolution vectorization patterns
3191 //===----------------------------------------------------------------------===//
3192 
3193 template <int N>
3194 static void bindShapeDims(ShapedType shapedType) {}
3195 
3196 template <int N, typename IntTy, typename... IntTy2>
3197 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3198  val = shapedType.getShape()[N];
3199  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3200 }
3201 
3202 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3203 template <typename... IntTy>
3204 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3205  bindShapeDims<0>(shapedType, vals...);
3206 }
3207 
3208 namespace {
3209 /// Generate a vector implementation for either:
3210 /// ```
3211 /// Op def: ( w, kw )
3212 /// Iters: ({Par(), Red()})
3213 /// Layout: {{w + kw}, {kw}, {w}}
3214 /// ```
3215 /// kw is unrolled.
3216 ///
3217 /// or
3218 ///
3219 /// ```
3220 /// Op def: ( n, w, c, kw, f )
3221 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3222 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3223 /// ```
3224 /// kw is unrolled, w is unrolled iff dilationW > 1.
3225 ///
3226 /// or
3227 ///
3228 /// ```
3229 /// Op def: ( n, c, w, f, kw )
3230 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3231 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3232 /// ```
3233 /// kw is unrolled, w is unrolled iff dilationW > 1.
3234 ///
3235 /// or
3236 ///
3237 /// ```
3238 /// Op def: ( n, w, c, kw )
3239 /// Iters: ({Par(), Par(), Par(), Red()})
3240 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3241 /// ```
3242 /// kw is unrolled, w is unrolled iff dilationW > 1.
3243 struct Conv1DGenerator
3244  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3245  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3246  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3247 
3248  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3249  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3250  resShaped = linalgOp.getDpsInitOperand(0)->get();
3251  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3252  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3253  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3254 
3255  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3256  redOp = reduceOp->getName().getIdentifier();
3257 
3258  setConvOperationKind(reduceOp);
3259 
3260  auto maybeKind = getCombinerOpKind(reduceOp);
3261  reductionKind = maybeKind.value();
3262 
3263  // The ConvolutionOpInterface gives us guarantees of existence for
3264  // strides/dilations. However, we do not need to rely on those, we can
3265  // simply use them if present, otherwise use the default and let the generic
3266  // conv. matcher in the ConvGenerator succeed or fail.
3267  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3268  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3269  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3270  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3271  }
3272 
3273  /// Generate a vector implementation for:
3274  /// ```
3275  /// Op def: ( w, kw )
3276  /// Iters: ({Par(), Red()})
3277  /// Layout: {{w + kw}, {kw}, {w}}
3278  /// ```
3279  /// kw is always unrolled.
3280  ///
3281  /// or
3282  ///
3283  /// ```
3284  /// Op def: ( n, w, c, kw, f )
3285  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3286  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3287  /// ```
3288  /// kw is always unrolled.
3289  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3290  /// > 1.
3291  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3292  int64_t nSize, wSize, cSize, kwSize, fSize;
3293  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3294  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3295  switch (conv1DOpOrder) {
3296  case Conv1DOpOrder::W:
3297  // Initialize unused dimensions
3298  nSize = fSize = cSize = 0;
3299  // out{W}
3300  bindShapeDims(resShapedType, wSize);
3301  // kernel{kw}
3302  bindShapeDims(rhsShapedType, kwSize);
3303  lhsShape = {// iw = ow + kw - 1
3304  // (i.e. 16 convolved with 3 -> 14)
3305  (wSize + kwSize - 1)};
3306  rhsShape = {kwSize};
3307  resShape = {wSize};
3308  break;
3309  case Conv1DOpOrder::Nwc:
3310  // out{n, w, f}
3311  bindShapeDims(resShapedType, nSize, wSize, fSize);
3312  switch (oper) {
3313  case ConvOperationKind::Conv:
3314  // kernel{kw, c, f}
3315  bindShapeDims(rhsShapedType, kwSize, cSize);
3316  break;
3317  case ConvOperationKind::Pool:
3318  // kernel{kw}
3319  bindShapeDims(rhsShapedType, kwSize);
3320  cSize = fSize;
3321  break;
3322  }
3323  lhsShape = {nSize,
3324  // iw = ow * sw + kw * dw - 1
3325  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3326  // Perform the proper inclusive -> exclusive -> inclusive.
3327  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3328  1,
3329  cSize};
3330  switch (oper) {
3331  case ConvOperationKind::Conv:
3332  rhsShape = {kwSize, cSize, fSize};
3333  break;
3334  case ConvOperationKind::Pool:
3335  rhsShape = {kwSize};
3336  break;
3337  }
3338  resShape = {nSize, wSize, fSize};
3339  break;
3340  case Conv1DOpOrder::Ncw:
3341  // out{n, f, w}
3342  bindShapeDims(resShapedType, nSize, fSize, wSize);
3343  switch (oper) {
3344  case ConvOperationKind::Conv:
3345  // kernel{f, c, kw}
3346  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3347  break;
3348  case ConvOperationKind::Pool:
3349  // kernel{kw}
3350  bindShapeDims(rhsShapedType, kwSize);
3351  cSize = fSize;
3352  break;
3353  }
3354  lhsShape = {nSize, cSize,
3355  // iw = ow * sw + kw * dw - 1
3356  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3357  // Perform the proper inclusive -> exclusive -> inclusive.
3358  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3359  1};
3360  switch (oper) {
3361  case ConvOperationKind::Conv:
3362  rhsShape = {fSize, cSize, kwSize};
3363  break;
3364  case ConvOperationKind::Pool:
3365  rhsShape = {kwSize};
3366  break;
3367  }
3368  resShape = {nSize, fSize, wSize};
3369  break;
3370  }
3371 
3372  vector::TransferWriteOp write;
3373  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3374 
3375  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3376  // When strideW == 1, we can batch the contiguous loads and avoid
3377  // unrolling
3378  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3379 
3380  Type lhsEltType = lhsShapedType.getElementType();
3381  Type rhsEltType = rhsShapedType.getElementType();
3382  Type resEltType = resShapedType.getElementType();
3383  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3384  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3385  auto resType = VectorType::get(resShape, resEltType);
3386  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3387  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3388  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3389  SmallVector<Value> resPadding(resShape.size(), zero);
3390 
3391  // Read the whole lhs, rhs and res in one shot (with zero padding).
3392  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3393  lhsPadding);
3394  // This is needed only for Conv.
3395  Value rhs = nullptr;
3396  if (oper == ConvOperationKind::Conv)
3397  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3398  rhsPadding);
3399  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3400  resPadding);
3401 
3402  // The base vectorization case for channeled convolution is input:
3403  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3404  // vectorization case, we do pre transpose on input, weight, and output.
3405  switch (conv1DOpOrder) {
3406  case Conv1DOpOrder::W:
3407  case Conv1DOpOrder::Nwc:
3408  // Base case, so no transposes necessary.
3409  break;
3410  case Conv1DOpOrder::Ncw: {
3411  // To match base vectorization case, we pre-transpose current case.
3412  // ncw -> nwc
3413  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3414  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3415  // fcw -> wcf
3416  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3417 
3418  // This is needed only for Conv.
3419  if (oper == ConvOperationKind::Conv)
3420  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3421  // nfw -> nwf
3422  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3423  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3424  break;
3425  }
3426  }
3427 
3428  //===------------------------------------------------------------------===//
3429  // Begin vector-only rewrite part
3430  //===------------------------------------------------------------------===//
3431  // Unroll along kw and read slices of lhs and rhs.
3432  SmallVector<Value> lhsVals, rhsVals, resVals;
3433  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3434  kwSize, strideW, dilationW, wSizeStep,
3435  isSingleChanneled);
3436  // Do not do for pooling.
3437  if (oper == ConvOperationKind::Conv)
3438  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3439  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3440  wSizeStep, isSingleChanneled);
3441 
3442  auto linearIndex = [&](int64_t kw, int64_t w) {
3443  return kw * (wSize / wSizeStep) + w;
3444  };
3445 
3446  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3447  // or perform outerproduct for non-channeled convolution or perform simple
3448  // arith operation for pooling
3449  for (int64_t kw = 0; kw < kwSize; ++kw) {
3450  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3451  switch (oper) {
3452  case ConvOperationKind::Conv:
3453  if (isSingleChanneled) {
3454  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3455  lhsVals[linearIndex(kw, w)],
3456  rhsVals[kw], resVals[w]);
3457  } else {
3458  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3459  lhsVals[linearIndex(kw, w)],
3460  rhsVals[kw], resVals[w]);
3461  }
3462  break;
3463  case ConvOperationKind::Pool:
3464  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3465  resVals[w]);
3466  break;
3467  }
3468  }
3469  }
3470 
3471  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3472  isSingleChanneled);
3473  //===------------------------------------------------------------------===//
3474  // End vector-only rewrite part
3475  //===------------------------------------------------------------------===//
3476 
3477  // The base vectorization case for channeled convolution is output:
3478  // {n,w,f} To reuse the result from base pattern vectorization case, we
3479  // post transpose the base case result.
3480  switch (conv1DOpOrder) {
3481  case Conv1DOpOrder::W:
3482  case Conv1DOpOrder::Nwc:
3483  // Base case, so no transposes necessary.
3484  break;
3485  case Conv1DOpOrder::Ncw: {
3486  // nwf -> nfw
3487  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3488  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3489  break;
3490  }
3491  }
3492 
3493  return rewriter
3494  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3495  .getOperation();
3496  }
3497 
3498  // Take a value and widen to have the same element type as `ty`.
3499  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3500  const Type srcElementType = getElementTypeOrSelf(val.getType());
3501  const Type dstElementType = getElementTypeOrSelf(ty);
3502  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3503  if (srcElementType == dstElementType)
3504  return val;
3505 
3506  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3507  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3508  const Type dstType =
3509  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3510 
3511  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3512  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3513  }
3514 
3515  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3516  srcWidth < dstWidth)
3517  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3518 
3519  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3520  srcWidth < dstWidth)
3521  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3522 
3523  assert(false && "unhandled promotion case");
3524  return nullptr;
3525  }
3526 
3527  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3528  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3529  Value lhs, Value rhs, Value res) {
3530  vector::IteratorType par = vector::IteratorType::parallel;
3531  vector::IteratorType red = vector::IteratorType::reduction;
3532  AffineExpr n, w, f, c;
3533  bindDims(ctx, n, w, f, c);
3534  lhs = promote(rewriter, loc, lhs, res.getType());
3535  rhs = promote(rewriter, loc, rhs, res.getType());
3536  auto contrationOp = rewriter.create<vector::ContractionOp>(
3537  loc, lhs, rhs, res,
3538  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3539  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3540  contrationOp.setKind(reductionKind);
3541  return contrationOp;
3542  }
3543 
3544  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3545  // convolution.
3546  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3547  Value lhs, Value rhs, Value res) {
3548  return rewriter.create<vector::OuterProductOp>(
3549  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3550  }
3551 
3552  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3553  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3554  Value res) {
3555  if (isPoolExt)
3556  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3557  return rewriter
3558  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3559  ->getResult(0);
3560  }
3561 
3562  /// Generate a vector implementation for:
3563  /// ```
3564  /// Op def: ( n, w, c, kw)
3565  /// Iters: ({Par(), Par(), Par(), Red()})
3566  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3567  /// ```
3568  /// kw is always unrolled.
3569  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3570  /// > 1.
3571  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3572  bool channelDimScalableFlag,
3573  bool flatten) {
3574  bool scalableChDim = false;
3575  bool useMasking = false;
3576  int64_t nSize, wSize, cSize, kwSize;
3577  // kernel{kw, c}
3578  bindShapeDims(rhsShapedType, kwSize, cSize);
3579  if (ShapedType::isDynamic(cSize)) {
3580  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3581  cSize = channelDimVecSize;
3582  // Scalable vectors are only used when both conditions are met:
3583  // 1. channel dim is dynamic
3584  // 2. channelDimScalableFlag is set
3585  scalableChDim = channelDimScalableFlag;
3586  useMasking = true;
3587  }
3588 
3589  assert(!(useMasking && flatten) &&
3590  "Unsupported flattened conv with dynamic shapes");
3591 
3592  // out{n, w, c}
3593  bindShapeDims(resShapedType, nSize, wSize);
3594 
3595  vector::TransferWriteOp write;
3596  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3597 
3598  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3599  // When strideW == 1, we can batch the contiguous loads and avoid
3600  // unrolling
3601  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3602 
3603  Type lhsEltType = lhsShapedType.getElementType();
3604  Type rhsEltType = rhsShapedType.getElementType();
3605  Type resEltType = resShapedType.getElementType();
3606  VectorType lhsType = VectorType::get(
3607  {nSize,
3608  // iw = ow * sw + kw * dw - 1
3609  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3610  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3611  cSize},
3612  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3613  VectorType rhsType =
3614  VectorType::get({kwSize, cSize}, rhsEltType,
3615  /*scalableDims=*/{false, scalableChDim});
3616  VectorType resType =
3617  VectorType::get({nSize, wSize, cSize}, resEltType,
3618  /*scalableDims=*/{false, false, scalableChDim});
3619 
3620  // Masks the input xfer Op along the channel dim, iff the corresponding
3621  // scalable flag is set.
3622  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3623  ArrayRef<bool> scalableDims,
3624  Operation *opToMask) {
3625  if (!useMasking)
3626  return opToMask;
3627  auto maskType =
3628  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3629 
3630  SmallVector<bool> inBounds(maskShape.size(), true);
3631  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3632  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3633  rewriter.getBoolArrayAttr(inBounds));
3634 
3636  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3637 
3638  Value maskOp =
3639  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3640 
3641  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3642  };
3643 
3644  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3645  // 0].
3646  Value lhs = rewriter.create<vector::TransferReadOp>(
3647  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3648  auto maybeMaskedLhs = maybeMaskXferOp(
3649  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3650 
3651  // Read rhs slice of size {kw, c} @ [0, 0].
3652  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3653  ValueRange{zero, zero});
3654  auto maybeMaskedRhs = maybeMaskXferOp(
3655  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3656 
3657  // Read res slice of size {n, w, c} @ [0, 0, 0].
3658  Value res = rewriter.create<vector::TransferReadOp>(
3659  loc, resType, resShaped, ValueRange{zero, zero, zero});
3660  auto maybeMaskedRes = maybeMaskXferOp(
3661  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3662 
3663  //===------------------------------------------------------------------===//
3664  // Begin vector-only rewrite part
3665  //===------------------------------------------------------------------===//
3666  // Unroll along kw and read slices of lhs and rhs.
3667  SmallVector<Value> lhsVals, rhsVals, resVals;
3668  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3669  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3670 
3671  // Extract lhs slice of size {n, wSizeStep, c}
3672  // @ [0, sw * w + dw * kw, 0].
3673  for (int64_t kw = 0; kw < kwSize; ++kw) {
3674  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3675  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3676  loc, maybeMaskedLhs->getResult(0),
3677  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3678  inOutSliceSizes, inOutStrides));
3679  }
3680  }
3681  // Extract rhs slice of size {c} @ [kw].
3682  for (int64_t kw = 0; kw < kwSize; ++kw) {
3683  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3684  loc, maybeMaskedRhs->getResult(0),
3685  /*offsets=*/ArrayRef<int64_t>{kw}));
3686  }
3687  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3688  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3689  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3690  loc, maybeMaskedRes->getResult(0),
3691  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3692  inOutStrides));
3693  }
3694 
3695  auto linearIndex = [&](int64_t kw, int64_t w) {
3696  return kw * (wSize / wSizeStep) + w;
3697  };
3698 
3699  // Note - the scalable flags are ignored as flattening combined with
3700  // scalable vectorization is not supported.
3701  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3702  auto lhsTypeAfterFlattening =
3703  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3704  auto resTypeAfterFlattening =
3705  VectorType::get(inOutFlattenSliceSizes, resEltType);
3706 
3707  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3708  for (int64_t kw = 0; kw < kwSize; ++kw) {
3709  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3710  Value lhsVal = lhsVals[linearIndex(kw, w)];
3711  Value resVal = resVals[w];
3712  if (flatten) {
3713  // Flatten the input and output vectors (collapse the channel
3714  // dimension)
3715  lhsVal = rewriter.create<vector::ShapeCastOp>(
3716  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3717  resVal = rewriter.create<vector::ShapeCastOp>(
3718  loc, resTypeAfterFlattening, resVals[w]);
3719  }
3720  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3721  rhsVals[kw], resVal, flatten);
3722  if (flatten) {
3723  // Un-flatten the output vector (restore the channel dimension)
3724  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3725  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3726  }
3727  }
3728  }
3729 
3730  // Its possible we failed to create the Fma.
3731  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3732  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3733  for (auto &collection :
3734  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3735  for (Value v : collection)
3736  rewriter.eraseOp(v.getDefiningOp());
3737  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3738  }
3739 
3740  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3741  // This does not depend on kw.
3742  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3743  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3744  loc, resVals[w], maybeMaskedRes->getResult(0),
3745  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3746  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3747  }
3748  //===------------------------------------------------------------------===//
3749  // End vector-only rewrite part
3750  //===------------------------------------------------------------------===//
3751 
3752  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3753  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3754  loc, maybeMaskedRes->getResult(0), resShaped,
3755  ValueRange{zero, zero, zero});
3756  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3757  resOut);
3758  }
3759 
3760  /// Lower:
3761  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3762  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3763  /// to MulAcc.
3764  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3765  Value lhs, Value rhs, Value res,
3766  bool flatten) {
3767  auto rhsTy = cast<ShapedType>(rhs.getType());
3768  auto resTy = cast<ShapedType>(res.getType());
3769 
3770  // TODO(suderman): Change this to use a vector.ima intrinsic.
3771  lhs = promote(rewriter, loc, lhs, resTy);
3772 
3773  if (flatten) {
3774  // NOTE: This following logic won't work for scalable vectors. For this
3775  // reason, "flattening" is not supported when shapes are dynamic (this
3776  // should be captured by one of the pre-conditions).
3777 
3778  // There are two options for handling the filter:
3779  // * shape_cast(broadcast(filter))
3780  // * broadcast(shuffle(filter))
3781  // Opt for the option without shape_cast to simplify the codegen.
3782  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3783  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3784 
3785  SmallVector<int64_t, 16> indices;
3786  for (int i = 0; i < resSize / rhsSize; ++i) {
3787  for (int j = 0; j < rhsSize; ++j)
3788  indices.push_back(j);
3789  }
3790 
3791  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3792  }
3793  // Broadcast the filter to match the output vector
3794  rhs = rewriter.create<vector::BroadcastOp>(
3795  loc, resTy.clone(rhsTy.getElementType()), rhs);
3796 
3797  rhs = promote(rewriter, loc, rhs, resTy);
3798 
3799  if (!lhs || !rhs)
3800  return nullptr;
3801 
3802  if (isa<FloatType>(resTy.getElementType()))
3803  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3804 
3805  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3806  return rewriter.create<arith::AddIOp>(loc, mul, res);
3807  }
3808 
3809  /// Entry point for non-channeled convolution:
3810  /// {{w + kw}, {kw}, {w}}
3811  FailureOr<Operation *> generateNonChanneledConv() {
3812  AffineExpr w, kw;
3813  bindDims(ctx, w, kw);
3814  if (!iters({Par(), Red()}))
3815  return rewriter.notifyMatchFailure(op,
3816  "failed to match conv::W 1-par 1-red");
3817 
3818  // No transposition needed.
3819  if (layout({/*lhsIndex*/ {w + kw},
3820  /*rhsIndex*/ {kw},
3821  /*resIndex*/ {w}}))
3822  return conv(Conv1DOpOrder::W);
3823 
3824  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3825  }
3826 
3827  /// Entry point that transposes into the common form:
3828  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3829  FailureOr<Operation *> generateNwcConv() {
3830  AffineExpr n, w, f, kw, c;
3831  bindDims(ctx, n, w, f, kw, c);
3832  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3833  return rewriter.notifyMatchFailure(
3834  op, "failed to match conv::Nwc 3-par 2-red");
3835 
3836  // No transposition needed.
3837  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3838  /*rhsIndex*/ {kw, c, f},
3839  /*resIndex*/ {n, w, f}}))
3840  return conv(Conv1DOpOrder::Nwc);
3841 
3842  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3843  }
3844 
3845  /// Entry point that transposes into the common form:
3846  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3847  FailureOr<Operation *> generateNcwConv() {
3848  AffineExpr n, w, f, kw, c;
3849  bindDims(ctx, n, f, w, c, kw);
3850  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3851  return rewriter.notifyMatchFailure(
3852  op, "failed to match conv::Ncw 3-par 2-red");
3853 
3854  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3855  /*rhsIndex*/ {f, c, kw},
3856  /*resIndex*/ {n, f, w}}))
3857  return conv(Conv1DOpOrder::Ncw);
3858 
3859  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3860  }
3861 
3862  /// Entry point that transposes into the common form:
3863  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3864  FailureOr<Operation *> generateNwcPooling() {
3865  AffineExpr n, w, c, kw;
3866  bindDims(ctx, n, w, c, kw);
3867  if (!iters({Par(), Par(), Par(), Red()}))
3868  return rewriter.notifyMatchFailure(op,
3869  "failed to match pooling 3-par 1-red");
3870 
3871  // No transposition needed.
3872  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3873  /*rhsIndex*/ {kw},
3874  /*resIndex*/ {n, w, c}}))
3875  return conv(Conv1DOpOrder::Nwc);
3876 
3877  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3878  }
3879 
3880  /// Entry point that transposes into the common form:
3881  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3882  FailureOr<Operation *> generateNcwPooling() {
3883  AffineExpr n, w, c, kw;
3884  bindDims(ctx, n, c, w, kw);
3885  if (!iters({Par(), Par(), Par(), Red()}))
3886  return rewriter.notifyMatchFailure(op,
3887  "failed to match pooling 3-par 1-red");
3888 
3889  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3890  /*rhsIndex*/ {kw},
3891  /*resIndex*/ {n, c, w}}))
3892  return conv(Conv1DOpOrder::Ncw);
3893 
3894  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3895  }
3896 
3897  /// Entry point that transposes into the common form:
3898  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3899  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3900  bool vecChDimScalableFlag = false,
3901  bool flatten = false) {
3902  AffineExpr n, w, c, kw;
3903  bindDims(ctx, n, w, c, kw);
3904  if (!iters({Par(), Par(), Par(), Red()}))
3905  return rewriter.notifyMatchFailure(
3906  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3907 
3908  // No transposition needed.
3909  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3910  /*rhsIndex*/ {kw, c},
3911  /*resIndex*/ {n, w, c}}))
3912  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3913 
3914  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3915  }
3916 
3917 private:
3918  ConvOperationKind oper = ConvOperationKind::Conv;
3919  StringAttr redOp;
3920  StringAttr poolExtOp;
3921  bool isPoolExt = false;
3922  int strideW, dilationW;
3923  Value lhsShaped, rhsShaped, resShaped;
3924  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3925  vector::CombiningKind reductionKind;
3926 
3927  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3928  void setConvOperationKind(Operation *reduceOp) {
3929  int numBlockArguments =
3930  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3931  if (numBlockArguments == 1) {
3932  // Will be convolution if feeder is a MulOp.
3933  // A strength reduced version of MulOp for i1 type is AndOp which is also
3934  // supported. Otherwise, it can be pooling. This strength reduction logic
3935  // is in `buildBinaryFn` helper in the Linalg dialect.
3936  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3937  llvm::IsaPred<BlockArgument>);
3938  Operation *feedOp = (*feedValIt).getDefiningOp();
3939  if (isCastOfBlockArgument(feedOp)) {
3940  oper = ConvOperationKind::Pool;
3941  isPoolExt = true;
3942  poolExtOp = feedOp->getName().getIdentifier();
3943  return;
3944  }
3945  oper = ConvOperationKind::Conv;
3946  return;
3947  }
3948  // numBlockArugments == 2 and this is a pooling op.
3949  oper = ConvOperationKind::Pool;
3950  isPoolExt = false;
3951  }
3952 };
3953 } // namespace
3954 
3955 /// Helper function to vectorize a LinalgOp with convolution semantics.
3956 // TODO: extend the generic vectorization to support windows and drop this.
3957 static FailureOr<Operation *> vectorizeConvolution(
3958  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3959  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3960  Conv1DGenerator conv1dGen(rewriter, op);
3961  auto res = conv1dGen.generateNonChanneledConv();
3962  if (succeeded(res))
3963  return res;
3964  res = conv1dGen.generateNwcConv();
3965  if (succeeded(res))
3966  return res;
3967  res = conv1dGen.generateNcwConv();
3968  if (succeeded(res))
3969  return res;
3970  res = conv1dGen.generateNwcPooling();
3971  if (succeeded(res))
3972  return res;
3973  res = conv1dGen.generateNcwPooling();
3974  if (succeeded(res))
3975  return res;
3976 
3977  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3978  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3979  // masked/scalable) is the channel dim (i.e. the trailing dim).
3980  uint64_t vecChDimSize = ShapedType::kDynamic;
3981  bool vecChDimScalableFlag = false;
3982  if (!inputVecSizes.empty()) {
3983  // Only use the input vector size corresponding to the channel dim. Other
3984  // vector dims will be inferred from the Ops.
3985  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3986  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3987  "Not a 1D depthwise conv!");
3988  size_t chDimIdx =
3990  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3991  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3992 
3993  vecChDimSize = inputVecSizes[chDimIdx];
3994  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3995  }
3996  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3997  flatten1DDepthwiseConv);
3998 }
3999 
4002 
4003  LogicalResult matchAndRewrite(LinalgOp op,
4004  PatternRewriter &rewriter) const override {
4005  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4006  if (failed(resultOrFail))
4007  return failure();
4008  Operation *newOp = *resultOrFail;
4009  if (newOp->getNumResults() == 0) {
4010  rewriter.eraseOp(op.getOperation());
4011  return success();
4012  }
4013  assert(newOp->getNumResults() == 1 && "expected single result");
4014  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4015  return success();
4016  }
4017 };
4018 
4021  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4022 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4511
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4510
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4509
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:142
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:606
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:307
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:670
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:219
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:203
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:242
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:649
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2498
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:815
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:334
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:336
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.