MLIR  22.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Support/LLVM.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 /// Try to vectorize `convOp` as a convolution.
53 static FailureOr<Operation *>
54 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
55  ArrayRef<int64_t> inputVecSizes = {},
56  ArrayRef<bool> inputVecScalableFlags = {},
57  bool flatten1DDepthwiseConv = false);
58 
59 /// Vectorize tensor::InsertSliceOp with:
60 /// * vector::TransferReadOp + vector::TransferWriteOp
61 /// The vector sizes are either:
62 /// * user-provided in `inputVectorSizes`, or
63 /// * inferred from the static dims in the input and output tensors.
64 /// Bails out if:
65 /// * vector sizes are not user-provided, and
66 /// * at least one dim is dynamic (in both the input and output tensors).
67 ///
68 /// Before:
69 /// !t_in_type = tensor<1x2x3xf32>
70 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
71 /// !v_type = vector<1x2x3xf32>
72 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
73 /// into !t_out_type
74 /// After:
75 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
76 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
77 static LogicalResult
78 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
79  ArrayRef<int64_t> inputVectorSizes,
80  SmallVectorImpl<Value> &newResults);
81 
82 /// Returns the effective Pad value for the input op, provided it's a scalar.
83 ///
84 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
85 /// this Op performs padding, retrieve the padding value provided that it's
86 /// a scalar and static/fixed for all the padded values. Returns an empty value
87 /// otherwise.
88 static Value getStaticPadVal(Operation *op);
89 
90 /// Return the unique instance of OpType in `block` if it is indeed unique.
91 /// Return null if none or more than 1 instances exist.
92 template <typename OpType>
93 static OpType getSingleOpOfType(Block &block) {
94  OpType res;
95  block.walk([&](OpType op) {
96  if (res) {
97  res = nullptr;
98  return WalkResult::interrupt();
99  }
100  res = op;
101  return WalkResult::advance();
102  });
103  return res;
104 }
105 
106 /// Helper function to extract the input slices after filter is unrolled along
107 /// kw.
108 static SmallVector<Value>
110  int64_t nSize, int64_t wSize, int64_t cSize,
111  int64_t kwSize, int strideW, int dilationW,
112  int64_t wSizeStep, bool isSingleChanneled) {
113  SmallVector<Value> result;
114  if (isSingleChanneled) {
115  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
116  // convolution.
117  SmallVector<int64_t> sizes = {wSizeStep};
118  SmallVector<int64_t> strides = {1};
119  for (int64_t kw = 0; kw < kwSize; ++kw) {
120  for (int64_t w = 0; w < wSize; w += wSizeStep) {
121  result.push_back(vector::ExtractStridedSliceOp::create(
122  rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
123  strides));
124  }
125  }
126  } else {
127  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128  // for channeled convolution.
129  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130  SmallVector<int64_t> strides = {1, 1, 1};
131  for (int64_t kw = 0; kw < kwSize; ++kw) {
132  for (int64_t w = 0; w < wSize; w += wSizeStep) {
133  result.push_back(vector::ExtractStridedSliceOp::create(
134  rewriter, loc, input,
135  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136  sizes, strides));
137  }
138  }
139  }
140  return result;
141 }
142 
143 /// Helper function to extract the filter slices after filter is unrolled along
144 /// kw.
146  Location loc, Value filter,
147  int64_t kwSize) {
148  SmallVector<Value> result;
149  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150  // non-chanelled convolution] @ [kw].
151  for (int64_t kw = 0; kw < kwSize; ++kw) {
152  result.push_back(vector::ExtractOp::create(
153  rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
154  }
155  return result;
156 }
157 
158 /// Helper function to extract the result slices after filter is unrolled along
159 /// kw.
160 static SmallVector<Value>
162  int64_t nSize, int64_t wSize, int64_t fSize,
163  int64_t wSizeStep, bool isSingleChanneled) {
164  SmallVector<Value> result;
165  if (isSingleChanneled) {
166  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167  SmallVector<int64_t> sizes = {wSizeStep};
168  SmallVector<int64_t> strides = {1};
169  for (int64_t w = 0; w < wSize; w += wSizeStep) {
170  result.push_back(vector::ExtractStridedSliceOp::create(
171  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
172  strides));
173  }
174  } else {
175  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
176  // convolution.
177  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
178  SmallVector<int64_t> strides = {1, 1, 1};
179  for (int64_t w = 0; w < wSize; w += wSizeStep) {
180  result.push_back(vector::ExtractStridedSliceOp::create(
181  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
182  strides));
183  }
184  }
185  return result;
186 }
187 
188 /// Helper function to insert the computed result slices.
190  Value res, int64_t wSize, int64_t wSizeStep,
191  SmallVectorImpl<Value> &resVals,
192  bool isSingleChanneled) {
193 
194  if (isSingleChanneled) {
195  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196  // This does not depend on kw.
197  SmallVector<int64_t> strides = {1};
198  for (int64_t w = 0; w < wSize; w += wSizeStep) {
199  res = vector::InsertStridedSliceOp::create(
200  rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
201  strides);
202  }
203  } else {
204  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
205  // convolution. This does not depend on kw.
206  SmallVector<int64_t> strides = {1, 1, 1};
207  for (int64_t w = 0; w < wSize; w += wSizeStep) {
208  res = vector::InsertStridedSliceOp::create(
209  rewriter, loc, resVals[w], res,
210  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
211  }
212  }
213  return res;
214 }
215 
216 /// Contains the vectorization state and related methods used across the
217 /// vectorization process of a given operation.
219  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
220 
221  /// Initializes the vectorization state, including the computation of the
222  /// canonical vector shape for vectorization.
223  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224  ArrayRef<int64_t> inputVectorSizes,
225  ArrayRef<bool> inputScalableVecDims,
226  bool assumeDynamicDimsMatchVecSizes = false);
227 
228  /// Returns the canonical vector shape used to vectorize the iteration space.
229  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
230 
231  /// Returns the vector dimensions that are scalable in the canonical vector
232  /// shape.
233  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
234 
235  /// Returns a vector type of the provided `elementType` with the canonical
236  /// vector shape and the corresponding fixed/scalable dimensions bit. If
237  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
238  /// accordingly.
240  Type elementType,
241  std::optional<AffineMap> dimPermutation = std::nullopt) const {
243  SmallVector<bool> scalableDims;
244  if (dimPermutation.has_value()) {
245  vectorShape =
246  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
247  scalableDims =
248  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
249  } else {
250  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
252  }
253 
254  return VectorType::get(vectorShape, elementType, scalableDims);
255  }
256 
257  /// Masks an operation with the canonical vector mask if the operation needs
258  /// masking. Returns the masked operation or the original operation if masking
259  /// is not needed. If provided, the canonical mask for this operation is
260  /// permuted using `maybeIndexingMap`.
261  Operation *
262  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
263  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
264 
265 private:
266  /// Initializes the iteration space static sizes using the Linalg op
267  /// information. This may become more complicated in the future.
268  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
270  }
271 
272  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
273  /// all the static and dynamic dimensions of the iteration space to be
274  /// vectorized and store them in `iterSpaceValueSizes`.
275  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276  LinalgOp linalgOp);
277 
278  /// Create or retrieve an existing mask value to mask `opToMask` in the
279  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
280  /// permuted using that permutation map. If a new mask is created, it will be
281  /// cached for future users.
282  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
283  LinalgOp linalgOp,
284  std::optional<AffineMap> maybeMaskingMap);
285 
286  /// Check whether this permutation map can be used for masking. At the
287  /// moment we only make sure that there are no broadcast dimensions, but this
288  /// might change if indexing maps evolve.
289  bool isValidMaskingMap(AffineMap maskingMap) {
290  return maskingMap.getBroadcastDims().size() == 0;
291  }
292 
293  /// Turn the input indexing map into a valid masking map.
294  ///
295  /// The input indexing map may contain "zero" results, e.g.:
296  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
297  /// Applying such maps to canonical vector shapes like this one:
298  /// (1, 16, 16, 4)
299  /// would yield an invalid vector shape like this:
300  /// (16, 16, 1, 0)
301  /// Instead, drop the broadcasting dims that make no sense for masking perm.
302  /// maps:
303  /// (d0, d1, d2, d3) -> (d2, d1, d0)
304  /// This way, the corresponding vector/mask type will be:
305  /// vector<16x16x1xty>
306  /// rather than this invalid Vector type:
307  /// vector<16x16x1x0xty>
308  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
309  return indexingMap.dropZeroResults();
310  }
311 
312  // Holds the compile-time static sizes of the iteration space to vectorize.
313  // Dynamic dimensions are represented using ShapedType::kDynamic.
314  SmallVector<int64_t> iterSpaceStaticSizes;
315 
316  /// Holds the value sizes of the iteration space to vectorize. Static
317  /// dimensions are represented by 'arith.constant' and dynamic
318  /// dimensions by 'tensor/memref.dim'.
319  SmallVector<Value> iterSpaceValueSizes;
320 
321  /// Holds the canonical vector shape used to vectorize the iteration space.
322  SmallVector<int64_t> canonicalVecShape;
323 
324  /// Holds the vector dimensions that are scalable in the canonical vector
325  /// shape.
326  SmallVector<bool> scalableVecDims;
327 
328  /// Holds the active masks for permutations of the canonical vector iteration
329  /// space.
330  DenseMap<AffineMap, Value> activeMaskCache;
331 
332  /// Global vectorization guard for the incoming rewriter. It's initialized
333  /// when the vectorization state is initialized.
334  OpBuilder::InsertionGuard rewriterGuard;
335 
336  /// Do all dynamic dims match the corresponding vector sizes?
337  ///
338  /// When a dynamic tensor/memref dimension matches the corresponding vector
339  /// dimension, masking can be safely skipped, despite the presence of dynamic
340  /// shapes. Use this flag with care and only for cases where you are
341  /// confident the assumption holds.
342  bool assumeDynamicDimsMatchVecSizes = false;
343 };
344 
345 LogicalResult
346 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
347  LinalgOp linalgOp) {
348  // TODO: Support 0-d vectors.
349  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350  if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
351  // Create constant index op for static dimensions.
352  iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
353  rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
354  continue;
355  }
356 
357  // Find an operand defined on this dimension of the iteration space to
358  // extract the runtime dimension size.
359  Value operand;
360  unsigned operandDimPos;
361  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
362  operandDimPos)))
363  return failure();
364 
365  Value dynamicDim =
366  linalgOp.hasPureTensorSemantics()
367  ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
368  operandDimPos)
369  : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370  operandDimPos);
371  iterSpaceValueSizes.push_back(dynamicDim);
372  }
373 
374  return success();
375 }
376 
377 /// Initializes the vectorization state, including the computation of the
378 /// canonical vector shape for vectorization.
379 // TODO: Move this to the constructor when we can remove the failure cases.
381  LinalgOp linalgOp,
382  ArrayRef<int64_t> inputVectorSizes,
383  ArrayRef<bool> inputScalableVecDims,
384  bool assumeDimsMatchVec) {
385  assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
386  // Initialize the insertion point.
387  rewriter.setInsertionPoint(linalgOp);
388 
389  if (!inputVectorSizes.empty()) {
390  // Get the canonical vector shape from the input vector sizes provided. This
391  // path should be taken to vectorize code with dynamic shapes and when using
392  // vector sizes greater than the iteration space sizes.
393  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394  scalableVecDims.append(inputScalableVecDims.begin(),
395  inputScalableVecDims.end());
396  } else {
397  // Compute the canonical vector shape from the operation shape. If there are
398  // dynamic shapes, the operation won't be vectorized. We assume all the
399  // vector dimensions are fixed.
400  canonicalVecShape = linalgOp.getStaticLoopRanges();
401  scalableVecDims.append(linalgOp.getNumLoops(), false);
402  }
403 
404  LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405  LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
406 
407  if (ShapedType::isDynamicShape(canonicalVecShape))
408  return failure();
409 
410  // Initialize iteration space static sizes.
411  initIterSpaceStaticSizes(linalgOp);
412 
413  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
414  // all the static and dynamic dimensions of the iteration space, needed to
415  // compute a mask during vectorization.
416  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417  return failure();
418 
419  return success();
420 }
421 
422 /// Create or retrieve an existing mask value to mask `opToMask` in the
423 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
424 /// using that permutation map. If a new mask is created, it will be cached for
425 /// future users.
426 Value VectorizationState::getOrCreateMaskFor(
427  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
428  std::optional<AffineMap> maybeMaskingMap) {
429 
430  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431  "Ill-formed masking map.");
432 
433  // No mask is needed if the operation is not maskable.
434  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
435  if (!maskableOp)
436  return Value();
437 
438  assert(!maskableOp.isMasked() &&
439  "Masking an operation that is already masked");
440 
441  // If no masking map was provided, use an identity map with the loop dims.
442  assert((!maybeMaskingMap || *maybeMaskingMap) &&
443  "Unexpected null mask permutation map");
444  AffineMap maskingMap =
445  maybeMaskingMap ? *maybeMaskingMap
447  linalgOp.getNumLoops(), rewriter.getContext());
448 
449  LDBG() << "Masking map: " << maskingMap;
450 
451  // Return the active mask for the masking map of this operation if it was
452  // already created.
453  auto activeMaskIt = activeMaskCache.find(maskingMap);
454  if (activeMaskIt != activeMaskCache.end()) {
455  Value mask = activeMaskIt->second;
456  LDBG() << "Reusing mask: " << mask;
457  return mask;
458  }
459 
460  // Compute permuted projection of the iteration space to be masked and the
461  // corresponding mask shape. If the resulting iteration space dimensions are
462  // static and identical to the mask shape, masking is not needed for this
463  // operation.
464  // TODO: Improve this check. Only projected permutation indexing maps are
465  // supported.
466  SmallVector<int64_t> permutedStaticSizes =
467  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
469  auto maskShape = maskType.getShape();
470 
471  LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
472 
473  if (permutedStaticSizes == maskShape) {
474  LDBG() << "Masking is not needed for masking map: " << maskingMap;
475  activeMaskCache[maskingMap] = Value();
476  return Value();
477  }
478 
479  if (assumeDynamicDimsMatchVecSizes) {
480  // While for _dynamic_ dim sizes we can _assume_ that the corresponding
481  // vector sizes match, we still need to check the _static_ dim sizes. Only
482  // then we can be 100% sure that masking is not required.
483  if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
484  [](auto it) {
485  return std::get<0>(it) == ShapedType::kDynamic
486  ? true
487  : std::get<0>(it) == std::get<1>(it);
488  })) {
489  LDBG()
490  << "Dynamic + static dimensions match vector sizes, masking is not "
491  "required.";
492  activeMaskCache[maskingMap] = Value();
493  return Value();
494  }
495  }
496 
497  // Permute the iteration space value sizes to compute the mask upper bounds.
498  SmallVector<Value> upperBounds =
499  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
500  assert(!maskShape.empty() && !upperBounds.empty() &&
501  "Masked 0-d vectors are not supported yet");
502 
503  // Create the mask based on the dimension values.
504  Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505  maskType, upperBounds);
506  LDBG() << "Creating new mask: " << mask;
507  activeMaskCache[maskingMap] = mask;
508  return mask;
509 }
510 
511 Operation *
513  LinalgOp linalgOp,
514  std::optional<AffineMap> maybeIndexingMap) {
515  LDBG() << "Trying to mask: " << *opToMask;
516 
517  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518  if (maybeIndexingMap)
519  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
520 
521  // Create or retrieve mask for this operation.
522  Value mask =
523  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
524 
525  if (!mask) {
526  LDBG() << "No mask required";
527  return opToMask;
528  }
529 
530  // Wrap the operation with a new `vector.mask` and update D-U chain.
531  assert(opToMask && "Expected a valid operation to mask");
532  auto maskOp = cast<vector::MaskOp>(
533  mlir::vector::maskOperation(rewriter, opToMask, mask));
534  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
535 
536  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
537  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
538  maskOpTerminator);
539 
540  LDBG() << "Masked operation: " << *maskOp;
541  return maskOp;
542 }
543 
544 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
545 /// projectedPermutation, compress the unused dimensions to serve as a
546 /// permutation_map for a vector transfer operation.
547 /// For example, given a linalg op such as:
548 ///
549 /// ```
550 /// %0 = linalg.generic {
551 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
552 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
553 /// }
554 /// ins(%0 : tensor<2x3x4xf32>)
555 /// outs(%1 : tensor<5x6xf32>)
556 /// ```
557 ///
558 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
559 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
560 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
562  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
563  "expected projected permutation");
564  auto res = compressUnusedDims(map);
565  assert(res.getNumDims() ==
566  (res.getNumResults() - res.getNumOfZeroResults()) &&
567  "expected reindexed map with same number of dims and results");
568  return res;
569 }
570 
571 /// Helper enum to represent conv1d input traversal order.
572 enum class Conv1DOpOrder {
573  W, // Corresponds to non-channeled 1D convolution operation.
574  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
575  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
576 };
577 
578 /// Helper data structure to represent the result of vectorization for a single
579 /// operation. In certain specific cases, like terminators, we do not want to
580 /// propagate.
582  /// Op failed to vectorize.
583  Failure = 0,
584  /// Op vectorized and custom function took care of replacement logic
586  /// Op vectorized into a new Op whose results will replace original Op's
587  /// results.
588  NewOp
589  // TODO: support values if Op vectorized to Many-Ops whose results we need to
590  // aggregate for replacement.
591 };
592 /// VectorizationHookResult contains the vectorized op returned from a
593 /// CustomVectorizationHook. This is an internal implementation detail of
594 /// linalg vectorization, not to be confused with VectorizationResult.
596  /// Return status from vectorizing the current op.
598  /// New vectorized operation to replace the current op.
599  /// Replacement behavior is specified by `status`.
601 };
602 
603 std::optional<vector::CombiningKind>
605  using ::mlir::vector::CombiningKind;
606 
607  if (!combinerOp)
608  return std::nullopt;
610  .Case<arith::AddIOp, arith::AddFOp>(
611  [&](auto op) { return CombiningKind::ADD; })
612  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
613  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
614  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
615  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
616  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
617  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
618  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
619  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
620  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
621  .Case<arith::MulIOp, arith::MulFOp>(
622  [&](auto op) { return CombiningKind::MUL; })
623  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
624  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
625  .Default([&](auto op) { return std::nullopt; });
626 }
627 
628 /// Check whether `outputOperand` is a reduction with a single combiner
629 /// operation. Return the combiner operation of the reduction. Return
630 /// nullptr otherwise. Multiple reduction operations would impose an
631 /// ordering between reduction dimensions and is currently unsupported in
632 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
633 /// max(min(X))
634 // TODO: use in LinalgOp verification, there is a circular dependency atm.
635 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
636  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
637  unsigned outputPos =
638  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
639  // Only single combiner operations are supported for now.
640  SmallVector<Operation *, 4> combinerOps;
641  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
642  combinerOps.size() != 1)
643  return nullptr;
644 
645  // Return the combiner operation.
646  return combinerOps[0];
647 }
648 
649 /// Broadcast `value` to a vector of `shape` if possible. Return value
650 /// otherwise.
651 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
652  auto dstVecType = dyn_cast<VectorType>(dstType);
653  // If no shape to broadcast to, just return `value`.
654  if (dstVecType.getRank() == 0)
655  return value;
656  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
658  return value;
659  Location loc = b.getInsertionPoint()->getLoc();
660  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
661 }
662 
663 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
664 /// assumes that `reductionOp` has two operands and one of them is the reduction
665 /// initial value.buildMultiDimReduce
666 // Note: this is a true builder that notifies the OpBuilder listener.
667 // TODO: Consider moving as a static helper on the ReduceOp.
669  Value valueToReduce, Value acc,
670  ArrayRef<bool> dimsToMask) {
671  auto maybeKind = getCombinerOpKind(reduceOp);
672  assert(maybeKind && "Failed precondition: could not get reduction kind");
673  return vector::MultiDimReductionOp::create(
674  b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
675 }
676 
677 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
678  return llvm::to_vector(
679  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
680 }
681 
682 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
683 /// reduction iterator.
684 static bool hasReductionIterator(LinalgOp &op) {
685  return isa<linalg::ReduceOp>(op) ||
686  (isa<linalg::GenericOp>(op) &&
687  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
688 }
689 
690 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
691 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
692 /// currently being vectorized. If `dest` has null rank, build an memref.store.
693 /// Return the produced value or null if no value is produced.
694 // Note: this is a true builder that notifies the OpBuilder listener.
695 // TODO: Consider moving as a static helper on the ReduceOp.
696 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
697  OpOperand *outputOperand,
698  VectorizationState &state) {
699  Location loc = value.getLoc();
700  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
701  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
702 
703  // Compute the vector type of the value to store. This type should be an
704  // identity or projection of the canonical vector type without any permutation
705  // applied, given that any permutation in a transfer write happens as part of
706  // the write itself.
708  opOperandMap.getContext(), opOperandMap.getNumInputs(),
709  [&](AffineDimExpr dimExpr) -> bool {
710  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
711  });
712  auto vectorType = state.getCanonicalVecType(
713  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
714 
715  Operation *write;
716  if (vectorType.getRank() > 0) {
717  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
718  SmallVector<Value> indices(
719  linalgOp.getRank(outputOperand),
720  arith::ConstantIndexOp::create(rewriter, loc, 0));
721  value = broadcastIfNeeded(rewriter, value, vectorType);
722  assert(value.getType() == vectorType && "Incorrect type");
723  write = vector::TransferWriteOp::create(
724  rewriter, loc, value, outputOperand->get(), indices, writeMap);
725  } else {
726  // 0-d case is still special: do not invert the reindexing writeMap.
727  if (!isa<VectorType>(value.getType()))
728  value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
729  assert(value.getType() == vectorType && "Incorrect type");
730  write = vector::TransferWriteOp::create(rewriter, loc, value,
731  outputOperand->get(), ValueRange{});
732  }
733 
734  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
735 
736  // If masked, set in-bounds to true. Masking guarantees that the access will
737  // be in-bounds.
738  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
739  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
740  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
741  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
742  }
743 
744  LDBG() << "vectorized op: " << *write;
745  if (!write->getResults().empty())
746  return write->getResult(0);
747  return Value();
748 }
749 
750 // Custom vectorization precondition function type. This is intented to be used
751 // with CustomVectorizationHook. Returns success if the corresponding custom
752 // hook can vectorize the op.
754  std::function<LogicalResult(Operation *, bool)>;
755 
756 // Custom vectorization function type. Produce a vector form of Operation*
757 // assuming all its vectorized operands are already in the IRMapping.
758 // Return nullptr if the Operation cannot be vectorized.
760  std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
761 
762 /// Helper function to vectorize the terminator of a `linalgOp`. New result
763 /// vector values are appended to `newResults`. Return
764 /// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
765 /// that it should not try to map produced operations and instead return the
766 /// results using the `newResults` vector making them available to the
767 /// vectorization algorithm for RAUW. This function is meant to be used as a
768 /// CustomVectorizationHook.
771  const IRMapping &bvm, VectorizationState &state,
772  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
773  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
774  if (!yieldOp)
776  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
777  // TODO: Scan for an opportunity for reuse.
778  // TODO: use a map.
779  Value vectorValue = bvm.lookup(output.value());
780  Value newResult =
781  buildVectorWrite(rewriter, vectorValue,
782  linalgOp.getDpsInitOperand(output.index()), state);
783  if (newResult)
784  newResults.push_back(newResult);
785  }
786 
788 }
789 
790 /// Helper function to vectorize the index operations of a `linalgOp`. Return
791 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
792 /// should map the produced operations. This function is meant to be used as a
793 /// CustomVectorizationHook.
795  VectorizationState &state,
796  Operation *op,
797  LinalgOp linalgOp) {
798  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
799  if (!indexOp)
801  auto loc = indexOp.getLoc();
802  // Compute the static loop sizes of the index op.
803  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
804  auto dim = indexOp.getDim();
805  // Compute a one-dimensional index vector for the index op dimension.
806  auto indexVectorType =
807  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
808  state.getScalableVecDims()[dim]);
809  auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
810  // Return the one-dimensional index vector if it lives in the trailing
811  // dimension of the iteration space since the vectorization algorithm in this
812  // case can handle the broadcast.
813  if (dim == targetShape.size() - 1)
815  // Otherwise permute the targetShape to move the index dimension last,
816  // broadcast the one-dimensional index vector to the permuted shape, and
817  // finally transpose the broadcasted index vector to undo the permutation.
818  auto permPattern =
819  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
820  std::swap(permPattern[dim], permPattern.back());
821  auto permMap =
822  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
823 
824  auto broadCastOp = vector::BroadcastOp::create(
825  rewriter, loc,
826  state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
827  SmallVector<int64_t> transposition =
828  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
829  std::swap(transposition.back(), transposition[dim]);
830  auto transposeOp =
831  vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
833 }
834 
835 /// Helper function to check if the tensor.extract can be vectorized by the
836 /// custom hook vectorizeTensorExtract.
837 static LogicalResult
838 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
839  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
840  if (!extractOp)
841  return failure();
842 
843  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
844  return failure();
845 
846  // Check the index type, but only for non 0-d tensors (for which we do need
847  // access indices).
848  if (not extractOp.getIndices().empty()) {
849  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
850  return failure();
851  }
852 
853  if (!llvm::all_of(extractOp->getResultTypes(),
854  VectorType::isValidElementType)) {
855  return failure();
856  }
857 
858  return success();
859 }
860 
861 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
862 /// generated from `tensor.extract`. The offset is calculated as follows
863 /// (example using scalar values):
864 ///
865 /// offset = extractOp.indices[0]
866 /// for (i = 1; i < numIndices; i++)
867 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
868 ///
869 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
870 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
872  VectorizationState &state,
873  tensor::ExtractOp extractOp,
874  const IRMapping &bvm) {
875  // The vector of indices for GatherOp should be shaped as the output vector.
876  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
877  auto loc = extractOp.getLoc();
878 
879  Value offset = broadcastIfNeeded(
880  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
881 
882  const size_t numIndices = extractOp.getIndices().size();
883  for (size_t i = 1; i < numIndices; i++) {
884  Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
885 
886  auto dimSize = broadcastIfNeeded(
887  rewriter,
888  tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
889  indexVecType);
890 
891  offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
892 
893  auto extractOpIndex = broadcastIfNeeded(
894  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
895 
896  offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
897  }
898 
899  return offset;
900 }
901 
903 
904 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
905 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
906 /// represents a contiguous load operation.
907 ///
908 /// Note that when calling this hook, it is assumed that the output vector is
909 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
910 /// labelled as a gather load before entering this method.
911 ///
912 /// Following on from the above, it is assumed that:
913 /// * for statically shaped loops, when no masks are used, only one dim is !=
914 /// 1 (that's what the shape of the output vector is based on).
915 /// * for dynamically shaped loops, there might be more non-unit dims
916 /// as the output vector type is user-specified.
917 ///
918 /// TODO: Statically shaped loops + vector masking
919 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
920  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
921  assert(
922  (linalgOp.hasDynamicShape() ||
923  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
924  "For statically shaped Linalg Ops, only one "
925  "non-unit loop dim is expected");
926  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
927 
928  size_t idx = loopRanges.size() - 1;
929  for (; idx != 0; idx--)
930  if (loopRanges[idx] != 1)
931  break;
932 
933  return idx;
934 }
935 
936 /// Checks whether `val` can be used for calculating a loop invariant index.
937 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
938  VectorType resType) {
939 
940  assert(((llvm::count_if(resType.getShape(),
941  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942  "n-D vectors are not yet supported");
943 
944  // Blocks outside _this_ linalg.generic are effectively loop invariant.
945  // However, analysing block arguments for _this_ linalg.generic Op is a bit
946  // tricky. Just bail out in the latter case.
947  // TODO: We could try analysing the corresponding affine map here.
948  auto *block = linalgOp.getBlock();
949  if (isa<BlockArgument>(val))
950  return !llvm::is_contained(block->getArguments(), val);
951 
952  Operation *defOp = val.getDefiningOp();
953  assert(defOp && "This is neither a block argument nor an operation result");
954 
955  // IndexOp is loop invariant as long as its result remains constant across
956  // iterations. Note that for dynamic shapes, the corresponding dim will also
957  // be conservatively treated as != 1.
958  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
960  }
961 
962  auto *ancestor = block->findAncestorOpInBlock(*defOp);
963 
964  // Values define outside `linalgOp` are loop invariant.
965  if (!ancestor)
966  return true;
967 
968  // Values defined inside `linalgOp`, which are constant, are loop invariant.
969  if (isa<arith::ConstantOp>(ancestor))
970  return true;
971 
972  bool result = true;
973  for (auto op : ancestor->getOperands())
974  result &= isLoopInvariantIdx(linalgOp, op, resType);
975 
976  return result;
977 }
978 
979 /// Check whether `val` could be used for calculating the trailing index for a
980 /// contiguous load operation.
981 ///
982 /// There are currently 3 types of values that are allowed here:
983 /// 1. loop-invariant values,
984 /// 2. values that increment by 1 with every loop iteration,
985 /// 3. results of basic arithmetic operations (linear and continuous)
986 /// involving 1., 2. and 3.
987 /// This method returns True if indeed only such values are used in calculating
988 /// `val.`
989 ///
990 /// Additionally, the trailing index for a contiguous load operation should
991 /// increment by 1 with every loop iteration, i.e. be based on:
992 /// * `linalg.index <dim>` ,
993 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
994 /// `linalg.index <dim>` increments by 1 with every loop iteration).
995 /// `foundIndexOp` is updated to `true` when such Op is found.
996 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
997  bool &foundIndexOp, VectorType resType) {
998 
999  assert(((llvm::count_if(resType.getShape(),
1000  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1001  "n-D vectors are not yet supported");
1002 
1003  // Blocks outside _this_ linalg.generic are effectively loop invariant.
1004  // However, analysing block arguments for _this_ linalg.generic Op is a bit
1005  // tricky. Just bail out in the latter case.
1006  // TODO: We could try analysing the corresponding affine map here.
1007  auto *block = linalgOp.getBlock();
1008  if (isa<BlockArgument>(val))
1009  return !llvm::is_contained(block->getArguments(), val);
1010 
1011  Operation *defOp = val.getDefiningOp();
1012  assert(defOp && "This is neither a block argument nor an operation result");
1013 
1014  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1015  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1016 
1017  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1018  return true;
1019  }
1020 
1021  auto *ancestor = block->findAncestorOpInBlock(*defOp);
1022 
1023  if (!ancestor)
1024  return false;
1025 
1026  // Conservatively reject Ops that could lead to indices with stride other
1027  // than 1.
1028  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1029  return false;
1030 
1031  bool result = false;
1032  for (auto op : ancestor->getOperands())
1033  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1034 
1035  return result;
1036 }
1037 
1038 /// Infer the memory access pattern for the input ExtractOp
1039 ///
1040 /// Based on the ExtratOp result shape and the access indices, decides whether
1041 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1042 /// or a gather load. When analysing the ExtractOp indices (to identify
1043 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1044 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1045 /// Op).
1046 ///
1047 /// Note that it is always safe to use gather load operations for contiguous
1048 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1049 /// that `extractOp` is a gather load.
1051 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1052  LinalgOp &linalgOp, VectorType resType) {
1053 
1054  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1055 
1056  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1057  if (inputShape.getShape().empty())
1059 
1060  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1061  // otherwise.
1062  bool isOutput1DVector =
1063  (llvm::count_if(resType.getShape(),
1064  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1065  // 1. Assume that it's a gather load when reading non-1D vector.
1066  if (!isOutput1DVector)
1068 
1069  bool leadingIdxsLoopInvariant = true;
1070 
1071  // 2. Analyze the leading indices of `extractOp`.
1072  // Look at the way each index is calculated and decide whether it is suitable
1073  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1074  // gather load.
1075  auto indices = extractOp.getIndices();
1076  auto leadIndices = indices.drop_back(1);
1077 
1078  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1079  if (inputShape.getShape()[i] == 1)
1080  continue;
1081 
1082  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1083  }
1084 
1085  if (!leadingIdxsLoopInvariant) {
1086  LDBG() << "Found gather load: " << extractOp;
1088  }
1089 
1090  // 3. Analyze the trailing index for `extractOp`.
1091  // At this point we know that the leading indices are loop invariant. This
1092  // means that is potentially a scalar or a contiguous load. We can decide
1093  // based on the trailing idx.
1094  auto extractOpTrailingIdx = indices.back();
1095 
1096  // 3a. Scalar broadcast load
1097  // If the trailing index is loop invariant then this is a scalar load.
1098  if (leadingIdxsLoopInvariant &&
1099  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1100  LDBG() << "Found scalar broadcast load: " << extractOp;
1101 
1103  }
1104 
1105  // 3b. Contiguous loads
1106  // The trailing `extractOp` index should increment with every loop iteration.
1107  // This effectively means that it must be based on the trailing loop index.
1108  // This is what the following bool captures.
1109  bool foundIndexOp = false;
1110  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1111  foundIndexOp, resType);
1112  // TODO: Support generating contiguous loads for column vectors - that will
1113  // require adding a permutation map to tranfer_read Ops.
1114  bool isRowVector = resType.getShape().back() != 1;
1115  isContiguousLoad &= (foundIndexOp && isRowVector);
1116 
1117  if (isContiguousLoad) {
1118  LDBG() << "Found contigous load: " << extractOp;
1120  }
1121 
1122  // 4. Fallback case - gather load.
1123  LDBG() << "Found gather load: " << extractOp;
1125 }
1126 
1127 /// Helper function to vectorize the tensor.extract operations. Returns
1128 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1129 /// should map the produced operations. This function is meant to be used as a
1130 /// CustomVectorizationHook.
1133  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1134  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1135  if (!extractOp)
1137  auto loc = extractOp.getLoc();
1138 
1139  // Compute the static loop sizes of the extract op.
1140  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1141  auto maskConstantOp = arith::ConstantOp::create(
1142  rewriter, loc,
1143  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1144  /*value=*/true));
1145  auto passThruConstantOp = arith::ConstantOp::create(
1146  rewriter, loc, rewriter.getZeroAttr(resultType));
1147 
1148  // Base indices are currently set to 0. We will need to re-visit if more
1149  // generic scenarios are to be supported.
1150  SmallVector<Value> baseIndices(
1151  extractOp.getIndices().size(),
1152  arith::ConstantIndexOp::create(rewriter, loc, 0));
1153 
1154  VectorMemoryAccessKind memAccessKind =
1155  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1156 
1157  // 1. Handle gather access
1158  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1159  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1160 
1161  // Generate the gather load
1162  Operation *gatherOp = vector::GatherOp::create(
1163  rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1164  maskConstantOp, passThruConstantOp);
1165  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1166 
1167  LDBG() << "Vectorised as gather load: " << extractOp;
1169  }
1170 
1171  // 2. Handle:
1172  // a. scalar loads + broadcast,
1173  // b. contiguous loads.
1174  // Both cases use vector.transfer_read.
1175 
1176  // Collect indices for `vector.transfer_read`. At this point, the indices will
1177  // either be scalars or would have been broadcast to vectors matching the
1178  // result type. For indices that are vectors, there are two options:
1179  // * for non-trailing indices, all elements are identical (contiguous
1180  // loads are identified by looking for non-trailing indices that are
1181  // invariant with respect to the corresponding linalg.generic), or
1182  // * for trailing indices, the index vector will contain values with stride
1183  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1184  // needed.
1185  // This means that
1186  // * for scalar indices - just re-use it,
1187  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1188  // (0th) element and use that.
1189  SmallVector<Value> transferReadIdxs;
1190  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1191  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1192  if (idx.getType().isIndex()) {
1193  transferReadIdxs.push_back(idx);
1194  continue;
1195  }
1196 
1197  auto indexAs1dVector = vector::ShapeCastOp::create(
1198  rewriter, loc,
1199  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1200  resultType.getScalableDims().back()),
1201  idx);
1202  transferReadIdxs.push_back(
1203  vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1204  }
1205 
1206  // `tensor.extract_element` is always in-bounds, hence the following holds.
1207  auto dstRank = resultType.getRank();
1208  auto srcRank = extractOp.getTensor().getType().getRank();
1209  SmallVector<bool> inBounds(dstRank, true);
1210 
1211  // 2a. Handle scalar broadcast access.
1212  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1213  MLIRContext *ctx = rewriter.getContext();
1214  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1215  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1216 
1217  auto transferReadOp = vector::TransferReadOp::create(
1218  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1219  /*padding=*/std::nullopt, permutationMap, inBounds);
1220 
1221  // Mask this broadcasting xfer_read here rather than relying on the generic
1222  // path (the generic path assumes identity masking map, which wouldn't be
1223  // valid here).
1224  SmallVector<int64_t> readMaskShape = {1};
1225  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1226  auto allTrue = vector::ConstantMaskOp::create(
1227  rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1228  auto *maskedReadOp =
1229  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1230 
1231  LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1233  maskedReadOp};
1234  }
1235 
1236  // 2b. Handle contiguous access.
1237  auto permutationMap = AffineMap::getMinorIdentityMap(
1238  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1239 
1240  int32_t rankDiff = dstRank - srcRank;
1241  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1242  // dims so that the ranks match. This is done by extending the map with 0s.
1243  // For example, for dstRank = 3, srcRank = 2, the following map created
1244  // above:
1245  // (d0, d1) --> (d0, d1)
1246  // is extended as:
1247  // (d0, d1) --> (0, d0, d1)
1248  while (rankDiff > 0) {
1249  permutationMap = permutationMap.insertResult(
1250  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1251  rankDiff--;
1252  }
1253 
1254  auto transferReadOp = vector::TransferReadOp::create(
1255  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1256  /*padding=*/std::nullopt, permutationMap, inBounds);
1257 
1258  LDBG() << "Vectorised as contiguous load: " << extractOp;
1260  transferReadOp};
1261 }
1262 
1263 /// Emit reduction operations if the shapes of the value to reduce is different
1264 /// that the result shape.
1265 // Note: this is a true builder that notifies the OpBuilder listener.
1266 // TODO: Consider moving as a static helper on the ReduceOp.
1267 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1268  Value reduceValue, Value initialValue,
1269  const IRMapping &bvm) {
1270  Value reduceVec = bvm.lookup(reduceValue);
1271  Value outputVec = bvm.lookup(initialValue);
1272  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1273  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1274  // Reduce only if needed as the value may already have been reduce for
1275  // contraction vectorization.
1276  if (!reduceType ||
1277  (outputType && reduceType.getShape() == outputType.getShape()))
1278  return nullptr;
1279  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1280  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1281 }
1282 
1283 /// Generic vectorization for a single operation `op`, given already vectorized
1284 /// operands carried by `bvm`. Vectorization occurs as follows:
1285 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1286 /// result on success.
1287 /// 2. Clone any constant in the current scope without vectorization: each
1288 /// consumer of the constant will later determine the shape to which the
1289 /// constant needs to be broadcast to.
1290 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1291 /// of the `customVectorizationHooks` to cover such cases.
1292 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1293 /// operand of maximal rank. Other operands have smaller rank and are
1294 /// broadcast accordingly. It is assumed this broadcast is always legal,
1295 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1296 ///
1297 /// This function assumes all operands of `op` have been vectorized and are in
1298 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1299 /// a topologically-sorted list of ops.
1300 /// This function does not update `bvm` but returns a VectorizationHookStatus
1301 /// that instructs the caller what `bvm` update needs to occur.
1304  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1305  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1306  LDBG() << "vectorize op " << *op;
1307 
1308  // 1. Try to apply any CustomVectorizationHook.
1309  if (!customVectorizationHooks.empty()) {
1310  for (auto &customFunc : customVectorizationHooks) {
1311  VectorizationHookResult result = customFunc(op, bvm);
1313  continue;
1314  return result;
1315  }
1316  }
1317 
1318  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1319  // Clone so that the constant is not confined to the linalgOp block .
1320  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1322  rewriter.clone(*op)};
1323 
1324  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1327 
1328  // 4 . Check if the operation is a reduction.
1329  SmallVector<std::pair<Value, Value>> reductionOperands;
1330  for (Value operand : op->getOperands()) {
1331  auto blockArg = dyn_cast<BlockArgument>(operand);
1332  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1333  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1334  continue;
1335  SmallVector<Operation *> reductionOps;
1336  Value reduceValue = matchReduction(
1337  linalgOp.getRegionOutputArgs(),
1338  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1339  if (!reduceValue)
1340  continue;
1341  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1342  }
1343  if (!reductionOperands.empty()) {
1344  assert(reductionOperands.size() == 1);
1345  Operation *reduceOp =
1346  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1347  reductionOperands[0].second, bvm);
1348  if (reduceOp)
1350  }
1351 
1352  // 5. Generic vectorization path for ElementwiseMappable ops.
1353  // a. Get the first max ranked shape.
1354  VectorType firstMaxRankedType;
1355  for (Value operand : op->getOperands()) {
1356  auto vecOperand = bvm.lookup(operand);
1357  assert(vecOperand && "Vector operand couldn't be found");
1358 
1359  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1360  if (vecType && (!firstMaxRankedType ||
1361  firstMaxRankedType.getRank() < vecType.getRank()))
1362  firstMaxRankedType = vecType;
1363  }
1364  // b. Broadcast each op if needed.
1365  SmallVector<Value> vecOperands;
1366  for (Value scalarOperand : op->getOperands()) {
1367  Value vecOperand = bvm.lookup(scalarOperand);
1368  assert(vecOperand && "Vector operand couldn't be found");
1369 
1370  if (firstMaxRankedType) {
1371  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1372  getElementTypeOrSelf(vecOperand.getType()),
1373  firstMaxRankedType.getScalableDims());
1374  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1375  } else {
1376  vecOperands.push_back(vecOperand);
1377  }
1378  }
1379  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1380  SmallVector<Type> resultTypes;
1381  for (Type resultType : op->getResultTypes()) {
1382  resultTypes.push_back(
1383  firstMaxRankedType
1384  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1385  firstMaxRankedType.getScalableDims())
1386  : resultType);
1387  }
1388  // d. Build and return the new op.
1389  return VectorizationHookResult{
1391  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1392  resultTypes, op->getAttrs())};
1393 }
1394 
1395 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1396 /// vector form. Generic vectorization proceeds as follows:
1397 /// 1. Verify the `linalgOp` has one non-empty region.
1398 /// 2. Values defined above the region are mapped to themselves and will be
1399 /// broadcasted on a per-need basis by their consumers.
1400 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1401 /// load).
1402 /// TODO: Reuse opportunities for RAR dependencies.
1403 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1404 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1405 /// iteration indices.
1406 /// 5. Iteratively call vectorizeOneOp on the region operations.
1407 ///
1408 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1409 /// performed to the maximal common vector size implied by the `linalgOp`
1410 /// iteration space. This eager broadcasting is introduced in the
1411 /// permutation_map of the vector.transfer_read operations. The eager
1412 /// broadcasting makes it trivial to determine where broadcast, transposes and
1413 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1414 /// the absence of good canonicalizations, the amount of work increases.
1415 /// This is not deemed a problem as we expect canonicalizations and foldings to
1416 /// aggressively clean up the useless work.
1417 static LogicalResult
1419  LinalgOp linalgOp,
1420  SmallVectorImpl<Value> &newResults) {
1421  LDBG() << "Vectorizing operation as linalg generic/n";
1422  Block *block = linalgOp.getBlock();
1423 
1424  // 2. Values defined above the region can only be broadcast for now. Make them
1425  // map to themselves.
1426  IRMapping bvm;
1427  SetVector<Value> valuesSet;
1428  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1429  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1430 
1431  if (linalgOp.getNumDpsInits() == 0)
1432  return failure();
1433 
1434  // 3. Turn all BBArgs into vector.transfer_read / load.
1435  Location loc = linalgOp.getLoc();
1436  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1437  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1438  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1439  if (linalgOp.isScalar(opOperand)) {
1440  bvm.map(bbarg, opOperand->get());
1441  continue;
1442  }
1443 
1444  // 3.a. Convert the indexing map for this input/output to a transfer read
1445  // permutation map and masking map.
1446  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1447 
1448  AffineMap readMap;
1449  VectorType readType;
1450  Type elemType = getElementTypeOrSelf(opOperand->get());
1451  if (linalgOp.isDpsInput(opOperand)) {
1452  // 3.a.i. For input reads we use the canonical vector shape.
1453  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1454  readType = state.getCanonicalVecType(elemType);
1455  } else {
1456  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1457  // reductions), the vector shape is computed by mapping the canonical
1458  // vector shape to the output domain and back to the canonical domain.
1459  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1460  readType =
1461  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1462  }
1463 
1464  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1465 
1466  Operation *read = vector::TransferReadOp::create(
1467  rewriter, loc, readType, opOperand->get(), indices,
1468  /*padding=*/std::nullopt, readMap);
1469  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1470  Value readValue = read->getResult(0);
1471 
1472  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1473  // will be in-bounds.
1474  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1475  SmallVector<bool> inBounds(readType.getRank(), true);
1476  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1477  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1478  }
1479 
1480  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1481  // TODO: remove this.
1482  if (readType.getRank() == 0)
1483  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1484  ArrayRef<int64_t>());
1485 
1486  LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1487  << "): " << readValue;
1488  bvm.map(bbarg, readValue);
1489  bvm.map(opOperand->get(), readValue);
1490  }
1491 
1493  // 4a. Register CustomVectorizationHook for yieldOp.
1494  CustomVectorizationHook vectorizeYield =
1495  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1496  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1497  };
1498  hooks.push_back(vectorizeYield);
1499 
1500  // 4b. Register CustomVectorizationHook for indexOp.
1501  CustomVectorizationHook vectorizeIndex =
1502  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1503  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1504  };
1505  hooks.push_back(vectorizeIndex);
1506 
1507  // 4c. Register CustomVectorizationHook for extractOp.
1508  CustomVectorizationHook vectorizeExtract =
1509  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1510  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1511  };
1512  hooks.push_back(vectorizeExtract);
1513 
1514  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1515  for (Operation &op : block->getOperations()) {
1516  VectorizationHookResult result =
1517  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1518  if (result.status == VectorizationHookStatus::Failure) {
1519  LDBG() << "failed to vectorize: " << op;
1520  return failure();
1521  }
1522  if (result.status == VectorizationHookStatus::NewOp) {
1523  Operation *maybeMaskedOp =
1524  state.maskOperation(rewriter, result.newOp, linalgOp);
1525  LDBG() << "New vector op: " << *maybeMaskedOp;
1526  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1527  }
1528  }
1529 
1530  return success();
1531 }
1532 
1533 /// Given a linalg::PackOp, return the `dest` shape before any packing
1534 /// permutations.
1535 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1536  ArrayRef<int64_t> destShape) {
1537  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1538 }
1539 
1540 /// Determines whether a mask for xfer_write is trivially "all true"
1541 ///
1542 /// Given all the inputs required to generate a mask (mask sizes and shapes),
1543 /// and an xfer_write operation (write indices and the destination tensor
1544 /// shape), determines whether the corresponding mask would be trivially
1545 /// foldable (i.e., trivially "all true").
1546 ///
1547 /// Use this method to avoid generating spurious masks and relaying on
1548 /// vectorization post-processing to remove them.
1549 ///
1550 /// Pre-conditions for a mask to be trivially foldable:
1551 /// * All involved shapes (mask + destination tensor) are static.
1552 /// * All write indices are constant.
1553 /// * All mask sizes are constant (including `arith.constant`).
1554 ///
1555 /// If the pre-conditions are met, the method checks for each destination
1556 /// dimension `d`:
1557 /// (1) destDimSize[rankDiff + d] <= maskShape[d]
1558 /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1559 ///
1560 /// rankDiff = rank(dest) - rank(mask).
1561 ///
1562 /// This method takes a conservative view: it may return false even if the mask
1563 /// is technically foldable.
1564 ///
1565 /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1566 /// of the dest tensor):
1567 /// %c0 = arith.constant 0 : index
1568 /// %mask = vector.create_mask 5, 1
1569 /// vector.mask %mask {
1570 /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1571 /// {in_bounds = [true, true]}
1572 /// : vector<5x1xi32>, tensor<5x1xi32>
1573 /// }
1574 ///
1575 /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1576 /// mask is required to avoid out-of-bounds write):
1577 /// %c0 = arith.constant 0 : index
1578 /// %mask = vector.create_mask 5, 1
1579 /// vector.mask %mask {
1580 /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1581 /// {in_bounds = [true, true]}
1582 /// : vector<8x1xi32>, tensor<5x1xi32>
1583 /// }
1584 ///
1585 /// TODO: Re-use in createReadOrMaskedRead
1587  SmallVector<Value> &writeIdxs,
1588  ArrayRef<int64_t> destShape,
1589  ArrayRef<int64_t> maskShape) {
1590  // Masking is unavoidable in the case of dynamic tensors.
1591  if (ShapedType::isDynamicShape(destShape))
1592  return false;
1593 
1594  // Collect all constant mask sizes.
1595  SmallVector<int64_t, 4> cstMaskSizes;
1596  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1597  if (auto intSize = getConstantIntValue(dimSize)) {
1598  cstMaskSizes.push_back(*intSize);
1599  }
1600  }
1601 
1602  // If any of the mask sizes is non-constant, bail out.
1603  if (cstMaskSizes.size() != maskShape.size())
1604  return false;
1605 
1606  // Collect all constant write indices.
1607  SmallVector<int64_t, 4> cstWriteIdxs;
1608  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1609  APSInt intVal;
1610  if (matchPattern(idx, m_ConstantInt(&intVal))) {
1611  cstWriteIdxs.push_back(intVal.getSExtValue());
1612  }
1613  }
1614 
1615  // If any of the write indices is non-constant, bail out.
1616  if (cstWriteIdxs.size() != destShape.size())
1617  return false;
1618 
1619  // Go over all destination dims and check (1) and (2). Take into account that:
1620  // * The number of mask sizes will match the rank of the vector to store.
1621  // This could be lower than the rank of the destination tensor.
1622  // * Mask sizes could be larger than the corresponding mask shape (hence
1623  // `clamp`).
1624  // TODO: The 2nd item should be rejected by the verifier.
1625  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1626  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1627  if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1628  /*(2)*/ destShape[rankDiff + i] <
1629  (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1630  cstWriteIdxs[i]))
1631  return false;
1632  }
1633 
1634  return true;
1635 }
1636 
1637 /// Creates an optionally masked TransferWriteOp
1638 ///
1639 /// Generates the following operation:
1640 /// %res = vector.transfer_write %vecToStore into %dest
1641 ///
1642 /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1643 ///
1644 /// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1645 /// %res = vector.mask %mask {
1646 /// vector.transfer_write %vecToStore into %dest
1647 /// }
1648 ///
1649 /// The mask shape is identical to `vecToStore` (with the element type ==
1650 /// i1), and the mask values are based on the shape of the `dest` tensor.
1651 ///
1652 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1653 /// is used instead of masking:
1654 ///
1655 /// %write = vector.transfer_write %vecToStore into %dest
1656 /// in_bounds_flags = (...)
1657 /// %res = vector.transfer_write %input into %dest
1658 /// {in_bounds = in_bounds_flags}
1659 ///
1660 /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1661 /// are set to 0.
1662 static Operation *
1664  Value dest, SmallVector<Value> writeIndices = {},
1665  bool useInBoundsInsteadOfMasking = false) {
1666 
1667  ShapedType destType = cast<ShapedType>(dest.getType());
1668  int64_t destRank = destType.getRank();
1669  auto destShape = destType.getShape();
1670 
1671  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1672  int64_t vecToStoreRank = vecToStoreType.getRank();
1673  auto vecToStoreShape = vecToStoreType.getShape();
1674 
1675  // Compute the in_bounds attribute
1676  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1677  if (useInBoundsInsteadOfMasking) {
1678  // Update the inBounds attribute.
1679  // FIXME: This computation is too weak - it ignores the write indices.
1680  for (unsigned i = 0; i < vecToStoreRank; i++)
1681  inBoundsVal[i] =
1682  (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1683  ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1684  }
1685 
1686  // If missing, initialize the write indices to 0.
1687  assert((writeIndices.empty() ||
1688  writeIndices.size() == static_cast<size_t>(destRank)) &&
1689  "Invalid number of write indices!");
1690  if (writeIndices.empty()) {
1691  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1692  writeIndices.assign(destRank, zero);
1693  }
1694 
1695  // Generate the xfer_write Op
1696  Operation *write = vector::TransferWriteOp::create(builder, loc,
1697  /*vector=*/vecToStore,
1698  /*source=*/dest,
1699  /*indices=*/writeIndices,
1700  /*inBounds=*/inBoundsVal);
1701 
1702  // If masking is disabled, exit.
1703  if (useInBoundsInsteadOfMasking)
1704  return write;
1705 
1706  // Check if masking is needed. If not, exit.
1707  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1708  return write;
1709 
1710  // Compute the mask and mask the write Op.
1711  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1712  vecToStoreType.getScalableDims());
1713 
1714  SmallVector<OpFoldResult> destSizes =
1715  isa<MemRefType>(dest.getType())
1716  ? memref::getMixedSizes(builder, loc, dest)
1717  : tensor::getMixedSizes(builder, loc, dest);
1718  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1719  destSizes.end());
1720 
1721  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1722  vecToStoreShape))
1723  return write;
1724 
1725  Value maskForWrite =
1726  builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1727  return mlir::vector::maskOperation(builder, write, maskForWrite);
1728 }
1729 
1730 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1731 /// padding value and (3) input vector sizes into:
1732 ///
1733 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1734 ///
1735 /// As in the following example:
1736 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1737 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1738 ///
1739 /// This pack would be vectorized to:
1740 ///
1741 /// %load = vector.mask %mask {
1742 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1743 /// {in_bounds = [true, true, true]} :
1744 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1745 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1746 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1747 /// to vector<32x4x2x1x16xf32>
1748 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1749 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1750 /// %write = vector.transfer_write %transpose,
1751 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1752 /// {in_bounds = [true, true, true, true, true]}
1753 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1754 ///
1755 /// If the (3) input vector sizes are not provided, the vector sizes are
1756 /// determined by the result tensor shape and the `in_bounds`
1757 /// attribute is used instead of masking to mark out-of-bounds accesses.
1758 ///
1759 /// NOTE: The input vector sizes specify the dimensions corresponding to the
1760 /// outer dimensions of the output tensor. The remaining dimensions are
1761 /// computed based on, e.g., the static inner tiles.
1762 /// Supporting dynamic inner tiles will require the user to specify the
1763 /// missing vector sizes. This is left as a TODO.
1764 static LogicalResult
1765 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1766  ArrayRef<int64_t> inputVectorSizes,
1767  SmallVectorImpl<Value> &newResults) {
1768  // TODO: Introduce a parent class that will handle the insertion point update.
1769  OpBuilder::InsertionGuard g(rewriter);
1770  rewriter.setInsertionPoint(packOp);
1771 
1772  Location loc = packOp.getLoc();
1773  auto padValue = packOp.getPaddingValue();
1774  if (!padValue) {
1775  padValue = arith::ConstantOp::create(
1776  rewriter, loc,
1777  rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1778  }
1779 
1780  // If the input vector sizes are not provided, then the vector sizes are
1781  // determined by the result tensor shape. In case the vector sizes aren't
1782  // provided, we update the inBounds attribute instead of masking.
1783  bool useInBoundsInsteadOfMasking = false;
1784  if (inputVectorSizes.empty()) {
1785  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1786  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1787  useInBoundsInsteadOfMasking = true;
1788  }
1789 
1790  // Create masked TransferReadOp.
1791  SmallVector<int64_t> inputShape(inputVectorSizes);
1792  auto innerTiles = packOp.getStaticInnerTiles();
1793  auto innerDimsPos = packOp.getInnerDimsPos();
1794  auto outerDimsPerm = packOp.getOuterDimsPerm();
1795  if (!outerDimsPerm.empty())
1796  applyPermutationToVector(inputShape,
1798  for (auto [idx, size] : enumerate(innerTiles))
1799  inputShape[innerDimsPos[idx]] *= size;
1800  auto maskedRead = vector::createReadOrMaskedRead(
1801  rewriter, loc, packOp.getSource(), inputShape, padValue,
1802  useInBoundsInsteadOfMasking,
1803  /*inputScalableVecSizes=*/{});
1804 
1805  // Create ShapeCastOp.
1806  SmallVector<int64_t> destShape(inputVectorSizes);
1807  destShape.append(innerTiles.begin(), innerTiles.end());
1808  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1809  packOp.getDestType().getElementType());
1810  auto shapeCastOp =
1811  vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1812 
1813  // Create TransposeOp.
1814  auto destPermutation =
1816  auto transposeOp = vector::TransposeOp::create(
1817  rewriter, loc, shapeCastOp.getResult(), destPermutation);
1818 
1819  // Create TransferWriteOp.
1821  rewriter, loc, transposeOp.getResult(), packOp.getDest());
1822  newResults.push_back(write->getResult(0));
1823  return success();
1824 }
1825 
1826 /// Given the re-associations, "collapses" the input Vector type
1827 ///
1828 /// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1829 /// differences:
1830 /// * We can safely assume that there are no dynamic sizes.
1831 /// * Scalable flags are updated alongside regular dims.
1832 ///
1833 /// When collapsing scalable flags, conservatively avoids cases with two
1834 /// scalable dims. We could re-visit this in the future.
1835 ///
1836 /// EXAMPLE:
1837 /// type = vector<4x16x[8]x16xf32>
1838 /// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1839 /// (d0, d1, d2, d3) -> (d2, d3)]
1840 /// Result:
1841 /// vector<64x[128]xf32>
1842 static VectorType getCollapsedVecType(VectorType type,
1843  ArrayRef<AffineMap> reassociation) {
1844  assert(type.getNumScalableDims() < 2 &&
1845  "Collapsing more than 1 scalable dim is not supported ATM");
1846 
1847  // Use the fact that reassociation is valid to simplify the logic: only use
1848  // each map's rank.
1849  assert(isReassociationValid(reassociation) && "invalid reassociation");
1850 
1851  auto shape = type.getShape();
1852  auto scalableFlags = type.getScalableDims();
1853  SmallVector<int64_t> newShape;
1854  SmallVector<bool> newScalableFlags;
1855 
1856  unsigned currentDim = 0;
1857  for (AffineMap m : reassociation) {
1858  unsigned dim = m.getNumResults();
1859  int64_t size = 1;
1860  bool flag = false;
1861  for (unsigned d = 0; d < dim; ++d) {
1862  size *= shape[currentDim + d];
1863  flag |= scalableFlags[currentDim + d];
1864  }
1865  newShape.push_back(size);
1866  newScalableFlags.push_back(flag);
1867  currentDim += dim;
1868  }
1869 
1870  return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1871 }
1872 
1873 /// Vectorize `linalg.unpack` as:
1874 /// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1875 ///
1876 /// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1877 /// for the xfer_read operation). This is sufficient to infer the other vector
1878 /// sizes required here.
1879 ///
1880 /// If the vector sizes are not provided:
1881 /// * the vector sizes are determined from the input tensor static shape.
1882 /// * the inBounds attribute is used instead of masking.
1883 ///
1884 /// EXAMPLE (no vector sizes):
1885 /// ```
1886 /// %unpack = linalg.unpack %src
1887 /// inner_dims_pos = [0, 1]
1888 /// inner_tiles = [8, 8]
1889 /// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1890 /// ```
1891 /// is vectorized as:
1892 /// ```
1893 /// %read = vector.transfer_read %src
1894 /// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1895 /// %tr = vector.transpose %read, [0, 2, 1, 3]
1896 /// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1897 /// %sc = vector.shape_cast %tr
1898 /// : vector<1x8x1x8xf32> to vector<8x8xf32>
1899 /// %vector = vector.transfer_write %sc into %dest
1900 /// : vector<8x8xf32>, tensor<8x8xf32>
1901 /// ```
1902 static LogicalResult
1903 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1904  ArrayRef<int64_t> inputVectorSizes,
1905  ArrayRef<bool> inputScalableVecDims,
1906  SmallVectorImpl<Value> &newResults) {
1907  if (!inputVectorSizes.empty()) {
1908  assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1909  "Invalid number of input vector sizes!");
1910  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1911  "Incompatible number of vector sizes and vector scalable flags!");
1912  }
1913 
1914  // TODO: Introduce a parent class that will handle the insertion point update.
1915  OpBuilder::InsertionGuard g(rewriter);
1916  rewriter.setInsertionPoint(unpackOp);
1917 
1918  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1919 
1920  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1921  bool useInBoundsInsteadOfMasking = false;
1922 
1923  Location loc = unpackOp->getLoc();
1924 
1925  // Obtain vector sizes for the read operation.
1926  SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1927  SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1928 
1929  // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1930  if (inputVectorSizes.empty()) {
1931  if (ShapedType::isDynamicShape(sourceShape))
1932  return failure();
1933 
1934  readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1935  useInBoundsInsteadOfMasking = true;
1936  }
1937 
1938  // -- Generate the read operation --
1939  auto padValue = arith::ConstantOp::create(
1940  rewriter, loc,
1941  rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1943  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1944  useInBoundsInsteadOfMasking, readScalableVectorFlags);
1945 
1946  // -- Generate the transpose operation --
1947  PackingMetadata packMetadata;
1948  SmallVector<int64_t> lastDimToInsertPosPerm =
1949  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1950  vector::TransposeOp transposeOp = vector::TransposeOp::create(
1951  rewriter, loc, readResult, lastDimToInsertPosPerm);
1952 
1953  // -- Generate the shape_cast operation --
1954  VectorType collapsedVecType = getCollapsedVecType(
1955  transposeOp.getType(),
1957  rewriter.getContext(), packMetadata.reassociations)));
1958  vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1959  rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1960 
1961  // -- Generate the write operation --
1963  rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1964  /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1965 
1966  newResults.push_back(write->getResult(0));
1967  return success();
1968 }
1969 
1970 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1971 /// and (3) all-zero lowPad to
1972 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1973 static LogicalResult
1974 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1975  ArrayRef<int64_t> inputVectorSizes,
1976  SmallVectorImpl<Value> &newResults) {
1977  auto padValue = padOp.getConstantPaddingValue();
1978  Location loc = padOp.getLoc();
1979 
1980  // TODO: Introduce a parent class that will handle the insertion point update.
1981  OpBuilder::InsertionGuard g(rewriter);
1982  rewriter.setInsertionPoint(padOp);
1983 
1984  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1985  LogicalResult status =
1986  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1987  .reifyResultShapes(rewriter, reifiedReturnShapes);
1988  (void)status; // prevent unused variable warning on non-assert builds
1989  assert(succeeded(status) && "failed to reify result shapes");
1990  auto maskedRead = vector::createReadOrMaskedRead(
1991  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1992  /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
1993 
1994  // Create Xfer write Op
1995  Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
1996  padOp.getResultType().getElementType());
1997  Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
1998  newResults.push_back(write->getResult(0));
1999  return success();
2000 }
2001 
2002 // TODO: probably need some extra checks for reduction followed by consumer
2003 // ops that may not commute (e.g. linear reduction + non-linear instructions).
2004 static LogicalResult reductionPreconditions(LinalgOp op) {
2005  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2006  LDBG() << "reduction precondition failed: no reduction iterator";
2007  return failure();
2008  }
2009  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2010  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2011  if (indexingMap.isPermutation())
2012  continue;
2013 
2014  Operation *reduceOp = matchLinalgReduction(&opOperand);
2015  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2016  LDBG() << "reduction precondition failed: reduction detection failed";
2017  return failure();
2018  }
2019  }
2020  return success();
2021 }
2022 
2023 static LogicalResult
2025  bool flatten1DDepthwiseConv) {
2026  if (flatten1DDepthwiseConv) {
2027  LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2028  "supported";
2029  return failure();
2030  }
2031 
2032  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2033  LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2034  return failure();
2035  }
2036 
2037  // Support dynamic shapes in 1D depthwise convolution, but only in the
2038  // _channel_ dimension.
2039  Value lhs = conv.getDpsInputOperand(0)->get();
2040  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2041  auto shapeWithoutCh = lhsShape.drop_back(1);
2042  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2043  LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2044  "channel dim can be dynamic";
2045  return failure();
2046  }
2047 
2048  return success();
2049 }
2050 
2051 static LogicalResult
2053  bool flatten1DDepthwiseConv) {
2054  if (isa<ConvolutionOpInterface>(op.getOperation()))
2055  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2056 
2057  if (hasReductionIterator(op))
2058  return reductionPreconditions(op);
2059 
2060  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2061  // linalg.copy ops and ops that implement ContractionOpInterface for now.
2062  if (!isElementwise(op) &&
2063  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2064  op.getOperation()))
2065  return failure();
2066 
2067  LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2068  return success();
2069 }
2070 
2071 //// This hook considers two cases:
2072 /// (1) If the input-vector-sizes are empty, then the vector sizes will be
2073 /// infered. This is only possible when all shapes are static.
2074 /// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2075 /// carry out basic sanity-checking.
2076 static LogicalResult
2077 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2078  ArrayRef<int64_t> inputVectorSizes) {
2079  // If there are no input vector sizes and all shapes are static, there is
2080  // nothing left to check.
2081  if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2082  unpackOp.getSourceType().hasStaticShape())
2083  return success();
2084 
2085  // The number of input vector sizes must be equal to:
2086  // * read-vector-rank
2087  if (!inputVectorSizes.empty() &&
2088  (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2089  LDBG() << "Incorrect number of input vector sizes";
2090  return failure();
2091  }
2092 
2093  // Check the vector sizes for the read operation.
2095  unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2096  LDBG() << "Invalid vector sizes for the read operation";
2097  return failure();
2098  }
2099 
2100  return success();
2101 }
2102 
2103 static LogicalResult
2104 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2105  ArrayRef<int64_t> inputVectorSizes) {
2106 
2107  TypedValue<RankedTensorType> source = sliceOp.getSource();
2108  auto sourceType = source.getType();
2109  if (!VectorType::isValidElementType(sourceType.getElementType()))
2110  return failure();
2111 
2112  // Get the pad value.
2113  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2114  // scalar padding value. Note that:
2115  // * for in-bounds accesses,
2116  // the value is actually irrelevant. There are 2 cases in which xfer.read
2117  // accesses are known to be in-bounds:
2118  // 1. The source shape is static (output vector sizes would be based on
2119  // the source shape and hence all memory accesses would be in-bounds),
2120  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2121  // this case it is safe to assume that all memory accesses are in-bounds.
2122  //
2123  // When the value is not known and not needed, use 0. Otherwise, bail out.
2124  Value padValue = getStaticPadVal(sliceOp);
2125  bool isOutOfBoundsRead =
2126  !sourceType.hasStaticShape() && inputVectorSizes.empty();
2127 
2128  if (!padValue && isOutOfBoundsRead) {
2129  LDBG() << "Failed to get a pad value for out-of-bounds read access";
2130  return failure();
2131  }
2132  return success();
2133 }
2134 
2135 /// Vectorize a named linalg contraction op into:
2136 /// vector::TransferReadOp - Reads vectors from the operands
2137 /// vector::ContractionOp - Performs contraction
2138 /// vector::TransferWriteOp - Write the result vector back to the
2139 /// destination
2140 /// The operands shapes are preserved and loaded directly into vectors.
2141 /// Any further permutations or numerical casting remain within contraction op.
2142 static LogicalResult
2144  LinalgOp linalgOp,
2145  SmallVectorImpl<Value> &newResults) {
2146  Location loc = linalgOp.getLoc();
2147  MLIRContext *ctx = linalgOp.getContext();
2148 
2149  // For simplicity, contraction vectorization is limited to linalg named ops.
2150  // Generic op is ignored as not every arbitrary contraction body can be
2151  // expressed by a vector.contract.
2152  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2153  return failure();
2154 
2155  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2156  Operation *reduceOp = matchLinalgReduction(outOperand);
2157  auto maybeKind = getCombinerOpKind(reduceOp);
2158  if (!maybeKind) {
2159  LDBG() << "Failed to determine contraction combining kind.";
2160  return failure();
2161  }
2162 
2163  // Check that all dimensions are present in the input operands.
2164  // Arbitrary broadcasts are not supported by the vector contraction.
2165  // Broadcasts are expected to be decomposed before vectorization.
2166  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2167  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2168  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2169  LDBG() << "Contractions with broadcasts are not supported.";
2170  return failure();
2171  }
2172 
2173  // Load operands.
2174  SmallVector<Value> vecOperands;
2175  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2176  // The operand vector shape is computed by mapping the canonical vector
2177  // shape to the operand's domain. Further permutations are left as a part of
2178  // the contraction.
2179  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2181  indexingMap.getNumResults(), rewriter.getContext());
2182  Type elemType = getElementTypeOrSelf(opOperand.get());
2183  VectorType readType =
2184  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2185 
2187  rewriter, loc, opOperand.get(), readType.getShape(),
2188  /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2189  /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2190  vecOperands.push_back(read);
2191  }
2192 
2193  // Remap iterators from linalg to vector.
2194  SmallVector<Attribute> iterAttrs;
2195  auto iterators = linalgOp.getIteratorTypesArray();
2196  for (utils::IteratorType iter : iterators) {
2197  auto vecIter = iter == utils::IteratorType::parallel
2198  ? vector::IteratorType::parallel
2199  : vector::IteratorType::reduction;
2200  iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2201  }
2202 
2203  // Create contraction.
2204  Operation *contractOp = vector::ContractionOp::create(
2205  rewriter, loc, /*lhs=*/vecOperands[0],
2206  /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2207  linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2208  contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2209 
2210  // Store result.
2212  rewriter, loc, contractOp->getResult(0), outOperand->get());
2213 
2214  // Finalize.
2215  if (!write->getResults().empty())
2216  newResults.push_back(write->getResult(0));
2217 
2218  return success();
2219 }
2220 
2221 namespace {
2222 enum class ConvOperationKind { Conv, Pool };
2223 } // namespace
2224 
2226  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2227  isa<BlockArgument>(op->getOperand(0));
2228 }
2229 
2230 // Returns the ConvOperationKind of the op using reduceOp of the generic
2231 // payload. If it is neither a convolution nor a pooling, it returns
2232 // std::nullopt.
2233 //
2234 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2235 // + yield) and rhs is not used) then it is the body of a pooling
2236 // If conv, check for single `mul` predecessor. The `mul` operands must be
2237 // block arguments or extension of block arguments.
2238 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2239 // must be block arguments or extension of block arguments.
2240 static std::optional<ConvOperationKind>
2242  int numBlockArguments =
2243  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2244 
2245  switch (numBlockArguments) {
2246  case 1: {
2247  // Will be convolution if feeder is a MulOp.
2248  // A strength reduced version of MulOp for i1 type is AndOp which is also
2249  // supported. Otherwise, it can be pooling. This strength reduction logic
2250  // is in `buildBinaryFn` helper in the Linalg dialect.
2251  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2252  llvm::IsaPred<BlockArgument>);
2253  assert(feedValIt != reduceOp->operand_end() &&
2254  "Expected a non-block argument operand");
2255  Operation *feedOp = (*feedValIt).getDefiningOp();
2256  if (isCastOfBlockArgument(feedOp)) {
2257  return ConvOperationKind::Pool;
2258  }
2259 
2260  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2261  (isa<arith::AndIOp>(feedOp) &&
2262  feedOp->getResultTypes()[0].isInteger(1))) &&
2263  llvm::all_of(feedOp->getOperands(), [](Value v) {
2264  if (isa<BlockArgument>(v))
2265  return true;
2266  if (Operation *op = v.getDefiningOp())
2267  return isCastOfBlockArgument(op);
2268  return false;
2269  }))) {
2270  return std::nullopt;
2271  }
2272 
2273  return ConvOperationKind::Conv;
2274  }
2275  case 2:
2276  // Must be pooling
2277  return ConvOperationKind::Pool;
2278  default:
2279  return std::nullopt;
2280  }
2281 }
2282 
2283 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2284  switch (kind) {
2285  case vector::CombiningKind::ADD:
2286  case vector::CombiningKind::MAXNUMF:
2287  case vector::CombiningKind::MAXIMUMF:
2288  case vector::CombiningKind::MAXSI:
2289  case vector::CombiningKind::MAXUI:
2290  case vector::CombiningKind::MINNUMF:
2291  case vector::CombiningKind::MINIMUMF:
2292  case vector::CombiningKind::MINSI:
2294  return true;
2295  default:
2296  return false;
2297  }
2298 }
2299 
2300 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2301  auto getOperandType = [&](auto operand) {
2302  return dyn_cast<ShapedType>((operand->get()).getType());
2303  };
2304  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2305  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2306  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2307  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2308  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2309  // Note that this also ensures 2D and 3D convolutions are rejected.
2310  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2311  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2312  return failure();
2313 
2314  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2315  if (!reduceOp)
2316  return failure();
2317 
2318  auto maybeOper = getConvOperationKind(reduceOp);
2319  if (!maybeOper.has_value())
2320  return failure();
2321 
2322  auto maybeKind = getCombinerOpKind(reduceOp);
2323  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2324  // can get strength reduced to `OR` which is also supported. This strength
2325  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2326  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2327  *maybeKind != vector::CombiningKind::OR) &&
2328  (*maybeOper != ConvOperationKind::Pool ||
2329  !isSupportedPoolKind(*maybeKind)))) {
2330  return failure();
2331  }
2332 
2333  auto rhsRank = rhsShapedType.getRank();
2334  if (*maybeOper == ConvOperationKind::Pool) {
2335  if (rhsRank != 1)
2336  return failure();
2337  } else {
2338  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2339  return failure();
2340  }
2341 
2342  return success();
2343 }
2344 
2345 static LogicalResult vectorizeLinalgOpPrecondition(
2346  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2347  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2348  // tensor with dimension of 0 cannot be vectorized.
2349  if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2350  return llvm::is_contained(linalgOp.getShape(&operand), 0);
2351  }))
2352  return failure();
2353  // Check API contract for input vector sizes.
2354  if (!inputVectorSizes.empty() &&
2355  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2356  inputVectorSizes)))
2357  return failure();
2358 
2359  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2360  linalgOp, flatten1DDepthwiseConv))) {
2361  LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2362  return failure();
2363  }
2364 
2366 
2367  // Register CustomVectorizationPrecondition for extractOp.
2368  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2369 
2370  // All types in the body should be a supported element type for VectorType.
2371  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2372  // Check if any custom hook can vectorize the inner op.
2373  if (llvm::any_of(
2374  customPreconditions,
2375  [&](const CustomVectorizationPrecondition &customPrecondition) {
2376  return succeeded(
2377  customPrecondition(&innerOp, vectorizeNDExtract));
2378  })) {
2379  continue;
2380  }
2381  if (!llvm::all_of(innerOp.getOperandTypes(),
2382  VectorType::isValidElementType)) {
2383  return failure();
2384  }
2385  if (!llvm::all_of(innerOp.getResultTypes(),
2386  VectorType::isValidElementType)) {
2387  return failure();
2388  }
2389  }
2390  if (isElementwise(linalgOp))
2391  return success();
2392 
2393  // TODO: isaConvolutionOpInterface that can also infer from generic
2394  // features. But we will still need stride/dilation attributes that will be
2395  // annoying to reverse-engineer...
2396  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2397  return vectorizeConvOpPrecondition(linalgOp);
2398 
2399  // TODO: the common vector shape is equal to the static loop sizes only when
2400  // all indexing maps are projected permutations. For convs and stencils the
2401  // logic will need to evolve.
2402  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2403  LDBG() << "precondition failed: not projected permutations";
2404  return failure();
2405  }
2406  if (failed(reductionPreconditions(linalgOp))) {
2407  LDBG() << "precondition failed: reduction preconditions";
2408  return failure();
2409  }
2410  return success();
2411 }
2412 
2413 static LogicalResult
2414 vectorizePackOpPrecondition(linalg::PackOp packOp,
2415  ArrayRef<int64_t> inputVectorSizes) {
2416  auto padValue = packOp.getPaddingValue();
2417  Attribute cstAttr;
2418  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2419  LDBG() << "pad value is not constant: " << packOp;
2420  return failure();
2421  }
2422 
2423  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2424  bool satisfyEmptyCond = true;
2425  if (inputVectorSizes.empty()) {
2426  if (!packOp.getDestType().hasStaticShape() ||
2427  !packOp.getSourceType().hasStaticShape())
2428  satisfyEmptyCond = false;
2429  }
2430 
2431  if (!satisfyEmptyCond &&
2433  resultTensorShape.take_front(packOp.getSourceRank()),
2434  inputVectorSizes)))
2435  return failure();
2436 
2437  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2438  return !getConstantIntValue(v).has_value();
2439  })) {
2440  LDBG() << "inner_tiles must be constant: " << packOp;
2441  return failure();
2442  }
2443 
2444  return success();
2445 }
2446 
2447 static LogicalResult
2448 vectorizePadOpPrecondition(tensor::PadOp padOp,
2449  ArrayRef<int64_t> inputVectorSizes) {
2450  auto padValue = padOp.getConstantPaddingValue();
2451  if (!padValue) {
2452  LDBG() << "pad value is not constant: " << padOp;
2453  return failure();
2454  }
2455 
2456  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2457  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2458  inputVectorSizes)))
2459  return failure();
2460 
2461  // Padding with non-zero low pad values is not supported, unless the
2462  // corresponding result dim is 1 as this would require shifting the results to
2463  // the right for the low padded dims by the required amount of low padding.
2464  // However, we do support low padding if the dims being low padded have result
2465  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2466  // input size of that dimension will be dynamically zero (as the sum of the
2467  // low pad and input dim size has to be one) and hence we will create a zero
2468  // mask as the lowering logic just makes the mask one for the input dim size -
2469  // which is zero here. Hence we will load the pad value which is what we want
2470  // in this case. If the low pad is dynamically zero then the lowering is
2471  // correct as well as no shifts are necessary.
2472  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2473  Value padValue = en.value();
2474  unsigned pos = en.index();
2475  std::optional<int64_t> pad = getConstantIntValue(padValue);
2476  return (!pad.has_value() || pad.value() != 0) &&
2477  resultTensorShape[pos] != 1;
2478  })) {
2479  LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2480  return failure();
2481  }
2482 
2483  return success();
2484 }
2485 
2486 /// Preconditions for scalable vectors.
2487 ///
2488 /// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2489 /// models the fact that in practice we would only make selected dimensions
2490 /// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2491 /// unconditionally - we are yet to identify meaningful conditions.
2492 static LogicalResult
2494  ArrayRef<int64_t> inputVectorSizes,
2495  ArrayRef<bool> inputScalableVecDims) {
2496  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2497  "Number of input vector sizes and scalable dims doesn't match");
2498 
2499  size_t numOfScalableDims =
2500  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2501 
2502  if (numOfScalableDims == 0)
2503  return success();
2504 
2505  auto linalgOp = dyn_cast<LinalgOp>(op);
2506 
2507  // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2508  // exception of UnpackOp for which there is a dedicated hook.
2509  if (!linalgOp) {
2510  return success(isa<linalg::UnPackOp>(op));
2511  }
2512 
2513  // Cond 2: There's been no need for more than 2 scalable dims so far
2514  if (numOfScalableDims > 2)
2515  return failure();
2516 
2517  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2518  // it matches one of the supported cases:
2519  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2520  // (*).
2521  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2522  // parallel dims.
2523  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2524  // dim.
2525  // The 2nd restriction above means that only Matmul-like Ops are supported
2526  // when 2 dims are scalable, e.g. :
2527  // * iterators = [parallel, parallel, reduction]
2528  // * scalable flags = [true, true, false]
2529  //
2530  // (*) Non-unit dims get folded away in practice.
2531  // TODO: Relax these conditions as good motivating examples are identified.
2532 
2533  // Find the first scalable flag.
2534  bool seenNonUnitParallel = false;
2535  auto iterators = linalgOp.getIteratorTypesArray();
2536  SmallVector<bool> scalableFlags(inputScalableVecDims);
2537  int64_t idx = scalableFlags.size() - 1;
2538  while (!scalableFlags[idx]) {
2539  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2540  seenNonUnitParallel |=
2541  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2542 
2543  iterators.pop_back();
2544  scalableFlags.pop_back();
2545  --idx;
2546  }
2547 
2548  // Analyze the iterator corresponding to the first scalable dim.
2549  switch (iterators.back()) {
2550  case utils::IteratorType::reduction: {
2551  // Check 3. above is met.
2552  if (iterators.size() != inputVectorSizes.size()) {
2553  LDBG() << "Non-trailing reduction dim requested for scalable "
2554  "vectorization";
2555  return failure();
2556  }
2557  if (isa<linalg::MatmulOp>(op)) {
2558  LDBG()
2559  << "Scalable vectorization of the reduction dim in Matmul-like ops "
2560  "is not supported";
2561  return failure();
2562  }
2563  break;
2564  }
2565  case utils::IteratorType::parallel: {
2566  // Check 1. and 2. above are met.
2567  if (seenNonUnitParallel) {
2568  LDBG() << "Inner parallel dim not requested for scalable "
2569  "vectorization";
2570  return failure();
2571  }
2572  break;
2573  }
2574  }
2575 
2576  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2577  // supported for which expect the folowing config:
2578  // * iterators = [parallel, parallel, reduction]
2579  // * scalable flags = [true, true, false]
2580  if (numOfScalableDims == 2) {
2581  // Disallow below case which breaks 3. above:
2582  // * iterators = [..., parallel, reduction]
2583  // * scalable flags = [..., true, true]
2584  if (iterators.back() == utils::IteratorType::reduction) {
2585  LDBG() << "Higher dim than the trailing reduction dim requested for "
2586  "scalable "
2587  "vectorizatio";
2588  return failure();
2589  }
2590  scalableFlags.pop_back();
2591  iterators.pop_back();
2592 
2593  if (!scalableFlags.back() ||
2594  (iterators.back() != utils::IteratorType::parallel))
2595  return failure();
2596  }
2597 
2598  // Cond 4: Only the following ops are supported in the
2599  // presence of scalable vectors
2600  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2601  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2602  isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2603  isa<linalg::BatchMmt4DOp>(op) ||
2604  hasReductionIterator(linalgOp));
2605 }
2606 
2608  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2609  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2610  bool flatten1DDepthwiseConv) {
2611 
2612  if (!hasVectorizationImpl(op))
2613  return failure();
2614 
2615  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2616  inputScalableVecDims)))
2617  return failure();
2618 
2620  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2621  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2622  vectorizeNDExtract,
2623  flatten1DDepthwiseConv);
2624  })
2625  .Case<tensor::PadOp>([&](auto padOp) {
2626  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2627  })
2628  .Case<linalg::PackOp>([&](auto packOp) {
2629  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2630  })
2631  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2632  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2633  })
2634  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2635  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2636  })
2637  .Default([](auto) { return failure(); });
2638 }
2639 
2640 /// Converts affine.apply Ops to arithmetic operations.
2641 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2642  OpBuilder::InsertionGuard g(rewriter);
2643  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2644 
2645  for (auto op : make_early_inc_range(toReplace)) {
2646  rewriter.setInsertionPoint(op);
2647  auto expanded = affine::expandAffineExpr(
2648  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2649  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2650  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2651  rewriter.replaceOp(op, expanded);
2652  }
2653 }
2654 
2656  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2657  tensor::InsertSliceOp>(op);
2658 }
2659 
2660 FailureOr<VectorizationResult> mlir::linalg::vectorize(
2661  RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2662  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2663  bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2664  bool createNamedContraction) {
2665  LDBG() << "Attempting to vectorize: " << *op;
2666  LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2667  LDBG() << "Input scalable vector dims: "
2668  << llvm::interleaved(inputScalableVecDims);
2669 
2670  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2671  vectorizeNDExtract,
2672  flatten1DDepthwiseConv))) {
2673  LDBG() << "Vectorization pre-conditions failed";
2674  return failure();
2675  }
2676 
2677  // Initialize vectorization state.
2678  VectorizationState state(rewriter);
2679  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2680  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2681  inputScalableVecDims,
2682  assumeDynamicDimsMatchVecSizes))) {
2683  LDBG() << "Vectorization state couldn't be initialized";
2684  return failure();
2685  }
2686  }
2687 
2688  SmallVector<Value> results;
2689  auto vectorizeResult =
2691  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2692  // TODO: isaConvolutionOpInterface that can also infer from
2693  // generic features. Will require stride/dilation attributes
2694  // inference.
2695  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2696  FailureOr<Operation *> convOr = vectorizeConvolution(
2697  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2698  flatten1DDepthwiseConv);
2699  if (succeeded(convOr)) {
2700  llvm::append_range(results, (*convOr)->getResults());
2701  return success();
2702  }
2703 
2704  LDBG() << "Unsupported convolution can't be vectorized.";
2705  return failure();
2706  }
2707 
2708  if (createNamedContraction &&
2709  isa<ContractionOpInterface>(linalgOp.getOperation()))
2710  return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2711  results);
2712 
2713  LDBG()
2714  << "Vectorize generic by broadcasting to the canonical vector "
2715  "shape";
2716 
2717  // Pre-process before proceeding.
2718  convertAffineApply(rewriter, linalgOp);
2719 
2720  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2721  // to 'OpBuilder' when it is passed over to some methods like
2722  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2723  // erase an op within these methods, the actual rewriter won't be
2724  // notified and we will end up with read-after-free issues!
2725  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2726  })
2727  .Case<tensor::PadOp>([&](auto padOp) {
2728  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2729  results);
2730  })
2731  .Case<linalg::PackOp>([&](auto packOp) {
2732  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2733  results);
2734  })
2735  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2736  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2737  inputVectorSizes,
2738  inputScalableVecDims, results);
2739  })
2740  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2741  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2742  results);
2743  })
2744  .Default([](auto) { return failure(); });
2745 
2746  if (failed(vectorizeResult)) {
2747  LDBG() << "Vectorization failed";
2748  return failure();
2749  }
2750 
2751  return VectorizationResult{results};
2752 }
2753 
2755  memref::CopyOp copyOp) {
2756  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2757  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2758  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2759  return failure();
2760 
2761  auto srcElementType = getElementTypeOrSelf(srcType);
2762  auto dstElementType = getElementTypeOrSelf(dstType);
2763  if (!VectorType::isValidElementType(srcElementType) ||
2764  !VectorType::isValidElementType(dstElementType))
2765  return failure();
2766 
2767  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2768  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2769 
2770  Location loc = copyOp->getLoc();
2771  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2772  SmallVector<Value> indices(srcType.getRank(), zero);
2773 
2774  Value readValue = vector::TransferReadOp::create(
2775  rewriter, loc, readType, copyOp.getSource(), indices,
2776  /*padding=*/std::nullopt,
2777  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2778  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2779  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2780  ArrayRef<int64_t>());
2781  readValue =
2782  vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2783  }
2784  Operation *writeValue = vector::TransferWriteOp::create(
2785  rewriter, loc, readValue, copyOp.getTarget(), indices,
2786  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2787  rewriter.replaceOp(copyOp, writeValue->getResults());
2788  return success();
2789 }
2790 
2791 //----------------------------------------------------------------------------//
2792 // Misc. vectorization patterns.
2793 //----------------------------------------------------------------------------//
2794 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2795 /// given operation type OpTy.
2796 template <typename OpTy>
2797 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2799 
2800  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2801  PatternRewriter &rewriter) const final {
2802  bool changed = false;
2803  // Insert users in vector, because some users may be replaced/removed.
2804  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2805  if (auto op = dyn_cast<OpTy>(user))
2806  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2807  return success(changed);
2808  }
2809 
2810 protected:
2811  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2812  tensor::PadOp padOp, OpTy op) const = 0;
2813 };
2814 
2815 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2816 /// ```
2817 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2818 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2819 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2820 /// ```
2821 /// is rewritten to:
2822 /// ```
2823 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2824 /// {in_bounds = [true, true]}
2825 /// : tensor<?x?xf32>, vector<17x5xf32>
2826 /// ```
2827 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2828 /// sure that the original padding value %cst was never used.
2829 ///
2830 /// This rewrite is possible if:
2831 /// - `xferOp` has no out-of-bounds dims or mask.
2832 /// - Low padding is static 0.
2833 /// - Single, scalar padding value.
2835  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2837  vector::TransferReadOp>::VectorizePadOpUserPattern;
2838 
2839  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2840  vector::TransferReadOp xferOp) const override {
2841  // Low padding must be static 0.
2842  if (!padOp.hasZeroLowPad())
2843  return failure();
2844  // Pad value must be a constant.
2845  auto padValue = padOp.getConstantPaddingValue();
2846  if (!padValue)
2847  return failure();
2848  // Padding value of existing `xferOp` is unused.
2849  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2850  return failure();
2851 
2852  rewriter.modifyOpInPlace(xferOp, [&]() {
2853  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2854  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2855  rewriter.getBoolArrayAttr(inBounds));
2856  xferOp.getBaseMutable().assign(padOp.getSource());
2857  xferOp.getPaddingMutable().assign(padValue);
2858  });
2859 
2860  return success();
2861  }
2862 };
2863 
2864 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2865 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2866 /// value, where the same amount of padding is immediately removed again after
2867 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2868 /// tensor value and apply out-of-bounds masking. E.g.:
2869 /// ```
2870 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2871 /// : tensor<...> to tensor<?x?xf32>
2872 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2873 /// %2 = vector.transfer_write %vec, %1[...]
2874 /// : vector<17x5xf32>, tensor<17x5xf32>
2875 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2876 /// : tensor<17x5xf32> to tensor<?x?xf32>
2877 /// ```
2878 /// is rewritten to:
2879 /// ```
2880 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2881 /// : tensor<...> to tensor<?x?xf32>
2882 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2883 /// tensor<?x?xf32>
2884 /// ```
2885 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2886 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2887 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2888 /// from %r's old dimensions.
2889 ///
2890 /// This rewrite is possible if:
2891 /// - Low padding is static 0.
2892 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2893 /// ExtractSliceOp trims the same amount of padding that was added
2894 /// beforehand.
2895 /// - Single, scalar padding value.
2897  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2899  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2900 
2901  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2902  vector::TransferWriteOp xferOp) const override {
2903  // TODO: support 0-d corner case.
2904  if (xferOp.getTransferRank() == 0)
2905  return failure();
2906 
2907  // Low padding must be static 0.
2908  if (!padOp.hasZeroLowPad())
2909  return failure();
2910  // Pad value must be a constant.
2911  auto padValue = padOp.getConstantPaddingValue();
2912  if (!padValue)
2913  return failure();
2914  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2915  if (!xferOp->hasOneUse())
2916  return failure();
2917  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2918  if (!trimPadding)
2919  return failure();
2920  // Only static zero offsets supported when trimming padding.
2921  if (!trimPadding.hasZeroOffset())
2922  return failure();
2923  // trimPadding must remove the amount of padding that was added earlier.
2924  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2925  return failure();
2926 
2927  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2928  rewriter.setInsertionPoint(xferOp);
2929 
2930  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2931  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2932  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2933  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2934  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2935  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2936 
2937  return success();
2938  }
2939 
2940  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2941  /// i.e., same dimensions.
2942  ///
2943  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2944  /// dimensions, this function tries to infer the (static) tensor size by
2945  /// looking at the defining op and utilizing op-specific knowledge.
2946  ///
2947  /// This is a conservative analysis. In case equal tensor sizes cannot be
2948  /// proven statically, this analysis returns `false` even though the tensor
2949  /// sizes may turn out to be equal at runtime.
2950  bool hasSameTensorSize(Value beforePadding,
2951  tensor::ExtractSliceOp afterTrimming) const {
2952  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2953  // result and CastOp operand.
2954  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2955  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2956  return true;
2957 
2958  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2959  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2960  // Only RankedTensorType supported.
2961  if (!t1 || !t2)
2962  return false;
2963  // Rank of both values must be the same.
2964  if (t1.getRank() != t2.getRank())
2965  return false;
2966 
2967  // All static dimensions must be the same. Mixed cases (e.g., dimension
2968  // static in `t1` but dynamic in `t2`) are not supported.
2969  for (unsigned i = 0; i < t1.getRank(); ++i) {
2970  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2971  return false;
2972  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2973  return false;
2974  }
2975 
2976  // Nothing more to check if all dimensions are static.
2977  if (t1.getNumDynamicDims() == 0)
2978  return true;
2979 
2980  // All dynamic sizes must be the same. The only supported case at the
2981  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2982  // thereof).
2983 
2984  // Apart from CastOp, only ExtractSliceOp is supported.
2985  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2986  if (!beforeSlice)
2987  return false;
2988 
2989  assert(static_cast<size_t>(t1.getRank()) ==
2990  beforeSlice.getMixedSizes().size());
2991  assert(static_cast<size_t>(t2.getRank()) ==
2992  afterTrimming.getMixedSizes().size());
2993 
2994  for (unsigned i = 0; i < t1.getRank(); ++i) {
2995  // Skip static dimensions.
2996  if (!t1.isDynamicDim(i))
2997  continue;
2998  auto size1 = beforeSlice.getMixedSizes()[i];
2999  auto size2 = afterTrimming.getMixedSizes()[i];
3000 
3001  // Case 1: Same value or same constant int.
3002  if (isEqualConstantIntOrValue(size1, size2))
3003  continue;
3004 
3005  // Other cases: Take a deeper look at defining ops of values.
3006  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3007  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3008  if (!v1 || !v2)
3009  return false;
3010 
3011  // Case 2: Both values are identical AffineMinOps. (Should not happen if
3012  // CSE is run.)
3013  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3014  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3015  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3016  minOp1.getOperands() == minOp2.getOperands())
3017  continue;
3018 
3019  // Add additional cases as needed.
3020  }
3021 
3022  // All tests passed.
3023  return true;
3024  }
3025 };
3026 
3027 /// Returns the effective Pad value for the input op, provided it's a scalar.
3028 ///
3029 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3030 /// this Op performs padding, retrieve the padding value provided that it's
3031 /// a scalar and static/fixed for all the padded values. Returns an empty value
3032 /// otherwise.
3033 ///
3034 /// TODO: This is used twice (when checking vectorization pre-conditions and
3035 /// when vectorizing). Cache results instead of re-running.
3037  if (!op)
3038  return {};
3039 
3040  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3041  // being broadcast, provided that it's a scalar.
3042  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3043  auto source = bcast.getSource();
3044  if (llvm::dyn_cast<VectorType>(source.getType()))
3045  return {};
3046 
3047  return source;
3048  }
3049 
3050  // 2. linalg.fill - use the scalar input value that used to fill the output
3051  // tensor.
3052  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3053  return fill.getInputs()[0];
3054  }
3055 
3056  // 3. tensor.generateOp - can't guarantee the value is fixed without
3057  // analysing, bail out.
3058  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3059  return {};
3060  }
3061 
3062  // 4. vector.transfer_write - inspect the input vector that's written from. If
3063  // if contains a single value that has been broadcast (e.g. via
3064  // vector.broadcast), extract it, fail otherwise.
3065  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3066  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3067 
3068  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3069  // than the input tensor, then, provided it's constant, we'll extract the
3070  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3071  // TODO: Clarify the semantics when the input tensor is larger than the
3072  // destination.
3073  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3074  return getStaticPadVal(slice.getDest().getDefiningOp());
3075 
3076  return {};
3077 }
3078 
3079 static LogicalResult
3080 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3081  ArrayRef<int64_t> inputVectorSizes,
3082  SmallVectorImpl<Value> &newResults) {
3083  // TODO: Introduce a parent class that will handle the insertion point update.
3084  OpBuilder::InsertionGuard g(rewriter);
3085  rewriter.setInsertionPoint(sliceOp);
3086 
3087  TypedValue<RankedTensorType> source = sliceOp.getSource();
3088  auto sourceType = source.getType();
3089  auto resultType = sliceOp.getResultType();
3090 
3091  Value padValue = getStaticPadVal(sliceOp);
3092 
3093  if (!padValue) {
3094  auto elemType = sourceType.getElementType();
3095  padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3096  rewriter.getZeroAttr(elemType));
3097  }
3098 
3099  // 2. Get the vector shape
3100  SmallVector<int64_t> vecShape;
3101  size_t rankDiff = resultType.getRank() - sourceType.getRank();
3102  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3103  if (!inputVectorSizes.empty()) {
3104  vecShape.push_back(inputVectorSizes[i]);
3105  } else if (!sourceType.isDynamicDim(i)) {
3106  vecShape.push_back(sourceType.getDimSize(i));
3107  } else if (!resultType.isDynamicDim(i)) {
3108  // Source shape is not statically known, but result shape is.
3109  // Vectorize with size of result shape. This may be larger than the
3110  // source size.
3111  // FIXME: Using rankDiff implies that the source tensor is inserted at
3112  // the end of the destination tensor. However, that's not required.
3113  vecShape.push_back(resultType.getDimSize(rankDiff + i));
3114  } else {
3115  // Neither source nor result dim of padOp is static. Cannot vectorize
3116  // the copy.
3117  return failure();
3118  }
3119  }
3120  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3121 
3122  // 3. Generate TransferReadOp + TransferWriteOp
3123  auto loc = sliceOp.getLoc();
3124 
3125  // Create read
3126  SmallVector<Value> readIndices(
3127  vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3129  rewriter, loc, source, vecType.getShape(), padValue,
3130  /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3131  /*inputScalableVecSizes=*/{});
3132 
3133  // Create write
3134  auto writeIndices =
3135  getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3136  Operation *write =
3137  createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3138  writeIndices, inputVectorSizes.empty());
3139 
3140  // 4. Finalize
3141  newResults.push_back(write->getResult(0));
3142 
3143  return success();
3144 }
3145 
3146 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3147 /// ```
3148 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3149 /// %r = tensor.insert_slice %0
3150 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3151 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3152 /// ```
3153 /// is rewritten to:
3154 /// ```
3155 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
3156 /// : tensor<?x?xf32>, vector<17x5xf32>
3157 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3158 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3159 /// ```
3160 ///
3161 /// This rewrite is possible if:
3162 /// - Low padding is static 0.
3163 /// - `padOp` result shape is static.
3164 /// - The entire padded tensor is inserted.
3165 /// (Implies that sizes of `insertOp` are all static.)
3166 /// - Only unit strides in `insertOp`.
3167 /// - Single, scalar padding value.
3168 /// - `padOp` result not used as destination.
3170  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3172  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3173 
3174  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3175  tensor::InsertSliceOp insertOp) const override {
3176  // Low padding must be static 0.
3177  if (!padOp.hasZeroLowPad())
3178  return failure();
3179  // Only unit stride supported.
3180  if (!insertOp.hasUnitStride())
3181  return failure();
3182  // Pad value must be a constant.
3183  auto padValue = padOp.getConstantPaddingValue();
3184  if (!padValue)
3185  return failure();
3186  // Dynamic shapes not supported.
3187  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3188  return failure();
3189  // Pad result not used as destination.
3190  if (insertOp.getDest() == padOp.getResult())
3191  return failure();
3192 
3193  auto vecType = VectorType::get(padOp.getType().getShape(),
3194  padOp.getType().getElementType());
3195  unsigned vecRank = vecType.getRank();
3196  unsigned tensorRank = insertOp.getType().getRank();
3197 
3198  // Check if sizes match: Insert the entire tensor into most minor dims.
3199  // (No permutations allowed.)
3200  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3201  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3202  if (!llvm::all_of(
3203  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3204  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3205  }))
3206  return failure();
3207 
3208  // Insert the TransferReadOp and TransferWriteOp at the position of the
3209  // InsertSliceOp.
3210  rewriter.setInsertionPoint(insertOp);
3211 
3212  // Generate TransferReadOp: Read entire source tensor and add high
3213  // padding.
3214  SmallVector<Value> readIndices(
3215  vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3216  auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3217  vecType, padOp.getSource(),
3218  readIndices, padValue);
3219 
3220  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3221  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3222  // source must fit into the destination at the specified offsets.
3223  auto writeIndices = getValueOrCreateConstantIndexOp(
3224  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3225  SmallVector<bool> inBounds(vecRank, true);
3226  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3227  insertOp, read, insertOp.getDest(), writeIndices,
3228  ArrayRef<bool>{inBounds});
3229 
3230  return success();
3231  }
3232 };
3233 
3235  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3239  patterns.getContext(), baseBenefit.getBenefit() + 1);
3240 }
3241 
3242 //----------------------------------------------------------------------------//
3243 // Forwarding patterns
3244 //----------------------------------------------------------------------------//
3245 
3246 /// Check whether there is any interleaved use of any `values` between
3247 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3248 /// is in a different block.
3249 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3250  ValueRange values) {
3251  if (firstOp->getBlock() != secondOp->getBlock() ||
3252  !firstOp->isBeforeInBlock(secondOp)) {
3253  LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3254  << ", second op: " << *secondOp;
3255  return true;
3256  }
3257  for (auto v : values) {
3258  for (auto &u : v.getUses()) {
3259  Operation *owner = u.getOwner();
3260  if (owner == firstOp || owner == secondOp)
3261  continue;
3262  // TODO: this is too conservative, use dominance info in the future.
3263  if (owner->getBlock() == firstOp->getBlock() &&
3264  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3265  continue;
3266  LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3267  << ", second op: " << *secondOp;
3268  return true;
3269  }
3270  }
3271  return false;
3272 }
3273 
3274 /// Return the unique subview use of `v` if it is indeed unique, null
3275 /// otherwise.
3276 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3277  memref::SubViewOp subViewOp;
3278  for (auto &u : v.getUses()) {
3279  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3280  if (subViewOp)
3281  return memref::SubViewOp();
3282  subViewOp = newSubViewOp;
3283  }
3284  }
3285  return subViewOp;
3286 }
3287 
3288 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3289 /// when available.
3291  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3292 
3293  // TODO: support mask.
3294  if (xferOp.getMask())
3295  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3296 
3297  // Transfer into `view`.
3298  Value viewOrAlloc = xferOp.getBase();
3299  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3300  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3301  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3302 
3303  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3304  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3305  if (!subViewOp)
3306  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3307  Value subView = subViewOp.getResult();
3308 
3309  // Find the copy into `subView` without interleaved uses.
3310  memref::CopyOp copyOp;
3311  for (auto &u : subView.getUses()) {
3312  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3313  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3314  if (newCopyOp.getTarget() != subView)
3315  continue;
3316  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3317  continue;
3318  copyOp = newCopyOp;
3319  break;
3320  }
3321  }
3322  if (!copyOp)
3323  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3324 
3325  // Find the fill into `viewOrAlloc` without interleaved uses before the
3326  // copy.
3327  FillOp maybeFillOp;
3328  for (auto &u : viewOrAlloc.getUses()) {
3329  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3330  assert(isa<MemRefType>(newFillOp.output().getType()));
3331  if (newFillOp.output() != viewOrAlloc)
3332  continue;
3333  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3334  continue;
3335  maybeFillOp = newFillOp;
3336  break;
3337  }
3338  }
3339  // Ensure padding matches.
3340  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3341  return rewriter.notifyMatchFailure(xferOp,
3342  "padding value does not match fill");
3343 
3344  // `in` is the subview that memref.copy reads. Replace it.
3345  Value in = copyOp.getSource();
3346 
3347  // memref.copy + linalg.fill can be used to create a padded local buffer.
3348  // The `masked` attribute is only valid on this padded buffer.
3349  // When forwarding to vector.transfer_read, the attribute must be reset
3350  // conservatively.
3351  auto vectorType = xferOp.getVectorType();
3352  Value res = vector::TransferReadOp::create(
3353  rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3354  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3355  rewriter.getBoolArrayAttr(
3356  SmallVector<bool>(vectorType.getRank(), false)));
3357 
3358  if (maybeFillOp)
3359  rewriter.eraseOp(maybeFillOp);
3360  rewriter.eraseOp(copyOp);
3361  rewriter.replaceOp(xferOp, res);
3362 
3363  return success();
3364 }
3365 
3366 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3367 /// when available.
3369  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3370  // TODO: support mask.
3371  if (xferOp.getMask())
3372  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3373 
3374  // Transfer into `viewOrAlloc`.
3375  Value viewOrAlloc = xferOp.getBase();
3376  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3377  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3378  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3379 
3380  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3381  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3382  if (!subViewOp)
3383  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3384  Value subView = subViewOp.getResult();
3385 
3386  // Find the copy from `subView` without interleaved uses.
3387  memref::CopyOp copyOp;
3388  for (auto &u : subViewOp.getResult().getUses()) {
3389  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3390  if (newCopyOp.getSource() != subView)
3391  continue;
3392  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3393  continue;
3394  copyOp = newCopyOp;
3395  break;
3396  }
3397  }
3398  if (!copyOp)
3399  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3400 
3401  // `out` is the subview copied into that we replace.
3402  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3403  Value out = copyOp.getTarget();
3404 
3405  // Forward vector.transfer into copy.
3406  // memref.copy + linalg.fill can be used to create a padded local buffer.
3407  // The `masked` attribute is only valid on this padded buffer.
3408  // When forwarding to vector.transfer_write, the attribute must be reset
3409  // conservatively.
3410  auto vector = xferOp.getVector();
3411  vector::TransferWriteOp::create(
3412  rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3413  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3415  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3416 
3417  rewriter.eraseOp(copyOp);
3418  rewriter.eraseOp(xferOp);
3419 
3420  return success();
3421 }
3422 
3423 //===----------------------------------------------------------------------===//
3424 // Convolution vectorization patterns
3425 //===----------------------------------------------------------------------===//
3426 
3427 template <int N>
3428 static void bindShapeDims(ShapedType shapedType) {}
3429 
3430 template <int N, typename IntTy, typename... IntTy2>
3431 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3432  val = shapedType.getShape()[N];
3433  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3434 }
3435 
3436 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3437 template <typename... IntTy>
3438 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3439  bindShapeDims<0>(shapedType, vals...);
3440 }
3441 
3442 namespace {
3443 /// Generate a vector implementation for either:
3444 /// ```
3445 /// Op def: ( w, kw )
3446 /// Iters: ({Par(), Red()})
3447 /// Layout: {{w + kw}, {kw}, {w}}
3448 /// ```
3449 /// kw is unrolled.
3450 ///
3451 /// or
3452 ///
3453 /// ```
3454 /// Op def: ( n, w, c, kw, f )
3455 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3456 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3457 /// ```
3458 /// kw is unrolled, w is unrolled iff dilationW > 1.
3459 ///
3460 /// or
3461 ///
3462 /// ```
3463 /// Op def: ( n, c, w, f, kw )
3464 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3465 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3466 /// ```
3467 /// kw is unrolled, w is unrolled iff dilationW > 1.
3468 ///
3469 /// or
3470 ///
3471 /// ```
3472 /// Op def: ( n, w, c, kw )
3473 /// Iters: ({Par(), Par(), Par(), Red()})
3474 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3475 /// ```
3476 /// kw is unrolled, w is unrolled iff dilationW > 1.
3477 struct Conv1DGenerator
3478  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3479  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3480  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3481 
3482  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3483  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3484  resShaped = linalgOp.getDpsInitOperand(0)->get();
3485  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3486  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3487  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3488 
3489  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3490  redOp = reduceOp->getName().getIdentifier();
3491 
3492  setConvOperationKind(reduceOp);
3493 
3494  auto maybeKind = getCombinerOpKind(reduceOp);
3495  reductionKind = maybeKind.value();
3496 
3497  // The ConvolutionOpInterface gives us guarantees of existence for
3498  // strides/dilations. However, we do not need to rely on those, we can
3499  // simply use them if present, otherwise use the default and let the generic
3500  // conv. matcher in the ConvGenerator succeed or fail.
3501  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3502  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3503  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3504  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3505  }
3506 
3507  /// Generate a vector implementation for:
3508  /// ```
3509  /// Op def: ( w, kw )
3510  /// Iters: ({Par(), Red()})
3511  /// Layout: {{w + kw}, {kw}, {w}}
3512  /// ```
3513  /// kw is always unrolled.
3514  ///
3515  /// or
3516  ///
3517  /// ```
3518  /// Op def: ( n, w, c, kw, f )
3519  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3520  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3521  /// ```
3522  /// kw is always unrolled.
3523  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3524  /// > 1.
3525  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3526  int64_t nSize, wSize, cSize, kwSize, fSize;
3527  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3528  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3529  switch (conv1DOpOrder) {
3530  case Conv1DOpOrder::W:
3531  // Initialize unused dimensions
3532  nSize = fSize = cSize = 0;
3533  // out{W}
3534  bindShapeDims(resShapedType, wSize);
3535  // kernel{kw}
3536  bindShapeDims(rhsShapedType, kwSize);
3537  lhsShape = {// iw = ow + kw - 1
3538  // (i.e. 16 convolved with 3 -> 14)
3539  (wSize + kwSize - 1)};
3540  rhsShape = {kwSize};
3541  resShape = {wSize};
3542  break;
3543  case Conv1DOpOrder::Nwc:
3544  // out{n, w, f}
3545  bindShapeDims(resShapedType, nSize, wSize, fSize);
3546  switch (oper) {
3547  case ConvOperationKind::Conv:
3548  // kernel{kw, c, f}
3549  bindShapeDims(rhsShapedType, kwSize, cSize);
3550  break;
3551  case ConvOperationKind::Pool:
3552  // kernel{kw}
3553  bindShapeDims(rhsShapedType, kwSize);
3554  cSize = fSize;
3555  break;
3556  }
3557  lhsShape = {nSize,
3558  // iw = ow * sw + kw * dw - 1
3559  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3560  // Perform the proper inclusive -> exclusive -> inclusive.
3561  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3562  1,
3563  cSize};
3564  switch (oper) {
3565  case ConvOperationKind::Conv:
3566  rhsShape = {kwSize, cSize, fSize};
3567  break;
3568  case ConvOperationKind::Pool:
3569  rhsShape = {kwSize};
3570  break;
3571  }
3572  resShape = {nSize, wSize, fSize};
3573  break;
3574  case Conv1DOpOrder::Ncw:
3575  // out{n, f, w}
3576  bindShapeDims(resShapedType, nSize, fSize, wSize);
3577  switch (oper) {
3578  case ConvOperationKind::Conv:
3579  // kernel{f, c, kw}
3580  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3581  break;
3582  case ConvOperationKind::Pool:
3583  // kernel{kw}
3584  bindShapeDims(rhsShapedType, kwSize);
3585  cSize = fSize;
3586  break;
3587  }
3588  lhsShape = {nSize, cSize,
3589  // iw = ow * sw + kw * dw - 1
3590  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3591  // Perform the proper inclusive -> exclusive -> inclusive.
3592  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3593  1};
3594  switch (oper) {
3595  case ConvOperationKind::Conv:
3596  rhsShape = {fSize, cSize, kwSize};
3597  break;
3598  case ConvOperationKind::Pool:
3599  rhsShape = {kwSize};
3600  break;
3601  }
3602  resShape = {nSize, fSize, wSize};
3603  break;
3604  }
3605 
3606  vector::TransferWriteOp write;
3607  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3608 
3609  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3610  // When strideW == 1, we can batch the contiguous loads and avoid
3611  // unrolling
3612  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3613 
3614  Type lhsEltType = lhsShapedType.getElementType();
3615  Type rhsEltType = rhsShapedType.getElementType();
3616  Type resEltType = resShapedType.getElementType();
3617  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3618  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3619  auto resType = VectorType::get(resShape, resEltType);
3620  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3621  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3622  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3623  SmallVector<Value> resPadding(resShape.size(), zero);
3624 
3625  // Read the whole lhs, rhs and res in one shot (with zero padding).
3626  Value lhs = vector::TransferReadOp::create(
3627  rewriter, loc, lhsType, lhsShaped, lhsPadding,
3628  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3629  // This is needed only for Conv.
3630  Value rhs = nullptr;
3631  if (oper == ConvOperationKind::Conv)
3632  rhs = vector::TransferReadOp::create(
3633  rewriter, loc, rhsType, rhsShaped, rhsPadding,
3634  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3635  Value res = vector::TransferReadOp::create(
3636  rewriter, loc, resType, resShaped, resPadding,
3637  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3638 
3639  // The base vectorization case for channeled convolution is input:
3640  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3641  // vectorization case, we do pre transpose on input, weight, and output.
3642  switch (conv1DOpOrder) {
3643  case Conv1DOpOrder::W:
3644  case Conv1DOpOrder::Nwc:
3645  // Base case, so no transposes necessary.
3646  break;
3647  case Conv1DOpOrder::Ncw: {
3648  // To match base vectorization case, we pre-transpose current case.
3649  // ncw -> nwc
3650  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3651  lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3652  // fcw -> wcf
3653  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3654 
3655  // This is needed only for Conv.
3656  if (oper == ConvOperationKind::Conv)
3657  rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3658  // nfw -> nwf
3659  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3660  res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3661  break;
3662  }
3663  }
3664 
3665  //===------------------------------------------------------------------===//
3666  // Begin vector-only rewrite part
3667  //===------------------------------------------------------------------===//
3668  // Unroll along kw and read slices of lhs and rhs.
3669  SmallVector<Value> lhsVals, rhsVals, resVals;
3670  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3671  kwSize, strideW, dilationW, wSizeStep,
3672  isSingleChanneled);
3673  // Do not do for pooling.
3674  if (oper == ConvOperationKind::Conv)
3675  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3676  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3677  wSizeStep, isSingleChanneled);
3678 
3679  auto linearIndex = [&](int64_t kw, int64_t w) {
3680  return kw * (wSize / wSizeStep) + w;
3681  };
3682 
3683  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3684  // or perform outerproduct for non-channeled convolution or perform simple
3685  // arith operation for pooling
3686  for (int64_t kw = 0; kw < kwSize; ++kw) {
3687  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3688  switch (oper) {
3689  case ConvOperationKind::Conv:
3690  if (isSingleChanneled) {
3691  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3692  lhsVals[linearIndex(kw, w)],
3693  rhsVals[kw], resVals[w]);
3694  } else {
3695  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3696  lhsVals[linearIndex(kw, w)],
3697  rhsVals[kw], resVals[w]);
3698  }
3699  break;
3700  case ConvOperationKind::Pool:
3701  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3702  resVals[w]);
3703  break;
3704  }
3705  }
3706  }
3707 
3708  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3709  isSingleChanneled);
3710  //===------------------------------------------------------------------===//
3711  // End vector-only rewrite part
3712  //===------------------------------------------------------------------===//
3713 
3714  // The base vectorization case for channeled convolution is output:
3715  // {n,w,f} To reuse the result from base pattern vectorization case, we
3716  // post transpose the base case result.
3717  switch (conv1DOpOrder) {
3718  case Conv1DOpOrder::W:
3719  case Conv1DOpOrder::Nwc:
3720  // Base case, so no transposes necessary.
3721  break;
3722  case Conv1DOpOrder::Ncw: {
3723  // nwf -> nfw
3724  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3725  res = vector::TransposeOp::create(rewriter, loc, res, perm);
3726  break;
3727  }
3728  }
3729 
3730  return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3731  resPadding)
3732  .getOperation();
3733  }
3734 
3735  // Take a value and widen to have the same element type as `ty`.
3736  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3737  const Type srcElementType = getElementTypeOrSelf(val.getType());
3738  const Type dstElementType = getElementTypeOrSelf(ty);
3739  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3740  if (srcElementType == dstElementType)
3741  return val;
3742 
3743  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3744  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3745  const Type dstType =
3746  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3747 
3748  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3749  return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3750  }
3751 
3752  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3753  srcWidth < dstWidth)
3754  return arith::ExtFOp::create(rewriter, loc, dstType, val);
3755 
3756  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3757  srcWidth < dstWidth)
3758  return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3759 
3760  assert(false && "unhandled promotion case");
3761  return nullptr;
3762  }
3763 
3764  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3765  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3766  Value lhs, Value rhs, Value res) {
3767  vector::IteratorType par = vector::IteratorType::parallel;
3768  vector::IteratorType red = vector::IteratorType::reduction;
3769  AffineExpr n, w, f, c;
3770  bindDims(ctx, n, w, f, c);
3771  lhs = promote(rewriter, loc, lhs, res.getType());
3772  rhs = promote(rewriter, loc, rhs, res.getType());
3773  auto contrationOp = vector::ContractionOp::create(
3774  rewriter, loc, lhs, rhs, res,
3775  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3776  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3777  contrationOp.setKind(reductionKind);
3778  return contrationOp;
3779  }
3780 
3781  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3782  // convolution.
3783  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3784  Value lhs, Value rhs, Value res) {
3785  return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3786  rhs, res, vector::CombiningKind::ADD);
3787  }
3788 
3789  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3790  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3791  Value res) {
3792  if (isPoolExt)
3793  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3794  return rewriter
3795  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3796  ->getResult(0);
3797  }
3798 
3799  /// Generate a vector implementation for:
3800  /// ```
3801  /// Op def: ( n, w, c, kw)
3802  /// Iters: ({Par(), Par(), Par(), Red()})
3803  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3804  /// ```
3805  /// kw is always unrolled.
3806  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3807  /// > 1.
3808  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3809  bool channelDimScalableFlag,
3810  bool flatten) {
3811  bool scalableChDim = false;
3812  bool useMasking = false;
3813  int64_t nSize, wSize, cSize, kwSize;
3814  // kernel{kw, c}
3815  bindShapeDims(rhsShapedType, kwSize, cSize);
3816  if (ShapedType::isDynamic(cSize)) {
3817  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3818  cSize = channelDimVecSize;
3819  // Scalable vectors are only used when both conditions are met:
3820  // 1. channel dim is dynamic
3821  // 2. channelDimScalableFlag is set
3822  scalableChDim = channelDimScalableFlag;
3823  useMasking = true;
3824  }
3825 
3826  assert(!(useMasking && flatten) &&
3827  "Unsupported flattened conv with dynamic shapes");
3828 
3829  // out{n, w, c}
3830  bindShapeDims(resShapedType, nSize, wSize);
3831 
3832  vector::TransferWriteOp write;
3833  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3834 
3835  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3836  // When strideW == 1, we can batch the contiguous loads and avoid
3837  // unrolling
3838  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3839 
3840  Type lhsEltType = lhsShapedType.getElementType();
3841  Type rhsEltType = rhsShapedType.getElementType();
3842  Type resEltType = resShapedType.getElementType();
3843  VectorType lhsType = VectorType::get(
3844  {nSize,
3845  // iw = ow * sw + kw * dw - 1
3846  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3847  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3848  cSize},
3849  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3850  VectorType rhsType =
3851  VectorType::get({kwSize, cSize}, rhsEltType,
3852  /*scalableDims=*/{false, scalableChDim});
3853  VectorType resType =
3854  VectorType::get({nSize, wSize, cSize}, resEltType,
3855  /*scalableDims=*/{false, false, scalableChDim});
3856 
3857  // Masks the input xfer Op along the channel dim, iff the corresponding
3858  // scalable flag is set.
3859  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3860  ArrayRef<bool> scalableDims,
3861  Operation *opToMask) {
3862  if (!useMasking)
3863  return opToMask;
3864  auto maskType =
3865  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3866 
3867  SmallVector<bool> inBounds(maskShape.size(), true);
3868  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3869  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3870  rewriter.getBoolArrayAttr(inBounds));
3871 
3873  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3874 
3875  Value maskOp =
3876  vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3877 
3878  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3879  };
3880 
3881  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3882  // 0].
3883  Value lhs = vector::TransferReadOp::create(
3884  rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3885  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3886  auto maybeMaskedLhs = maybeMaskXferOp(
3887  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3888 
3889  // Read rhs slice of size {kw, c} @ [0, 0].
3890  Value rhs = vector::TransferReadOp::create(
3891  rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3892  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3893  auto maybeMaskedRhs = maybeMaskXferOp(
3894  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3895 
3896  // Read res slice of size {n, w, c} @ [0, 0, 0].
3897  Value res = vector::TransferReadOp::create(
3898  rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
3899  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3900  auto maybeMaskedRes = maybeMaskXferOp(
3901  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3902 
3903  //===------------------------------------------------------------------===//
3904  // Begin vector-only rewrite part
3905  //===------------------------------------------------------------------===//
3906  // Unroll along kw and read slices of lhs and rhs.
3907  SmallVector<Value> lhsVals, rhsVals, resVals;
3908  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3909  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3910 
3911  // Extract lhs slice of size {n, wSizeStep, c}
3912  // @ [0, sw * w + dw * kw, 0].
3913  for (int64_t kw = 0; kw < kwSize; ++kw) {
3914  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3915  lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3916  rewriter, loc, maybeMaskedLhs->getResult(0),
3917  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3918  inOutSliceSizes, inOutStrides));
3919  }
3920  }
3921  // Extract rhs slice of size {c} @ [kw].
3922  for (int64_t kw = 0; kw < kwSize; ++kw) {
3923  rhsVals.push_back(
3924  vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3925  /*offsets=*/ArrayRef<int64_t>{kw}));
3926  }
3927  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3928  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3929  resVals.push_back(vector::ExtractStridedSliceOp::create(
3930  rewriter, loc, maybeMaskedRes->getResult(0),
3931  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3932  inOutStrides));
3933  }
3934 
3935  auto linearIndex = [&](int64_t kw, int64_t w) {
3936  return kw * (wSize / wSizeStep) + w;
3937  };
3938 
3939  // Note - the scalable flags are ignored as flattening combined with
3940  // scalable vectorization is not supported.
3941  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3942  auto lhsTypeAfterFlattening =
3943  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3944  auto resTypeAfterFlattening =
3945  VectorType::get(inOutFlattenSliceSizes, resEltType);
3946 
3947  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3948  for (int64_t kw = 0; kw < kwSize; ++kw) {
3949  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3950  Value lhsVal = lhsVals[linearIndex(kw, w)];
3951  Value resVal = resVals[w];
3952  if (flatten) {
3953  // Flatten the input and output vectors (collapse the channel
3954  // dimension)
3955  lhsVal =
3956  vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3957  lhsVals[linearIndex(kw, w)]);
3958  resVal = vector::ShapeCastOp::create(
3959  rewriter, loc, resTypeAfterFlattening, resVals[w]);
3960  }
3961  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3962  rhsVals[kw], resVal, flatten);
3963  if (flatten) {
3964  // Un-flatten the output vector (restore the channel dimension)
3965  resVals[w] = vector::ShapeCastOp::create(
3966  rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
3967  resVals[w]);
3968  }
3969  }
3970  }
3971 
3972  // Its possible we failed to create the Fma.
3973  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3974  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3975  for (auto &collection :
3976  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3977  for (Value v : collection)
3978  rewriter.eraseOp(v.getDefiningOp());
3979  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3980  }
3981 
3982  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3983  // This does not depend on kw.
3984  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3985  maybeMaskedRes = vector::InsertStridedSliceOp::create(
3986  rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
3987  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3988  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3989  }
3990  //===------------------------------------------------------------------===//
3991  // End vector-only rewrite part
3992  //===------------------------------------------------------------------===//
3993 
3994  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3995  Operation *resOut = vector::TransferWriteOp::create(
3996  rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
3997  ValueRange{zero, zero, zero});
3998  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3999  resOut);
4000  }
4001 
4002  /// Lower:
4003  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4004  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4005  /// to MulAcc.
4006  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4007  Value lhs, Value rhs, Value res,
4008  bool flatten) {
4009  auto rhsTy = cast<ShapedType>(rhs.getType());
4010  auto resTy = cast<ShapedType>(res.getType());
4011 
4012  // TODO(suderman): Change this to use a vector.ima intrinsic.
4013  lhs = promote(rewriter, loc, lhs, resTy);
4014 
4015  if (flatten) {
4016  // NOTE: This following logic won't work for scalable vectors. For this
4017  // reason, "flattening" is not supported when shapes are dynamic (this
4018  // should be captured by one of the pre-conditions).
4019 
4020  // There are two options for handling the filter:
4021  // * shape_cast(broadcast(filter))
4022  // * broadcast(shuffle(filter))
4023  // Opt for the option without shape_cast to simplify the codegen.
4024  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4025  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4026 
4027  SmallVector<int64_t, 16> indices;
4028  for (int i = 0; i < resSize / rhsSize; ++i) {
4029  for (int j = 0; j < rhsSize; ++j)
4030  indices.push_back(j);
4031  }
4032 
4033  rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4034  }
4035  // Broadcast the filter to match the output vector
4036  rhs = vector::BroadcastOp::create(rewriter, loc,
4037  resTy.clone(rhsTy.getElementType()), rhs);
4038 
4039  rhs = promote(rewriter, loc, rhs, resTy);
4040 
4041  if (!lhs || !rhs)
4042  return nullptr;
4043 
4044  if (isa<FloatType>(resTy.getElementType()))
4045  return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4046 
4047  auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4048  return arith::AddIOp::create(rewriter, loc, mul, res);
4049  }
4050 
4051  /// Entry point for non-channeled convolution:
4052  /// {{w + kw}, {kw}, {w}}
4053  FailureOr<Operation *> generateNonChanneledConv() {
4054  AffineExpr w, kw;
4055  bindDims(ctx, w, kw);
4056  if (!iters({Par(), Red()}))
4057  return rewriter.notifyMatchFailure(op,
4058  "failed to match conv::W 1-par 1-red");
4059 
4060  // No transposition needed.
4061  if (layout({/*lhsIndex*/ {w + kw},
4062  /*rhsIndex*/ {kw},
4063  /*resIndex*/ {w}}))
4064  return conv(Conv1DOpOrder::W);
4065 
4066  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4067  }
4068 
4069  /// Entry point that transposes into the common form:
4070  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4071  FailureOr<Operation *> generateNwcConv() {
4072  AffineExpr n, w, f, kw, c;
4073  bindDims(ctx, n, w, f, kw, c);
4074  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4075  return rewriter.notifyMatchFailure(
4076  op, "failed to match conv::Nwc 3-par 2-red");
4077 
4078  // No transposition needed.
4079  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4080  /*rhsIndex*/ {kw, c, f},
4081  /*resIndex*/ {n, w, f}}))
4082  return conv(Conv1DOpOrder::Nwc);
4083 
4084  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4085  }
4086 
4087  /// Entry point that transposes into the common form:
4088  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4089  FailureOr<Operation *> generateNcwConv() {
4090  AffineExpr n, w, f, kw, c;
4091  bindDims(ctx, n, f, w, c, kw);
4092  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4093  return rewriter.notifyMatchFailure(
4094  op, "failed to match conv::Ncw 3-par 2-red");
4095 
4096  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4097  /*rhsIndex*/ {f, c, kw},
4098  /*resIndex*/ {n, f, w}}))
4099  return conv(Conv1DOpOrder::Ncw);
4100 
4101  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4102  }
4103 
4104  /// Entry point that transposes into the common form:
4105  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4106  FailureOr<Operation *> generateNwcPooling() {
4107  AffineExpr n, w, c, kw;
4108  bindDims(ctx, n, w, c, kw);
4109  if (!iters({Par(), Par(), Par(), Red()}))
4110  return rewriter.notifyMatchFailure(op,
4111  "failed to match pooling 3-par 1-red");
4112 
4113  // No transposition needed.
4114  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4115  /*rhsIndex*/ {kw},
4116  /*resIndex*/ {n, w, c}}))
4117  return conv(Conv1DOpOrder::Nwc);
4118 
4119  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4120  }
4121 
4122  /// Entry point that transposes into the common form:
4123  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4124  FailureOr<Operation *> generateNcwPooling() {
4125  AffineExpr n, w, c, kw;
4126  bindDims(ctx, n, c, w, kw);
4127  if (!iters({Par(), Par(), Par(), Red()}))
4128  return rewriter.notifyMatchFailure(op,
4129  "failed to match pooling 3-par 1-red");
4130 
4131  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4132  /*rhsIndex*/ {kw},
4133  /*resIndex*/ {n, c, w}}))
4134  return conv(Conv1DOpOrder::Ncw);
4135 
4136  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4137  }
4138 
4139  /// Entry point that transposes into the common form:
4140  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4141  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4142  bool vecChDimScalableFlag = false,
4143  bool flatten = false) {
4144  AffineExpr n, w, c, kw;
4145  bindDims(ctx, n, w, c, kw);
4146  if (!iters({Par(), Par(), Par(), Red()}))
4147  return rewriter.notifyMatchFailure(
4148  op, "failed to match depthwise::Nwc conv 3-par 1-red");
4149 
4150  // No transposition needed.
4151  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4152  /*rhsIndex*/ {kw, c},
4153  /*resIndex*/ {n, w, c}}))
4154  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4155 
4156  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4157  }
4158 
4159 private:
4160  ConvOperationKind oper = ConvOperationKind::Conv;
4161  StringAttr redOp;
4162  StringAttr poolExtOp;
4163  bool isPoolExt = false;
4164  int strideW, dilationW;
4165  Value lhsShaped, rhsShaped, resShaped;
4166  ShapedType lhsShapedType, rhsShapedType, resShapedType;
4167  vector::CombiningKind reductionKind;
4168 
4169  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4170  void setConvOperationKind(Operation *reduceOp) {
4171  int numBlockArguments =
4172  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4173  if (numBlockArguments == 1) {
4174  // Will be convolution if feeder is a MulOp.
4175  // A strength reduced version of MulOp for i1 type is AndOp which is also
4176  // supported. Otherwise, it can be pooling. This strength reduction logic
4177  // is in `buildBinaryFn` helper in the Linalg dialect.
4178  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4179  llvm::IsaPred<BlockArgument>);
4180  Operation *feedOp = (*feedValIt).getDefiningOp();
4181  if (isCastOfBlockArgument(feedOp)) {
4182  oper = ConvOperationKind::Pool;
4183  isPoolExt = true;
4184  poolExtOp = feedOp->getName().getIdentifier();
4185  return;
4186  }
4187  oper = ConvOperationKind::Conv;
4188  return;
4189  }
4190  // numBlockArugments == 2 and this is a pooling op.
4191  oper = ConvOperationKind::Pool;
4192  isPoolExt = false;
4193  }
4194 };
4195 } // namespace
4196 
4197 /// Helper function to vectorize a LinalgOp with convolution semantics.
4198 // TODO: extend the generic vectorization to support windows and drop this.
4199 static FailureOr<Operation *> vectorizeConvolution(
4200  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4201  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4202  Conv1DGenerator conv1dGen(rewriter, op);
4203  auto res = conv1dGen.generateNonChanneledConv();
4204  if (succeeded(res))
4205  return res;
4206  res = conv1dGen.generateNwcConv();
4207  if (succeeded(res))
4208  return res;
4209  res = conv1dGen.generateNcwConv();
4210  if (succeeded(res))
4211  return res;
4212  res = conv1dGen.generateNwcPooling();
4213  if (succeeded(res))
4214  return res;
4215  res = conv1dGen.generateNcwPooling();
4216  if (succeeded(res))
4217  return res;
4218 
4219  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4220  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4221  // masked/scalable) is the channel dim (i.e. the trailing dim).
4222  uint64_t vecChDimSize = ShapedType::kDynamic;
4223  bool vecChDimScalableFlag = false;
4224  if (!inputVecSizes.empty()) {
4225  // Only use the input vector size corresponding to the channel dim. Other
4226  // vector dims will be inferred from the Ops.
4227  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4228  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4229  "Not a 1D depthwise conv!");
4230  size_t chDimIdx =
4232  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4233  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4234 
4235  vecChDimSize = inputVecSizes[chDimIdx];
4236  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4237  }
4238  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4239  flatten1DDepthwiseConv);
4240 }
4241 
4244 
4245  LogicalResult matchAndRewrite(LinalgOp op,
4246  PatternRewriter &rewriter) const override {
4247  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4248  if (failed(resultOrFail))
4249  return failure();
4250  Operation *newOp = *resultOrFail;
4251  if (newOp->getNumResults() == 0) {
4252  rewriter.eraseOp(op.getOperation());
4253  return success();
4254  }
4255  assert(newOp->getNumResults() == 1 && "expected single result");
4256  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4257  return success();
4258  }
4259 };
4260 
4263  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4264 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:5180
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:138
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:600
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:308
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:386
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:269
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:218
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:380
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:200
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:699
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2787
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:333
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:338
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Definition: Transforms.h:885
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.