MLIR  22.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Support/LLVM.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 /// Try to vectorize `convOp` as a convolution.
53 static FailureOr<Operation *>
54 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
55  ArrayRef<int64_t> inputVecSizes = {},
56  ArrayRef<bool> inputVecScalableFlags = {},
57  bool flatten1DDepthwiseConv = false);
58 
59 /// Vectorize tensor::InsertSliceOp with:
60 /// * vector::TransferReadOp + vector::TransferWriteOp
61 /// The vector sizes are either:
62 /// * user-provided in `inputVectorSizes`, or
63 /// * inferred from the static dims in the input and output tensors.
64 /// Bails out if:
65 /// * vector sizes are not user-provided, and
66 /// * at least one dim is dynamic (in both the input and output tensors).
67 ///
68 /// Before:
69 /// !t_in_type = tensor<1x2x3xf32>
70 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
71 /// !v_type = vector<1x2x3xf32>
72 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
73 /// into !t_out_type
74 /// After:
75 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
76 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
77 static LogicalResult
78 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
79  ArrayRef<int64_t> inputVectorSizes,
80  SmallVectorImpl<Value> &newResults);
81 
82 /// Returns the effective Pad value for the input op, provided it's a scalar.
83 ///
84 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
85 /// this Op performs padding, retrieve the padding value provided that it's
86 /// a scalar and static/fixed for all the padded values. Returns an empty value
87 /// otherwise.
88 static Value getStaticPadVal(Operation *op);
89 
90 /// Return the unique instance of OpType in `block` if it is indeed unique.
91 /// Return null if none or more than 1 instances exist.
92 template <typename OpType>
93 static OpType getSingleOpOfType(Block &block) {
94  OpType res;
95  block.walk([&](OpType op) {
96  if (res) {
97  res = nullptr;
98  return WalkResult::interrupt();
99  }
100  res = op;
101  return WalkResult::advance();
102  });
103  return res;
104 }
105 
106 /// Helper function to extract the input slices after filter is unrolled along
107 /// kw.
108 static SmallVector<Value>
110  int64_t nSize, int64_t wSize, int64_t cSize,
111  int64_t kwSize, int strideW, int dilationW,
112  int64_t wSizeStep, bool isSingleChanneled) {
113  SmallVector<Value> result;
114  if (isSingleChanneled) {
115  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
116  // convolution.
117  SmallVector<int64_t> sizes = {wSizeStep};
118  SmallVector<int64_t> strides = {1};
119  for (int64_t kw = 0; kw < kwSize; ++kw) {
120  for (int64_t w = 0; w < wSize; w += wSizeStep) {
121  result.push_back(vector::ExtractStridedSliceOp::create(
122  rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
123  strides));
124  }
125  }
126  } else {
127  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128  // for channeled convolution.
129  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130  SmallVector<int64_t> strides = {1, 1, 1};
131  for (int64_t kw = 0; kw < kwSize; ++kw) {
132  for (int64_t w = 0; w < wSize; w += wSizeStep) {
133  result.push_back(vector::ExtractStridedSliceOp::create(
134  rewriter, loc, input,
135  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136  sizes, strides));
137  }
138  }
139  }
140  return result;
141 }
142 
143 /// Helper function to extract the filter slices after filter is unrolled along
144 /// kw.
146  Location loc, Value filter,
147  int64_t kwSize) {
148  SmallVector<Value> result;
149  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150  // non-chanelled convolution] @ [kw].
151  for (int64_t kw = 0; kw < kwSize; ++kw) {
152  result.push_back(vector::ExtractOp::create(
153  rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
154  }
155  return result;
156 }
157 
158 /// Helper function to extract the result slices after filter is unrolled along
159 /// kw.
160 static SmallVector<Value>
162  int64_t nSize, int64_t wSize, int64_t fSize,
163  int64_t wSizeStep, bool isSingleChanneled) {
164  SmallVector<Value> result;
165  if (isSingleChanneled) {
166  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167  SmallVector<int64_t> sizes = {wSizeStep};
168  SmallVector<int64_t> strides = {1};
169  for (int64_t w = 0; w < wSize; w += wSizeStep) {
170  result.push_back(vector::ExtractStridedSliceOp::create(
171  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
172  strides));
173  }
174  } else {
175  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
176  // convolution.
177  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
178  SmallVector<int64_t> strides = {1, 1, 1};
179  for (int64_t w = 0; w < wSize; w += wSizeStep) {
180  result.push_back(vector::ExtractStridedSliceOp::create(
181  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
182  strides));
183  }
184  }
185  return result;
186 }
187 
188 /// Helper function to insert the computed result slices.
190  Value res, int64_t wSize, int64_t wSizeStep,
191  SmallVectorImpl<Value> &resVals,
192  bool isSingleChanneled) {
193 
194  if (isSingleChanneled) {
195  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196  // This does not depend on kw.
197  SmallVector<int64_t> strides = {1};
198  for (int64_t w = 0; w < wSize; w += wSizeStep) {
199  res = vector::InsertStridedSliceOp::create(
200  rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
201  strides);
202  }
203  } else {
204  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
205  // convolution. This does not depend on kw.
206  SmallVector<int64_t> strides = {1, 1, 1};
207  for (int64_t w = 0; w < wSize; w += wSizeStep) {
208  res = vector::InsertStridedSliceOp::create(
209  rewriter, loc, resVals[w], res,
210  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
211  }
212  }
213  return res;
214 }
215 
216 /// Contains the vectorization state and related methods used across the
217 /// vectorization process of a given operation.
219  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
220 
221  /// Initializes the vectorization state, including the computation of the
222  /// canonical vector shape for vectorization.
223  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224  ArrayRef<int64_t> inputVectorSizes,
225  ArrayRef<bool> inputScalableVecDims,
226  bool assumeDynamicDimsMatchVecSizes = false);
227 
228  /// Returns the canonical vector shape used to vectorize the iteration space.
229  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
230 
231  /// Returns the vector dimensions that are scalable in the canonical vector
232  /// shape.
233  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
234 
235  /// Returns a vector type of the provided `elementType` with the canonical
236  /// vector shape and the corresponding fixed/scalable dimensions bit. If
237  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
238  /// accordingly.
240  Type elementType,
241  std::optional<AffineMap> dimPermutation = std::nullopt) const {
243  SmallVector<bool> scalableDims;
244  if (dimPermutation.has_value()) {
245  vectorShape =
246  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
247  scalableDims =
248  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
249  } else {
250  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
252  }
253 
254  return VectorType::get(vectorShape, elementType, scalableDims);
255  }
256 
257  /// Masks an operation with the canonical vector mask if the operation needs
258  /// masking. Returns the masked operation or the original operation if masking
259  /// is not needed. If provided, the canonical mask for this operation is
260  /// permuted using `maybeIndexingMap`.
261  Operation *
262  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
263  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
264 
265 private:
266  /// Initializes the iteration space static sizes using the Linalg op
267  /// information. This may become more complicated in the future.
268  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
270  }
271 
272  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
273  /// all the static and dynamic dimensions of the iteration space to be
274  /// vectorized and store them in `iterSpaceValueSizes`.
275  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276  LinalgOp linalgOp);
277 
278  /// Create or retrieve an existing mask value to mask `opToMask` in the
279  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
280  /// permuted using that permutation map. If a new mask is created, it will be
281  /// cached for future users.
282  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
283  LinalgOp linalgOp,
284  std::optional<AffineMap> maybeMaskingMap);
285 
286  /// Check whether this permutation map can be used for masking. At the
287  /// moment we only make sure that there are no broadcast dimensions, but this
288  /// might change if indexing maps evolve.
289  bool isValidMaskingMap(AffineMap maskingMap) {
290  return maskingMap.getBroadcastDims().size() == 0;
291  }
292 
293  /// Turn the input indexing map into a valid masking map.
294  ///
295  /// The input indexing map may contain "zero" results, e.g.:
296  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
297  /// Applying such maps to canonical vector shapes like this one:
298  /// (1, 16, 16, 4)
299  /// would yield an invalid vector shape like this:
300  /// (16, 16, 1, 0)
301  /// Instead, drop the broadcasting dims that make no sense for masking perm.
302  /// maps:
303  /// (d0, d1, d2, d3) -> (d2, d1, d0)
304  /// This way, the corresponding vector/mask type will be:
305  /// vector<16x16x1xty>
306  /// rather than this invalid Vector type:
307  /// vector<16x16x1x0xty>
308  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
309  return indexingMap.dropZeroResults();
310  }
311 
312  // Holds the compile-time static sizes of the iteration space to vectorize.
313  // Dynamic dimensions are represented using ShapedType::kDynamic.
314  SmallVector<int64_t> iterSpaceStaticSizes;
315 
316  /// Holds the value sizes of the iteration space to vectorize. Static
317  /// dimensions are represented by 'arith.constant' and dynamic
318  /// dimensions by 'tensor/memref.dim'.
319  SmallVector<Value> iterSpaceValueSizes;
320 
321  /// Holds the canonical vector shape used to vectorize the iteration space.
322  SmallVector<int64_t> canonicalVecShape;
323 
324  /// Holds the vector dimensions that are scalable in the canonical vector
325  /// shape.
326  SmallVector<bool> scalableVecDims;
327 
328  /// Holds the active masks for permutations of the canonical vector iteration
329  /// space.
330  DenseMap<AffineMap, Value> activeMaskCache;
331 
332  /// Global vectorization guard for the incoming rewriter. It's initialized
333  /// when the vectorization state is initialized.
334  OpBuilder::InsertionGuard rewriterGuard;
335 
336  /// Do all dynamic dims match the corresponding vector sizes?
337  ///
338  /// When a dynamic tensor/memref dimension matches the corresponding vector
339  /// dimension, masking can be safely skipped, despite the presence of dynamic
340  /// shapes. Use this flag with care and only for cases where you are
341  /// confident the assumption holds.
342  bool assumeDynamicDimsMatchVecSizes = false;
343 };
344 
345 LogicalResult
346 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
347  LinalgOp linalgOp) {
348  // TODO: Support 0-d vectors.
349  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350  if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
351  // Create constant index op for static dimensions.
352  iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
353  rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
354  continue;
355  }
356 
357  // Find an operand defined on this dimension of the iteration space to
358  // extract the runtime dimension size.
359  Value operand;
360  unsigned operandDimPos;
361  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
362  operandDimPos)))
363  return failure();
364 
365  Value dynamicDim =
366  linalgOp.hasPureTensorSemantics()
367  ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
368  operandDimPos)
369  : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370  operandDimPos);
371  iterSpaceValueSizes.push_back(dynamicDim);
372  }
373 
374  return success();
375 }
376 
377 /// Initializes the vectorization state, including the computation of the
378 /// canonical vector shape for vectorization.
379 // TODO: Move this to the constructor when we can remove the failure cases.
381  LinalgOp linalgOp,
382  ArrayRef<int64_t> inputVectorSizes,
383  ArrayRef<bool> inputScalableVecDims,
384  bool assumeDimsMatchVec) {
385  assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
386  // Initialize the insertion point.
387  rewriter.setInsertionPoint(linalgOp);
388 
389  if (!inputVectorSizes.empty()) {
390  // Get the canonical vector shape from the input vector sizes provided. This
391  // path should be taken to vectorize code with dynamic shapes and when using
392  // vector sizes greater than the iteration space sizes.
393  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394  scalableVecDims.append(inputScalableVecDims.begin(),
395  inputScalableVecDims.end());
396  } else {
397  // Compute the canonical vector shape from the operation shape. If there are
398  // dynamic shapes, the operation won't be vectorized. We assume all the
399  // vector dimensions are fixed.
400  canonicalVecShape = linalgOp.getStaticLoopRanges();
401  scalableVecDims.append(linalgOp.getNumLoops(), false);
402  }
403 
404  LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405  LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
406 
407  if (ShapedType::isDynamicShape(canonicalVecShape))
408  return failure();
409 
410  // Initialize iteration space static sizes.
411  initIterSpaceStaticSizes(linalgOp);
412 
413  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
414  // all the static and dynamic dimensions of the iteration space, needed to
415  // compute a mask during vectorization.
416  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417  return failure();
418 
419  return success();
420 }
421 
422 /// Create or retrieve an existing mask value to mask `opToMask` in the
423 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
424 /// using that permutation map. If a new mask is created, it will be cached for
425 /// future users.
426 Value VectorizationState::getOrCreateMaskFor(
427  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
428  std::optional<AffineMap> maybeMaskingMap) {
429 
430  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431  "Ill-formed masking map.");
432 
433  // No mask is needed if the operation is not maskable.
434  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
435  if (!maskableOp)
436  return Value();
437 
438  assert(!maskableOp.isMasked() &&
439  "Masking an operation that is already masked");
440 
441  // If no masking map was provided, use an identity map with the loop dims.
442  assert((!maybeMaskingMap || *maybeMaskingMap) &&
443  "Unexpected null mask permutation map");
444  AffineMap maskingMap =
445  maybeMaskingMap ? *maybeMaskingMap
447  linalgOp.getNumLoops(), rewriter.getContext());
448 
449  LDBG() << "Masking map: " << maskingMap;
450 
451  // Return the active mask for the masking map of this operation if it was
452  // already created.
453  auto activeMaskIt = activeMaskCache.find(maskingMap);
454  if (activeMaskIt != activeMaskCache.end()) {
455  Value mask = activeMaskIt->second;
456  LDBG() << "Reusing mask: " << mask;
457  return mask;
458  }
459 
460  // Compute permuted projection of the iteration space to be masked and the
461  // corresponding mask shape. If the resulting iteration space dimensions are
462  // static and identical to the mask shape, masking is not needed for this
463  // operation.
464  // TODO: Improve this check. Only projected permutation indexing maps are
465  // supported.
466  SmallVector<int64_t> permutedStaticSizes =
467  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
469  auto maskShape = maskType.getShape();
470 
471  LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
472 
473  if (permutedStaticSizes == maskShape) {
474  LDBG() << "Masking is not needed for masking map: " << maskingMap;
475  activeMaskCache[maskingMap] = Value();
476  return Value();
477  }
478 
479  if (assumeDynamicDimsMatchVecSizes) {
480  // While for _dynamic_ dim sizes we can _assume_ that the corresponding
481  // vector sizes match, we still need to check the _static_ dim sizes. Only
482  // then we can be 100% sure that masking is not required.
483  if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
484  [](auto it) {
485  return std::get<0>(it) == ShapedType::kDynamic
486  ? true
487  : std::get<0>(it) == std::get<1>(it);
488  })) {
489  LDBG()
490  << "Dynamic + static dimensions match vector sizes, masking is not "
491  "required.";
492  activeMaskCache[maskingMap] = Value();
493  return Value();
494  }
495  }
496 
497  // Permute the iteration space value sizes to compute the mask upper bounds.
498  SmallVector<Value> upperBounds =
499  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
500  assert(!maskShape.empty() && !upperBounds.empty() &&
501  "Masked 0-d vectors are not supported yet");
502 
503  // Create the mask based on the dimension values.
504  Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505  maskType, upperBounds);
506  LDBG() << "Creating new mask: " << mask;
507  activeMaskCache[maskingMap] = mask;
508  return mask;
509 }
510 
511 Operation *
513  LinalgOp linalgOp,
514  std::optional<AffineMap> maybeIndexingMap) {
515  LDBG() << "Trying to mask: " << *opToMask;
516 
517  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518  if (maybeIndexingMap)
519  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
520 
521  // Create or retrieve mask for this operation.
522  Value mask =
523  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
524 
525  if (!mask) {
526  LDBG() << "No mask required";
527  return opToMask;
528  }
529 
530  // Wrap the operation with a new `vector.mask` and update D-U chain.
531  assert(opToMask && "Expected a valid operation to mask");
532  auto maskOp = cast<vector::MaskOp>(
533  mlir::vector::maskOperation(rewriter, opToMask, mask));
534  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
535 
536  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
537  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
538  maskOpTerminator);
539 
540  LDBG() << "Masked operation: " << *maskOp;
541  return maskOp;
542 }
543 
544 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
545 /// projectedPermutation, compress the unused dimensions to serve as a
546 /// permutation_map for a vector transfer operation.
547 /// For example, given a linalg op such as:
548 ///
549 /// ```
550 /// %0 = linalg.generic {
551 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
552 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
553 /// }
554 /// ins(%0 : tensor<2x3x4xf32>)
555 /// outs(%1 : tensor<5x6xf32>)
556 /// ```
557 ///
558 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
559 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
560 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
562  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
563  "expected projected permutation");
564  auto res = compressUnusedDims(map);
565  assert(res.getNumDims() ==
566  (res.getNumResults() - res.getNumOfZeroResults()) &&
567  "expected reindexed map with same number of dims and results");
568  return res;
569 }
570 
571 /// Helper enum to represent conv1d input traversal order.
572 enum class Conv1DOpOrder {
573  W, // Corresponds to non-channeled 1D convolution operation.
574  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
575  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
576 };
577 
578 /// Helper data structure to represent the result of vectorization for a single
579 /// operation. In certain specific cases, like terminators, we do not want to
580 /// propagate.
582  /// Op failed to vectorize.
583  Failure = 0,
584  /// Op vectorized and custom function took care of replacement logic
586  /// Op vectorized into a new Op whose results will replace original Op's
587  /// results.
588  NewOp
589  // TODO: support values if Op vectorized to Many-Ops whose results we need to
590  // aggregate for replacement.
591 };
592 /// VectorizationHookResult contains the vectorized op returned from a
593 /// CustomVectorizationHook. This is an internal implementation detail of
594 /// linalg vectorization, not to be confused with VectorizationResult.
596  /// Return status from vectorizing the current op.
598  /// New vectorized operation to replace the current op.
599  /// Replacement behavior is specified by `status`.
601 };
602 
603 std::optional<vector::CombiningKind>
605  using ::mlir::vector::CombiningKind;
606 
607  if (!combinerOp)
608  return std::nullopt;
610  .Case<arith::AddIOp, arith::AddFOp>(
611  [&](auto op) { return CombiningKind::ADD; })
612  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
613  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
614  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
615  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
616  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
617  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
618  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
619  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
620  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
621  .Case<arith::MulIOp, arith::MulFOp>(
622  [&](auto op) { return CombiningKind::MUL; })
623  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
624  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
625  .Default([&](auto op) { return std::nullopt; });
626 }
627 
628 /// Check whether `outputOperand` is a reduction with a single combiner
629 /// operation. Return the combiner operation of the reduction. Return
630 /// nullptr otherwise. Multiple reduction operations would impose an
631 /// ordering between reduction dimensions and is currently unsupported in
632 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
633 /// max(min(X))
634 // TODO: use in LinalgOp verification, there is a circular dependency atm.
635 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
636  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
637  unsigned outputPos =
638  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
639  // Only single combiner operations are supported for now.
640  SmallVector<Operation *, 4> combinerOps;
641  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
642  combinerOps.size() != 1)
643  return nullptr;
644 
645  // Return the combiner operation.
646  return combinerOps[0];
647 }
648 
649 /// Broadcast `value` to a vector of `shape` if possible. Return value
650 /// otherwise.
651 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
652  auto dstVecType = dyn_cast<VectorType>(dstType);
653  // If no shape to broadcast to, just return `value`.
654  if (dstVecType.getRank() == 0)
655  return value;
656  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
658  return value;
659  Location loc = b.getInsertionPoint()->getLoc();
660  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
661 }
662 
663 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
664 /// assumes that `reductionOp` has two operands and one of them is the reduction
665 /// initial value.buildMultiDimReduce
666 // Note: this is a true builder that notifies the OpBuilder listener.
667 // TODO: Consider moving as a static helper on the ReduceOp.
669  Value valueToReduce, Value acc,
670  ArrayRef<bool> dimsToMask) {
671  auto maybeKind = getCombinerOpKind(reduceOp);
672  assert(maybeKind && "Failed precondition: could not get reduction kind");
673  return vector::MultiDimReductionOp::create(
674  b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
675 }
676 
677 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
678  return llvm::to_vector(
679  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
680 }
681 
682 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
683 /// reduction iterator.
684 static bool hasReductionIterator(LinalgOp &op) {
685  return isa<linalg::ReduceOp>(op) ||
686  (isa<linalg::GenericOp>(op) &&
687  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
688 }
689 
690 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
691 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
692 /// currently being vectorized. If `dest` has null rank, build an memref.store.
693 /// Return the produced value or null if no value is produced.
694 // Note: this is a true builder that notifies the OpBuilder listener.
695 // TODO: Consider moving as a static helper on the ReduceOp.
696 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
697  OpOperand *outputOperand,
698  VectorizationState &state) {
699  Location loc = value.getLoc();
700  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
701  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
702 
703  // Compute the vector type of the value to store. This type should be an
704  // identity or projection of the canonical vector type without any permutation
705  // applied, given that any permutation in a transfer write happens as part of
706  // the write itself.
708  opOperandMap.getContext(), opOperandMap.getNumInputs(),
709  [&](AffineDimExpr dimExpr) -> bool {
710  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
711  });
712  auto vectorType = state.getCanonicalVecType(
713  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
714 
715  Operation *write;
716  if (vectorType.getRank() > 0) {
717  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
718  SmallVector<Value> indices(
719  linalgOp.getRank(outputOperand),
720  arith::ConstantIndexOp::create(rewriter, loc, 0));
721  value = broadcastIfNeeded(rewriter, value, vectorType);
722  assert(value.getType() == vectorType && "Incorrect type");
723  write = vector::TransferWriteOp::create(
724  rewriter, loc, value, outputOperand->get(), indices, writeMap);
725  } else {
726  // 0-d case is still special: do not invert the reindexing writeMap.
727  if (!isa<VectorType>(value.getType()))
728  value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
729  assert(value.getType() == vectorType && "Incorrect type");
730  write = vector::TransferWriteOp::create(rewriter, loc, value,
731  outputOperand->get(), ValueRange{});
732  }
733 
734  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
735 
736  // If masked, set in-bounds to true. Masking guarantees that the access will
737  // be in-bounds.
738  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
739  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
740  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
741  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
742  }
743 
744  LDBG() << "vectorized op: " << *write;
745  if (!write->getResults().empty())
746  return write->getResult(0);
747  return Value();
748 }
749 
750 // Custom vectorization precondition function type. This is intented to be used
751 // with CustomVectorizationHook. Returns success if the corresponding custom
752 // hook can vectorize the op.
754  std::function<LogicalResult(Operation *, bool)>;
755 
756 // Custom vectorization function type. Produce a vector form of Operation*
757 // assuming all its vectorized operands are already in the IRMapping.
758 // Return nullptr if the Operation cannot be vectorized.
760  std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
761 
762 /// Helper function to vectorize the terminator of a `linalgOp`. New result
763 /// vector values are appended to `newResults`. Return
764 /// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
765 /// that it should not try to map produced operations and instead return the
766 /// results using the `newResults` vector making them available to the
767 /// vectorization algorithm for RAUW. This function is meant to be used as a
768 /// CustomVectorizationHook.
771  const IRMapping &bvm, VectorizationState &state,
772  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
773  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
774  if (!yieldOp)
776  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
777  // TODO: Scan for an opportunity for reuse.
778  // TODO: use a map.
779  Value vectorValue = bvm.lookup(output.value());
780  Value newResult =
781  buildVectorWrite(rewriter, vectorValue,
782  linalgOp.getDpsInitOperand(output.index()), state);
783  if (newResult)
784  newResults.push_back(newResult);
785  }
786 
788 }
789 
790 /// Helper function to vectorize the index operations of a `linalgOp`. Return
791 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
792 /// should map the produced operations. This function is meant to be used as a
793 /// CustomVectorizationHook.
795  VectorizationState &state,
796  Operation *op,
797  LinalgOp linalgOp) {
798  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
799  if (!indexOp)
801  auto loc = indexOp.getLoc();
802  // Compute the static loop sizes of the index op.
803  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
804  auto dim = indexOp.getDim();
805  // Compute a one-dimensional index vector for the index op dimension.
806  auto indexVectorType =
807  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
808  state.getScalableVecDims()[dim]);
809  auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
810  // Return the one-dimensional index vector if it lives in the trailing
811  // dimension of the iteration space since the vectorization algorithm in this
812  // case can handle the broadcast.
813  if (dim == targetShape.size() - 1)
815  // Otherwise permute the targetShape to move the index dimension last,
816  // broadcast the one-dimensional index vector to the permuted shape, and
817  // finally transpose the broadcasted index vector to undo the permutation.
818  auto permPattern =
819  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
820  std::swap(permPattern[dim], permPattern.back());
821  auto permMap =
822  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
823 
824  auto broadCastOp = vector::BroadcastOp::create(
825  rewriter, loc,
826  state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
827  SmallVector<int64_t> transposition =
828  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
829  std::swap(transposition.back(), transposition[dim]);
830  auto transposeOp =
831  vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
833 }
834 
835 /// Helper function to check if the tensor.extract can be vectorized by the
836 /// custom hook vectorizeTensorExtract.
837 static LogicalResult
838 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
839  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
840  if (!extractOp)
841  return failure();
842 
843  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
844  return failure();
845 
846  // Check the index type, but only for non 0-d tensors (for which we do need
847  // access indices).
848  if (not extractOp.getIndices().empty()) {
849  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
850  return failure();
851  }
852 
853  if (!llvm::all_of(extractOp->getResultTypes(),
854  VectorType::isValidElementType)) {
855  return failure();
856  }
857 
858  return success();
859 }
860 
861 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
862 /// generated from `tensor.extract`. The offset is calculated as follows
863 /// (example using scalar values):
864 ///
865 /// offset = extractOp.indices[0]
866 /// for (i = 1; i < numIndices; i++)
867 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
868 ///
869 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
870 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
872  VectorizationState &state,
873  tensor::ExtractOp extractOp,
874  const IRMapping &bvm) {
875  // The vector of indices for GatherOp should be shaped as the output vector.
876  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
877  auto loc = extractOp.getLoc();
878 
879  Value offset = broadcastIfNeeded(
880  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
881 
882  const size_t numIndices = extractOp.getIndices().size();
883  for (size_t i = 1; i < numIndices; i++) {
884  Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
885 
886  auto dimSize = broadcastIfNeeded(
887  rewriter,
888  tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
889  indexVecType);
890 
891  offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
892 
893  auto extractOpIndex = broadcastIfNeeded(
894  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
895 
896  offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
897  }
898 
899  return offset;
900 }
901 
903 
904 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
905 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
906 /// represents a contiguous load operation.
907 ///
908 /// Note that when calling this hook, it is assumed that the output vector is
909 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
910 /// labelled as a gather load before entering this method.
911 ///
912 /// Following on from the above, it is assumed that:
913 /// * for statically shaped loops, when no masks are used, only one dim is !=
914 /// 1 (that's what the shape of the output vector is based on).
915 /// * for dynamically shaped loops, there might be more non-unit dims
916 /// as the output vector type is user-specified.
917 ///
918 /// TODO: Statically shaped loops + vector masking
919 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
920  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
921  assert(
922  (linalgOp.hasDynamicShape() ||
923  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
924  "For statically shaped Linalg Ops, only one "
925  "non-unit loop dim is expected");
926  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
927 
928  size_t idx = loopRanges.size() - 1;
929  for (; idx != 0; idx--)
930  if (loopRanges[idx] != 1)
931  break;
932 
933  return idx;
934 }
935 
936 /// Checks whether `val` can be used for calculating a loop invariant index.
937 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
938  VectorType resType) {
939 
940  assert(((llvm::count_if(resType.getShape(),
941  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942  "n-D vectors are not yet supported");
943 
944  // Blocks outside _this_ linalg.generic are effectively loop invariant.
945  // However, analysing block arguments for _this_ linalg.generic Op is a bit
946  // tricky. Just bail out in the latter case.
947  // TODO: We could try analysing the corresponding affine map here.
948  auto *block = linalgOp.getBlock();
949  if (isa<BlockArgument>(val))
950  return !llvm::is_contained(block->getArguments(), val);
951 
952  Operation *defOp = val.getDefiningOp();
953  assert(defOp && "This is neither a block argument nor an operation result");
954 
955  // IndexOp is loop invariant as long as its result remains constant across
956  // iterations. Note that for dynamic shapes, the corresponding dim will also
957  // be conservatively treated as != 1.
958  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
960  }
961 
962  auto *ancestor = block->findAncestorOpInBlock(*defOp);
963 
964  // Values define outside `linalgOp` are loop invariant.
965  if (!ancestor)
966  return true;
967 
968  // Values defined inside `linalgOp`, which are constant, are loop invariant.
969  if (isa<arith::ConstantOp>(ancestor))
970  return true;
971 
972  bool result = true;
973  for (auto op : ancestor->getOperands())
974  result &= isLoopInvariantIdx(linalgOp, op, resType);
975 
976  return result;
977 }
978 
979 /// Check whether `val` could be used for calculating the trailing index for a
980 /// contiguous load operation.
981 ///
982 /// There are currently 3 types of values that are allowed here:
983 /// 1. loop-invariant values,
984 /// 2. values that increment by 1 with every loop iteration,
985 /// 3. results of basic arithmetic operations (linear and continuous)
986 /// involving 1., 2. and 3.
987 /// This method returns True if indeed only such values are used in calculating
988 /// `val.`
989 ///
990 /// Additionally, the trailing index for a contiguous load operation should
991 /// increment by 1 with every loop iteration, i.e. be based on:
992 /// * `linalg.index <dim>` ,
993 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
994 /// `linalg.index <dim>` increments by 1 with every loop iteration).
995 /// `foundIndexOp` is updated to `true` when such Op is found.
996 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
997  bool &foundIndexOp, VectorType resType) {
998 
999  assert(((llvm::count_if(resType.getShape(),
1000  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1001  "n-D vectors are not yet supported");
1002 
1003  // Blocks outside _this_ linalg.generic are effectively loop invariant.
1004  // However, analysing block arguments for _this_ linalg.generic Op is a bit
1005  // tricky. Just bail out in the latter case.
1006  // TODO: We could try analysing the corresponding affine map here.
1007  auto *block = linalgOp.getBlock();
1008  if (isa<BlockArgument>(val))
1009  return !llvm::is_contained(block->getArguments(), val);
1010 
1011  Operation *defOp = val.getDefiningOp();
1012  assert(defOp && "This is neither a block argument nor an operation result");
1013 
1014  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1015  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1016 
1017  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1018  return true;
1019  }
1020 
1021  auto *ancestor = block->findAncestorOpInBlock(*defOp);
1022 
1023  if (!ancestor)
1024  return false;
1025 
1026  // Conservatively reject Ops that could lead to indices with stride other
1027  // than 1.
1028  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1029  return false;
1030 
1031  bool result = false;
1032  for (auto op : ancestor->getOperands())
1033  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1034 
1035  return result;
1036 }
1037 
1038 /// Infer the memory access pattern for the input ExtractOp
1039 ///
1040 /// Based on the ExtratOp result shape and the access indices, decides whether
1041 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1042 /// or a gather load. When analysing the ExtractOp indices (to identify
1043 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1044 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1045 /// Op).
1046 ///
1047 /// Note that it is always safe to use gather load operations for contiguous
1048 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1049 /// that `extractOp` is a gather load.
1051 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1052  LinalgOp &linalgOp, VectorType resType) {
1053 
1054  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1055 
1056  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1057  if (inputShape.getShape().empty())
1059 
1060  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1061  // otherwise.
1062  bool isOutput1DVector =
1063  (llvm::count_if(resType.getShape(),
1064  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1065  // 1. Assume that it's a gather load when reading non-1D vector.
1066  if (!isOutput1DVector)
1068 
1069  bool leadingIdxsLoopInvariant = true;
1070 
1071  // 2. Analyze the leading indices of `extractOp`.
1072  // Look at the way each index is calculated and decide whether it is suitable
1073  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1074  // gather load.
1075  auto indices = extractOp.getIndices();
1076  auto leadIndices = indices.drop_back(1);
1077 
1078  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1079  if (inputShape.getShape()[i] == 1)
1080  continue;
1081 
1082  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1083  }
1084 
1085  if (!leadingIdxsLoopInvariant) {
1086  LDBG() << "Found gather load: " << extractOp;
1088  }
1089 
1090  // 3. Analyze the trailing index for `extractOp`.
1091  // At this point we know that the leading indices are loop invariant. This
1092  // means that is potentially a scalar or a contiguous load. We can decide
1093  // based on the trailing idx.
1094  auto extractOpTrailingIdx = indices.back();
1095 
1096  // 3a. Scalar broadcast load
1097  // If the trailing index is loop invariant then this is a scalar load.
1098  if (leadingIdxsLoopInvariant &&
1099  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1100  LDBG() << "Found scalar broadcast load: " << extractOp;
1101 
1103  }
1104 
1105  // 3b. Contiguous loads
1106  // The trailing `extractOp` index should increment with every loop iteration.
1107  // This effectively means that it must be based on the trailing loop index.
1108  // This is what the following bool captures.
1109  bool foundIndexOp = false;
1110  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1111  foundIndexOp, resType);
1112  // TODO: Support generating contiguous loads for column vectors - that will
1113  // require adding a permutation map to tranfer_read Ops.
1114  bool isRowVector = resType.getShape().back() != 1;
1115  isContiguousLoad &= (foundIndexOp && isRowVector);
1116 
1117  if (isContiguousLoad) {
1118  LDBG() << "Found contigous load: " << extractOp;
1120  }
1121 
1122  // 4. Fallback case - gather load.
1123  LDBG() << "Found gather load: " << extractOp;
1125 }
1126 
1127 /// Helper function to vectorize the tensor.extract operations. Returns
1128 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1129 /// should map the produced operations. This function is meant to be used as a
1130 /// CustomVectorizationHook.
1133  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1134  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1135  if (!extractOp)
1137  auto loc = extractOp.getLoc();
1138 
1139  // Compute the static loop sizes of the extract op.
1140  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1141  auto maskConstantOp = arith::ConstantOp::create(
1142  rewriter, loc,
1143  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1144  /*value=*/true));
1145  auto passThruConstantOp = arith::ConstantOp::create(
1146  rewriter, loc, rewriter.getZeroAttr(resultType));
1147 
1148  // Base indices are currently set to 0. We will need to re-visit if more
1149  // generic scenarios are to be supported.
1150  SmallVector<Value> baseIndices(
1151  extractOp.getIndices().size(),
1152  arith::ConstantIndexOp::create(rewriter, loc, 0));
1153 
1154  VectorMemoryAccessKind memAccessKind =
1155  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1156 
1157  // 1. Handle gather access
1158  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1159  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1160 
1161  // Generate the gather load
1162  Operation *gatherOp = vector::GatherOp::create(
1163  rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1164  maskConstantOp, passThruConstantOp);
1165  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1166 
1167  LDBG() << "Vectorised as gather load: " << extractOp;
1169  }
1170 
1171  // 2. Handle:
1172  // a. scalar loads + broadcast,
1173  // b. contiguous loads.
1174  // Both cases use vector.transfer_read.
1175 
1176  // Collect indices for `vector.transfer_read`. At this point, the indices will
1177  // either be scalars or would have been broadcast to vectors matching the
1178  // result type. For indices that are vectors, there are two options:
1179  // * for non-trailing indices, all elements are identical (contiguous
1180  // loads are identified by looking for non-trailing indices that are
1181  // invariant with respect to the corresponding linalg.generic), or
1182  // * for trailing indices, the index vector will contain values with stride
1183  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1184  // needed.
1185  // This means that
1186  // * for scalar indices - just re-use it,
1187  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1188  // (0th) element and use that.
1189  SmallVector<Value> transferReadIdxs;
1190  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1191  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1192  if (idx.getType().isIndex()) {
1193  transferReadIdxs.push_back(idx);
1194  continue;
1195  }
1196 
1197  auto indexAs1dVector = vector::ShapeCastOp::create(
1198  rewriter, loc,
1199  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1200  resultType.getScalableDims().back()),
1201  idx);
1202  transferReadIdxs.push_back(
1203  vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1204  }
1205 
1206  // `tensor.extract_element` is always in-bounds, hence the following holds.
1207  auto dstRank = resultType.getRank();
1208  auto srcRank = extractOp.getTensor().getType().getRank();
1209  SmallVector<bool> inBounds(dstRank, true);
1210 
1211  // 2a. Handle scalar broadcast access.
1212  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1213  MLIRContext *ctx = rewriter.getContext();
1214  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1215  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1216 
1217  auto transferReadOp = vector::TransferReadOp::create(
1218  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1219  /*padding=*/std::nullopt, permutationMap, inBounds);
1220 
1221  // Mask this broadcasting xfer_read here rather than relying on the generic
1222  // path (the generic path assumes identity masking map, which wouldn't be
1223  // valid here).
1224  SmallVector<int64_t> readMaskShape = {1};
1225  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1226  auto allTrue = vector::ConstantMaskOp::create(
1227  rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1228  auto *maskedReadOp =
1229  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1230 
1231  LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1233  maskedReadOp};
1234  }
1235 
1236  // 2b. Handle contiguous access.
1237  auto permutationMap = AffineMap::getMinorIdentityMap(
1238  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1239 
1240  int32_t rankDiff = dstRank - srcRank;
1241  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1242  // dims so that the ranks match. This is done by extending the map with 0s.
1243  // For example, for dstRank = 3, srcRank = 2, the following map created
1244  // above:
1245  // (d0, d1) --> (d0, d1)
1246  // is extended as:
1247  // (d0, d1) --> (0, d0, d1)
1248  while (rankDiff > 0) {
1249  permutationMap = permutationMap.insertResult(
1250  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1251  rankDiff--;
1252  }
1253 
1254  auto transferReadOp = vector::TransferReadOp::create(
1255  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1256  /*padding=*/std::nullopt, permutationMap, inBounds);
1257 
1258  LDBG() << "Vectorised as contiguous load: " << extractOp;
1260  transferReadOp};
1261 }
1262 
1263 /// Emit reduction operations if the shapes of the value to reduce is different
1264 /// that the result shape.
1265 // Note: this is a true builder that notifies the OpBuilder listener.
1266 // TODO: Consider moving as a static helper on the ReduceOp.
1267 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1268  Value reduceValue, Value initialValue,
1269  const IRMapping &bvm) {
1270  Value reduceVec = bvm.lookup(reduceValue);
1271  Value outputVec = bvm.lookup(initialValue);
1272  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1273  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1274  // Reduce only if needed as the value may already have been reduce for
1275  // contraction vectorization.
1276  if (!reduceType ||
1277  (outputType && reduceType.getShape() == outputType.getShape()))
1278  return nullptr;
1279  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1280  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1281 }
1282 
1283 /// Generic vectorization for a single operation `op`, given already vectorized
1284 /// operands carried by `bvm`. Vectorization occurs as follows:
1285 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1286 /// result on success.
1287 /// 2. Clone any constant in the current scope without vectorization: each
1288 /// consumer of the constant will later determine the shape to which the
1289 /// constant needs to be broadcast to.
1290 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1291 /// of the `customVectorizationHooks` to cover such cases.
1292 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1293 /// operand of maximal rank. Other operands have smaller rank and are
1294 /// broadcast accordingly. It is assumed this broadcast is always legal,
1295 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1296 ///
1297 /// This function assumes all operands of `op` have been vectorized and are in
1298 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1299 /// a topologically-sorted list of ops.
1300 /// This function does not update `bvm` but returns a VectorizationHookStatus
1301 /// that instructs the caller what `bvm` update needs to occur.
1304  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1305  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1306  LDBG() << "vectorize op " << *op;
1307 
1308  // 1. Try to apply any CustomVectorizationHook.
1309  if (!customVectorizationHooks.empty()) {
1310  for (auto &customFunc : customVectorizationHooks) {
1311  VectorizationHookResult result = customFunc(op, bvm);
1313  continue;
1314  return result;
1315  }
1316  }
1317 
1318  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1319  // Clone so that the constant is not confined to the linalgOp block .
1320  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1322  rewriter.clone(*op)};
1323 
1324  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1327 
1328  // 4 . Check if the operation is a reduction.
1329  SmallVector<std::pair<Value, Value>> reductionOperands;
1330  for (Value operand : op->getOperands()) {
1331  auto blockArg = dyn_cast<BlockArgument>(operand);
1332  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1333  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1334  continue;
1335  SmallVector<Operation *> reductionOps;
1336  Value reduceValue = matchReduction(
1337  linalgOp.getRegionOutputArgs(),
1338  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1339  if (!reduceValue)
1340  continue;
1341  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1342  }
1343  if (!reductionOperands.empty()) {
1344  assert(reductionOperands.size() == 1);
1345  Operation *reduceOp =
1346  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1347  reductionOperands[0].second, bvm);
1348  if (reduceOp)
1350  }
1351 
1352  // 5. Generic vectorization path for ElementwiseMappable ops.
1353  // a. Get the first max ranked shape.
1354  VectorType firstMaxRankedType;
1355  for (Value operand : op->getOperands()) {
1356  auto vecOperand = bvm.lookup(operand);
1357  assert(vecOperand && "Vector operand couldn't be found");
1358 
1359  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1360  if (vecType && (!firstMaxRankedType ||
1361  firstMaxRankedType.getRank() < vecType.getRank()))
1362  firstMaxRankedType = vecType;
1363  }
1364  // b. Broadcast each op if needed.
1365  SmallVector<Value> vecOperands;
1366  for (Value scalarOperand : op->getOperands()) {
1367  Value vecOperand = bvm.lookup(scalarOperand);
1368  assert(vecOperand && "Vector operand couldn't be found");
1369 
1370  if (firstMaxRankedType) {
1371  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1372  getElementTypeOrSelf(vecOperand.getType()),
1373  firstMaxRankedType.getScalableDims());
1374  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1375  } else {
1376  vecOperands.push_back(vecOperand);
1377  }
1378  }
1379  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1380  SmallVector<Type> resultTypes;
1381  for (Type resultType : op->getResultTypes()) {
1382  resultTypes.push_back(
1383  firstMaxRankedType
1384  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1385  firstMaxRankedType.getScalableDims())
1386  : resultType);
1387  }
1388  // d. Build and return the new op.
1389  return VectorizationHookResult{
1391  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1392  resultTypes, op->getAttrs())};
1393 }
1394 
1395 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1396 /// vector form. Generic vectorization proceeds as follows:
1397 /// 1. Verify the `linalgOp` has one non-empty region.
1398 /// 2. Values defined above the region are mapped to themselves and will be
1399 /// broadcasted on a per-need basis by their consumers.
1400 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1401 /// load).
1402 /// TODO: Reuse opportunities for RAR dependencies.
1403 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1404 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1405 /// iteration indices.
1406 /// 5. Iteratively call vectorizeOneOp on the region operations.
1407 ///
1408 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1409 /// performed to the maximal common vector size implied by the `linalgOp`
1410 /// iteration space. This eager broadcasting is introduced in the
1411 /// permutation_map of the vector.transfer_read operations. The eager
1412 /// broadcasting makes it trivial to determine where broadcast, transposes and
1413 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1414 /// the absence of good canonicalizations, the amount of work increases.
1415 /// This is not deemed a problem as we expect canonicalizations and foldings to
1416 /// aggressively clean up the useless work.
1417 static LogicalResult
1419  LinalgOp linalgOp,
1420  SmallVectorImpl<Value> &newResults) {
1421  LDBG() << "Vectorizing operation as linalg generic/n";
1422  Block *block = linalgOp.getBlock();
1423 
1424  // 2. Values defined above the region can only be broadcast for now. Make them
1425  // map to themselves.
1426  IRMapping bvm;
1427  SetVector<Value> valuesSet;
1428  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1429  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1430 
1431  if (linalgOp.getNumDpsInits() == 0)
1432  return failure();
1433 
1434  // 3. Turn all BBArgs into vector.transfer_read / load.
1435  Location loc = linalgOp.getLoc();
1436  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1437  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1438  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1439  if (linalgOp.isScalar(opOperand)) {
1440  bvm.map(bbarg, opOperand->get());
1441  continue;
1442  }
1443 
1444  // 3.a. Convert the indexing map for this input/output to a transfer read
1445  // permutation map and masking map.
1446  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1447 
1448  AffineMap readMap;
1449  VectorType readType;
1450  Type elemType = getElementTypeOrSelf(opOperand->get());
1451  if (linalgOp.isDpsInput(opOperand)) {
1452  // 3.a.i. For input reads we use the canonical vector shape.
1453  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1454  readType = state.getCanonicalVecType(elemType);
1455  } else {
1456  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1457  // reductions), the vector shape is computed by mapping the canonical
1458  // vector shape to the output domain and back to the canonical domain.
1459  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1460  readType =
1461  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1462  }
1463 
1464  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1465 
1466  Operation *read = vector::TransferReadOp::create(
1467  rewriter, loc, readType, opOperand->get(), indices,
1468  /*padding=*/std::nullopt, readMap);
1469  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1470  Value readValue = read->getResult(0);
1471 
1472  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1473  // will be in-bounds.
1474  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1475  SmallVector<bool> inBounds(readType.getRank(), true);
1476  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1477  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1478  }
1479 
1480  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1481  // TODO: remove this.
1482  if (readType.getRank() == 0)
1483  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1484  ArrayRef<int64_t>());
1485 
1486  LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1487  << "): " << readValue;
1488  bvm.map(bbarg, readValue);
1489  bvm.map(opOperand->get(), readValue);
1490  }
1491 
1493  // 4a. Register CustomVectorizationHook for yieldOp.
1494  CustomVectorizationHook vectorizeYield =
1495  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1496  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1497  };
1498  hooks.push_back(vectorizeYield);
1499 
1500  // 4b. Register CustomVectorizationHook for indexOp.
1501  CustomVectorizationHook vectorizeIndex =
1502  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1503  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1504  };
1505  hooks.push_back(vectorizeIndex);
1506 
1507  // 4c. Register CustomVectorizationHook for extractOp.
1508  CustomVectorizationHook vectorizeExtract =
1509  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1510  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1511  };
1512  hooks.push_back(vectorizeExtract);
1513 
1514  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1515  for (Operation &op : block->getOperations()) {
1516  VectorizationHookResult result =
1517  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1518  if (result.status == VectorizationHookStatus::Failure) {
1519  LDBG() << "failed to vectorize: " << op;
1520  return failure();
1521  }
1522  if (result.status == VectorizationHookStatus::NewOp) {
1523  Operation *maybeMaskedOp =
1524  state.maskOperation(rewriter, result.newOp, linalgOp);
1525  LDBG() << "New vector op: " << *maybeMaskedOp;
1526  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1527  }
1528  }
1529 
1530  return success();
1531 }
1532 
1533 /// Given a linalg::PackOp, return the `dest` shape before any packing
1534 /// permutations.
1535 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1536  ArrayRef<int64_t> destShape) {
1537  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1538 }
1539 
1540 /// Determines whether a mask for xfer_write is trivially "all true"
1541 ///
1542 /// Given all the inputs required to generate a mask (mask sizes and shapes),
1543 /// and an xfer_write operation (write indices and the destination tensor
1544 /// shape), determines whether the corresponding mask would be trivially
1545 /// foldable (i.e., trivially "all true").
1546 ///
1547 /// Use this method to avoid generating spurious masks and relaying on
1548 /// vectorization post-processing to remove them.
1549 ///
1550 /// Pre-conditions for a mask to be trivially foldable:
1551 /// * All involved shapes (mask + destination tensor) are static.
1552 /// * All write indices are constant.
1553 /// * All mask sizes are constant (including `arith.constant`).
1554 ///
1555 /// If the pre-conditions are met, the method checks for each destination
1556 /// dimension `d`:
1557 /// (1) destDimSize[rankDiff + d] <= maskShape[d]
1558 /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1559 ///
1560 /// rankDiff = rank(dest) - rank(mask).
1561 ///
1562 /// This method takes a conservative view: it may return false even if the mask
1563 /// is technically foldable.
1564 ///
1565 /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1566 /// of the dest tensor):
1567 /// %c0 = arith.constant 0 : index
1568 /// %mask = vector.create_mask 5, 1
1569 /// vector.mask %mask {
1570 /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1571 /// {in_bounds = [true, true]}
1572 /// : vector<5x1xi32>, tensor<5x1xi32>
1573 /// }
1574 ///
1575 /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1576 /// mask is required to avoid out-of-bounds write):
1577 /// %c0 = arith.constant 0 : index
1578 /// %mask = vector.create_mask 5, 1
1579 /// vector.mask %mask {
1580 /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1581 /// {in_bounds = [true, true]}
1582 /// : vector<8x1xi32>, tensor<5x1xi32>
1583 /// }
1584 ///
1585 /// TODO: Re-use in createReadOrMaskedRead
1587  SmallVector<Value> &writeIdxs,
1588  ArrayRef<int64_t> destShape,
1589  ArrayRef<int64_t> maskShape) {
1590  // Masking is unavoidable in the case of dynamic tensors.
1591  if (ShapedType::isDynamicShape(destShape))
1592  return false;
1593 
1594  // Collect all constant mask sizes.
1595  SmallVector<int64_t, 4> cstMaskSizes;
1596  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1597  if (auto intSize = getConstantIntValue(dimSize)) {
1598  cstMaskSizes.push_back(*intSize);
1599  }
1600  }
1601 
1602  // If any of the mask sizes is non-constant, bail out.
1603  if (cstMaskSizes.size() != maskShape.size())
1604  return false;
1605 
1606  // Collect all constant write indices.
1607  SmallVector<int64_t, 4> cstWriteIdxs;
1608  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1609  APSInt intVal;
1610  if (matchPattern(idx, m_ConstantInt(&intVal))) {
1611  cstWriteIdxs.push_back(intVal.getSExtValue());
1612  }
1613  }
1614 
1615  // If any of the write indices is non-constant, bail out.
1616  if (cstWriteIdxs.size() != destShape.size())
1617  return false;
1618 
1619  // Go over all destination dims and check (1) and (2). Take into account that:
1620  // * The number of mask sizes will match the rank of the vector to store.
1621  // This could be lower than the rank of the destination tensor.
1622  // * Mask sizes could be larger than the corresponding mask shape (hence
1623  // `clamp`).
1624  // TODO: The 2nd item should be rejected by the verifier.
1625  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1626  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1627  if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1628  /*(2)*/ destShape[rankDiff + i] <
1629  (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1630  cstWriteIdxs[i]))
1631  return false;
1632  }
1633 
1634  return true;
1635 }
1636 
1637 /// Creates an optionally masked TransferWriteOp
1638 ///
1639 /// Generates the following operation:
1640 /// %res = vector.transfer_write %vecToStore into %dest
1641 ///
1642 /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1643 ///
1644 /// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1645 /// %res = vector.mask %mask {
1646 /// vector.transfer_write %vecToStore into %dest
1647 /// }
1648 ///
1649 /// The mask shape is identical to `vecToStore` (with the element type ==
1650 /// i1), and the mask values are based on the shape of the `dest` tensor.
1651 ///
1652 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1653 /// is used instead of masking:
1654 ///
1655 /// %write = vector.transfer_write %vecToStore into %dest
1656 /// in_bounds_flags = (...)
1657 /// %res = vector.transfer_write %input into %dest
1658 /// {in_bounds = in_bounds_flags}
1659 ///
1660 /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1661 /// are set to 0.
1662 static Operation *
1664  Value dest, SmallVector<Value> writeIndices = {},
1665  bool useInBoundsInsteadOfMasking = false) {
1666 
1667  ShapedType destType = cast<ShapedType>(dest.getType());
1668  int64_t destRank = destType.getRank();
1669  auto destShape = destType.getShape();
1670 
1671  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1672  int64_t vecToStoreRank = vecToStoreType.getRank();
1673  auto vecToStoreShape = vecToStoreType.getShape();
1674 
1675  // Compute the in_bounds attribute
1676  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1677  if (useInBoundsInsteadOfMasking) {
1678  // Update the inBounds attribute.
1679  // FIXME: This computation is too weak - it ignores the write indices.
1680  for (unsigned i = 0; i < vecToStoreRank; i++)
1681  inBoundsVal[i] =
1682  (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1683  ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1684  }
1685 
1686  // If missing, initialize the write indices to 0.
1687  assert((writeIndices.empty() ||
1688  writeIndices.size() == static_cast<size_t>(destRank)) &&
1689  "Invalid number of write indices!");
1690  if (writeIndices.empty()) {
1691  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1692  writeIndices.assign(destRank, zero);
1693  }
1694 
1695  // Generate the xfer_write Op
1696  Operation *write = vector::TransferWriteOp::create(builder, loc,
1697  /*vector=*/vecToStore,
1698  /*source=*/dest,
1699  /*indices=*/writeIndices,
1700  /*inBounds=*/inBoundsVal);
1701 
1702  // If masking is disabled, exit.
1703  if (useInBoundsInsteadOfMasking)
1704  return write;
1705 
1706  // Check if masking is needed. If not, exit.
1707  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1708  return write;
1709 
1710  // Compute the mask and mask the write Op.
1711  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1712  vecToStoreType.getScalableDims());
1713 
1714  SmallVector<OpFoldResult> destSizes =
1715  isa<MemRefType>(dest.getType())
1716  ? memref::getMixedSizes(builder, loc, dest)
1717  : tensor::getMixedSizes(builder, loc, dest);
1718  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1719  destSizes.end());
1720 
1721  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1722  vecToStoreShape))
1723  return write;
1724 
1725  Value maskForWrite =
1726  builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1727  return mlir::vector::maskOperation(builder, write, maskForWrite);
1728 }
1729 
1730 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1731 /// padding value and (3) input vector sizes into:
1732 ///
1733 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1734 ///
1735 /// As in the following example:
1736 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1737 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1738 ///
1739 /// This pack would be vectorized to:
1740 ///
1741 /// %load = vector.mask %mask {
1742 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1743 /// {in_bounds = [true, true, true]} :
1744 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1745 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1746 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1747 /// to vector<32x4x2x1x16xf32>
1748 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1749 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1750 /// %write = vector.transfer_write %transpose,
1751 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1752 /// {in_bounds = [true, true, true, true, true]}
1753 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1754 ///
1755 /// If the (3) input vector sizes are not provided, the vector sizes are
1756 /// determined by the result tensor shape and the `in_bounds`
1757 /// attribute is used instead of masking to mark out-of-bounds accesses.
1758 ///
1759 /// NOTE: The input vector sizes specify the dimensions corresponding to the
1760 /// outer dimensions of the output tensor. The remaining dimensions are
1761 /// computed based on, e.g., the static inner tiles.
1762 /// Supporting dynamic inner tiles will require the user to specify the
1763 /// missing vector sizes. This is left as a TODO.
1764 static LogicalResult
1765 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1766  ArrayRef<int64_t> inputVectorSizes,
1767  SmallVectorImpl<Value> &newResults) {
1768  // TODO: Introduce a parent class that will handle the insertion point update.
1769  OpBuilder::InsertionGuard g(rewriter);
1770  rewriter.setInsertionPoint(packOp);
1771 
1772  Location loc = packOp.getLoc();
1773  auto padValue = packOp.getPaddingValue();
1774  if (!padValue) {
1775  padValue = arith::ConstantOp::create(
1776  rewriter, loc,
1777  rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1778  }
1779  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1780  LogicalResult status =
1781  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1782  .reifyResultShapes(rewriter, reifiedReturnShapes);
1783  (void)status; // prevent unused variable warning on non-assert builds.
1784  assert(succeeded(status) && "failed to reify result shapes");
1785 
1786  // If the input vector sizes are not provided, then the vector sizes are
1787  // determined by the result tensor shape. In case the vector sizes aren't
1788  // provided, we update the inBounds attribute instead of masking.
1789  bool useInBoundsInsteadOfMasking = false;
1790  if (inputVectorSizes.empty()) {
1791  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1792  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1793  useInBoundsInsteadOfMasking = true;
1794  }
1795 
1796  // Create masked TransferReadOp.
1797  SmallVector<int64_t> inputShape(inputVectorSizes);
1798  auto innerTiles = packOp.getStaticInnerTiles();
1799  auto innerDimsPos = packOp.getInnerDimsPos();
1800  auto outerDimsPerm = packOp.getOuterDimsPerm();
1801  if (!outerDimsPerm.empty())
1802  applyPermutationToVector(inputShape,
1804  for (auto [idx, size] : enumerate(innerTiles))
1805  inputShape[innerDimsPos[idx]] *= size;
1806  auto maskedRead = vector::createReadOrMaskedRead(
1807  rewriter, loc, packOp.getSource(), inputShape, padValue,
1808  useInBoundsInsteadOfMasking,
1809  /*inputScalableVecSizes=*/{});
1810 
1811  // Create ShapeCastOp.
1812  SmallVector<int64_t> destShape(inputVectorSizes);
1813  destShape.append(innerTiles.begin(), innerTiles.end());
1814  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1815  packOp.getDestType().getElementType());
1816  auto shapeCastOp =
1817  vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1818 
1819  // Create TransposeOp.
1820  auto destPermutation =
1822  auto transposeOp = vector::TransposeOp::create(
1823  rewriter, loc, shapeCastOp.getResult(), destPermutation);
1824 
1825  // Create TransferWriteOp.
1826  Value dest = tensor::EmptyOp::create(
1827  rewriter, loc, reifiedReturnShapes[0],
1828  transposeOp.getResult().getType().getElementType());
1829  Operation *write =
1830  createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
1831  newResults.push_back(write->getResult(0));
1832  return success();
1833 }
1834 
1835 /// Given the re-associations, "collapses" the input Vector type
1836 ///
1837 /// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1838 /// differences:
1839 /// * We can safely assume that there are no dynamic sizes.
1840 /// * Scalable flags are updated alongside regular dims.
1841 ///
1842 /// When collapsing scalable flags, conservatively avoids cases with two
1843 /// scalable dims. We could re-visit this in the future.
1844 ///
1845 /// EXAMPLE:
1846 /// type = vector<4x16x[8]x16xf32>
1847 /// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1848 /// (d0, d1, d2, d3) -> (d2, d3)]
1849 /// Result:
1850 /// vector<64x[128]xf32>
1851 static VectorType getCollapsedVecType(VectorType type,
1852  ArrayRef<AffineMap> reassociation) {
1853  assert(type.getNumScalableDims() < 2 &&
1854  "Collapsing more than 1 scalable dim is not supported ATM");
1855 
1856  // Use the fact that reassociation is valid to simplify the logic: only use
1857  // each map's rank.
1858  assert(isReassociationValid(reassociation) && "invalid reassociation");
1859 
1860  auto shape = type.getShape();
1861  auto scalableFlags = type.getScalableDims();
1862  SmallVector<int64_t> newShape;
1863  SmallVector<bool> newScalableFlags;
1864 
1865  unsigned currentDim = 0;
1866  for (AffineMap m : reassociation) {
1867  unsigned dim = m.getNumResults();
1868  int64_t size = 1;
1869  bool flag = false;
1870  for (unsigned d = 0; d < dim; ++d) {
1871  size *= shape[currentDim + d];
1872  flag |= scalableFlags[currentDim + d];
1873  }
1874  newShape.push_back(size);
1875  newScalableFlags.push_back(flag);
1876  currentDim += dim;
1877  }
1878 
1879  return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1880 }
1881 
1882 /// Vectorize `linalg.unpack` as:
1883 /// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884 ///
1885 /// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886 /// for the xfer_read operation). This is sufficient to infer the other vector
1887 /// sizes required here.
1888 ///
1889 /// If the vector sizes are not provided:
1890 /// * the vector sizes are determined from the input tensor static shape.
1891 /// * the inBounds attribute is used instead of masking.
1892 ///
1893 /// EXAMPLE (no vector sizes):
1894 /// ```
1895 /// %unpack = linalg.unpack %src
1896 /// inner_dims_pos = [0, 1]
1897 /// inner_tiles = [8, 8]
1898 /// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1899 /// ```
1900 /// is vectorized as:
1901 /// ```
1902 /// %read = vector.transfer_read %src
1903 /// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1904 /// %tr = vector.transpose %read, [0, 2, 1, 3]
1905 /// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1906 /// %sc = vector.shape_cast %tr
1907 /// : vector<1x8x1x8xf32> to vector<8x8xf32>
1908 /// %vector = vector.transfer_write %sc into %dest
1909 /// : vector<8x8xf32>, tensor<8x8xf32>
1910 /// ```
1911 static LogicalResult
1912 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1913  ArrayRef<int64_t> inputVectorSizes,
1914  ArrayRef<bool> inputScalableVecDims,
1915  SmallVectorImpl<Value> &newResults) {
1916  if (!inputVectorSizes.empty()) {
1917  assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1918  "Invalid number of input vector sizes!");
1919  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1920  "Incompatible number of vector sizes and vector scalable flags!");
1921  }
1922 
1923  // TODO: Introduce a parent class that will handle the insertion point update.
1924  OpBuilder::InsertionGuard g(rewriter);
1925  rewriter.setInsertionPoint(unpackOp);
1926 
1927  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1928 
1929  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1930  bool useInBoundsInsteadOfMasking = false;
1931 
1932  Location loc = unpackOp->getLoc();
1933 
1934  // Obtain vector sizes for the read operation.
1935  SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1936  SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1937 
1938  // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1939  if (inputVectorSizes.empty()) {
1940  if (ShapedType::isDynamicShape(sourceShape))
1941  return failure();
1942 
1943  readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1944  useInBoundsInsteadOfMasking = true;
1945  }
1946 
1947  // -- Generate the read operation --
1948  auto padValue = arith::ConstantOp::create(
1949  rewriter, loc,
1950  rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1952  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1953  useInBoundsInsteadOfMasking, readScalableVectorFlags);
1954 
1955  // -- Generate the transpose operation --
1956  PackingMetadata packMetadata;
1957  SmallVector<int64_t> lastDimToInsertPosPerm =
1958  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1959  vector::TransposeOp transposeOp = vector::TransposeOp::create(
1960  rewriter, loc, readResult, lastDimToInsertPosPerm);
1961 
1962  // -- Generate the shape_cast operation --
1963  VectorType collapsedVecType = getCollapsedVecType(
1964  transposeOp.getType(),
1966  rewriter.getContext(), packMetadata.reassociations)));
1967  vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1968  rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1969 
1970  // -- Generate the write operation --
1972  rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1973  /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1974 
1975  newResults.push_back(write->getResult(0));
1976  return success();
1977 }
1978 
1979 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1980 /// and (3) all-zero lowPad to
1981 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1982 static LogicalResult
1983 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1984  ArrayRef<int64_t> inputVectorSizes,
1985  SmallVectorImpl<Value> &newResults) {
1986  auto padValue = padOp.getConstantPaddingValue();
1987  Location loc = padOp.getLoc();
1988 
1989  // TODO: Introduce a parent class that will handle the insertion point update.
1990  OpBuilder::InsertionGuard g(rewriter);
1991  rewriter.setInsertionPoint(padOp);
1992 
1993  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1994  LogicalResult status =
1995  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1996  .reifyResultShapes(rewriter, reifiedReturnShapes);
1997  (void)status; // prevent unused variable warning on non-assert builds
1998  assert(succeeded(status) && "failed to reify result shapes");
1999  auto maskedRead = vector::createReadOrMaskedRead(
2000  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2001  /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
2002 
2003  // Create Xfer write Op
2004  Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2005  padOp.getResultType().getElementType());
2006  Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2007  newResults.push_back(write->getResult(0));
2008  return success();
2009 }
2010 
2011 // TODO: probably need some extra checks for reduction followed by consumer
2012 // ops that may not commute (e.g. linear reduction + non-linear instructions).
2013 static LogicalResult reductionPreconditions(LinalgOp op) {
2014  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2015  LDBG() << "reduction precondition failed: no reduction iterator";
2016  return failure();
2017  }
2018  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2019  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2020  if (indexingMap.isPermutation())
2021  continue;
2022 
2023  Operation *reduceOp = matchLinalgReduction(&opOperand);
2024  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2025  LDBG() << "reduction precondition failed: reduction detection failed";
2026  return failure();
2027  }
2028  }
2029  return success();
2030 }
2031 
2032 static LogicalResult
2034  bool flatten1DDepthwiseConv) {
2035  if (flatten1DDepthwiseConv) {
2036  LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2037  "supported";
2038  return failure();
2039  }
2040 
2041  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2042  LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2043  return failure();
2044  }
2045 
2046  // Support dynamic shapes in 1D depthwise convolution, but only in the
2047  // _channel_ dimension.
2048  Value lhs = conv.getDpsInputOperand(0)->get();
2049  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2050  auto shapeWithoutCh = lhsShape.drop_back(1);
2051  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2052  LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2053  "channel dim can be dynamic";
2054  return failure();
2055  }
2056 
2057  return success();
2058 }
2059 
2060 static LogicalResult
2062  bool flatten1DDepthwiseConv) {
2063  if (isa<ConvolutionOpInterface>(op.getOperation()))
2064  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2065 
2066  if (hasReductionIterator(op))
2067  return reductionPreconditions(op);
2068 
2069  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2070  // linalg.copy ops and ops that implement ContractionOpInterface for now.
2071  if (!isElementwise(op) &&
2072  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2073  op.getOperation()))
2074  return failure();
2075 
2076  LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2077  return success();
2078 }
2079 
2080 //// This hook considers two cases:
2081 /// (1) If the input-vector-sizes are empty, then the vector sizes will be
2082 /// infered. This is only possible when all shapes are static.
2083 /// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2084 /// carry out basic sanity-checking.
2085 static LogicalResult
2086 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2087  ArrayRef<int64_t> inputVectorSizes) {
2088  // If there are no input vector sizes and all shapes are static, there is
2089  // nothing left to check.
2090  if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2091  unpackOp.getSourceType().hasStaticShape())
2092  return success();
2093 
2094  // The number of input vector sizes must be equal to:
2095  // * read-vector-rank
2096  if (!inputVectorSizes.empty() &&
2097  (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2098  LDBG() << "Incorrect number of input vector sizes";
2099  return failure();
2100  }
2101 
2102  // Check the vector sizes for the read operation.
2104  unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2105  LDBG() << "Invalid vector sizes for the read operation";
2106  return failure();
2107  }
2108 
2109  return success();
2110 }
2111 
2112 static LogicalResult
2113 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2114  ArrayRef<int64_t> inputVectorSizes) {
2115 
2116  TypedValue<RankedTensorType> source = sliceOp.getSource();
2117  auto sourceType = source.getType();
2118  if (!VectorType::isValidElementType(sourceType.getElementType()))
2119  return failure();
2120 
2121  // Get the pad value.
2122  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2123  // scalar padding value. Note that:
2124  // * for in-bounds accesses,
2125  // the value is actually irrelevant. There are 2 cases in which xfer.read
2126  // accesses are known to be in-bounds:
2127  // 1. The source shape is static (output vector sizes would be based on
2128  // the source shape and hence all memory accesses would be in-bounds),
2129  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2130  // this case it is safe to assume that all memory accesses are in-bounds.
2131  //
2132  // When the value is not known and not needed, use 0. Otherwise, bail out.
2133  Value padValue = getStaticPadVal(sliceOp);
2134  bool isOutOfBoundsRead =
2135  !sourceType.hasStaticShape() && inputVectorSizes.empty();
2136 
2137  if (!padValue && isOutOfBoundsRead) {
2138  LDBG() << "Failed to get a pad value for out-of-bounds read access";
2139  return failure();
2140  }
2141  return success();
2142 }
2143 
2144 /// Vectorize a named linalg contraction op into:
2145 /// vector::TransferReadOp - Reads vectors from the operands
2146 /// vector::ContractionOp - Performs contraction
2147 /// vector::TransferWriteOp - Write the result vector back to the
2148 /// destination
2149 /// The operands shapes are preserved and loaded directly into vectors.
2150 /// Any further permutations or numerical casting remain within contraction op.
2151 static LogicalResult
2153  LinalgOp linalgOp,
2154  SmallVectorImpl<Value> &newResults) {
2155  Location loc = linalgOp.getLoc();
2156  MLIRContext *ctx = linalgOp.getContext();
2157 
2158  // For simplicity, contraction vectorization is limited to linalg named ops.
2159  // Generic op is ignored as not every arbitrary contraction body can be
2160  // expressed by a vector.contract.
2161  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2162  return failure();
2163 
2164  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2165  Operation *reduceOp = matchLinalgReduction(outOperand);
2166  auto maybeKind = getCombinerOpKind(reduceOp);
2167  if (!maybeKind) {
2168  LDBG() << "Failed to determine contraction combining kind.";
2169  return failure();
2170  }
2171 
2172  // Check that all dimensions are present in the input operands.
2173  // Arbitrary broadcasts are not supported by the vector contraction.
2174  // Broadcasts are expected to be decomposed before vectorization.
2175  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2176  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2177  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2178  LDBG() << "Contractions with broadcasts are not supported.";
2179  return failure();
2180  }
2181 
2182  // Load operands.
2183  SmallVector<Value> vecOperands;
2184  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2185  // The operand vector shape is computed by mapping the canonical vector
2186  // shape to the operand's domain. Further permutations are left as a part of
2187  // the contraction.
2188  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2190  indexingMap.getNumResults(), rewriter.getContext());
2191  Type elemType = getElementTypeOrSelf(opOperand.get());
2192  VectorType readType =
2193  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2194 
2196  rewriter, loc, opOperand.get(), readType.getShape(),
2197  /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2198  /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2199  vecOperands.push_back(read);
2200  }
2201 
2202  // Remap iterators from linalg to vector.
2203  SmallVector<Attribute> iterAttrs;
2204  auto iterators = linalgOp.getIteratorTypesArray();
2205  for (utils::IteratorType iter : iterators) {
2206  auto vecIter = iter == utils::IteratorType::parallel
2207  ? vector::IteratorType::parallel
2208  : vector::IteratorType::reduction;
2209  iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2210  }
2211 
2212  // Create contraction.
2213  Operation *contractOp = vector::ContractionOp::create(
2214  rewriter, loc, /*lhs=*/vecOperands[0],
2215  /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2216  linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2217  contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2218 
2219  // Store result.
2221  rewriter, loc, contractOp->getResult(0), outOperand->get());
2222 
2223  // Finalize.
2224  if (!write->getResults().empty())
2225  newResults.push_back(write->getResult(0));
2226 
2227  return success();
2228 }
2229 
2230 namespace {
2231 enum class ConvOperationKind { Conv, Pool };
2232 } // namespace
2233 
2235  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2236  isa<BlockArgument>(op->getOperand(0));
2237 }
2238 
2239 // Returns the ConvOperationKind of the op using reduceOp of the generic
2240 // payload. If it is neither a convolution nor a pooling, it returns
2241 // std::nullopt.
2242 //
2243 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2244 // + yield) and rhs is not used) then it is the body of a pooling
2245 // If conv, check for single `mul` predecessor. The `mul` operands must be
2246 // block arguments or extension of block arguments.
2247 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2248 // must be block arguments or extension of block arguments.
2249 static std::optional<ConvOperationKind>
2251  int numBlockArguments =
2252  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2253 
2254  switch (numBlockArguments) {
2255  case 1: {
2256  // Will be convolution if feeder is a MulOp.
2257  // A strength reduced version of MulOp for i1 type is AndOp which is also
2258  // supported. Otherwise, it can be pooling. This strength reduction logic
2259  // is in `buildBinaryFn` helper in the Linalg dialect.
2260  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2261  llvm::IsaPred<BlockArgument>);
2262  assert(feedValIt != reduceOp->operand_end() &&
2263  "Expected a non-block argument operand");
2264  Operation *feedOp = (*feedValIt).getDefiningOp();
2265  if (isCastOfBlockArgument(feedOp)) {
2266  return ConvOperationKind::Pool;
2267  }
2268 
2269  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2270  (isa<arith::AndIOp>(feedOp) &&
2271  feedOp->getResultTypes()[0].isInteger(1))) &&
2272  llvm::all_of(feedOp->getOperands(), [](Value v) {
2273  if (isa<BlockArgument>(v))
2274  return true;
2275  if (Operation *op = v.getDefiningOp())
2276  return isCastOfBlockArgument(op);
2277  return false;
2278  }))) {
2279  return std::nullopt;
2280  }
2281 
2282  return ConvOperationKind::Conv;
2283  }
2284  case 2:
2285  // Must be pooling
2286  return ConvOperationKind::Pool;
2287  default:
2288  return std::nullopt;
2289  }
2290 }
2291 
2292 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2293  switch (kind) {
2294  case vector::CombiningKind::ADD:
2295  case vector::CombiningKind::MAXNUMF:
2296  case vector::CombiningKind::MAXIMUMF:
2297  case vector::CombiningKind::MAXSI:
2298  case vector::CombiningKind::MAXUI:
2299  case vector::CombiningKind::MINNUMF:
2300  case vector::CombiningKind::MINIMUMF:
2301  case vector::CombiningKind::MINSI:
2303  return true;
2304  default:
2305  return false;
2306  }
2307 }
2308 
2309 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2310  auto getOperandType = [&](auto operand) {
2311  return dyn_cast<ShapedType>((operand->get()).getType());
2312  };
2313  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2314  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2315  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2316  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2317  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2318  // Note that this also ensures 2D and 3D convolutions are rejected.
2319  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2320  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2321  return failure();
2322 
2323  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2324  if (!reduceOp)
2325  return failure();
2326 
2327  auto maybeOper = getConvOperationKind(reduceOp);
2328  if (!maybeOper.has_value())
2329  return failure();
2330 
2331  auto maybeKind = getCombinerOpKind(reduceOp);
2332  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2333  // can get strength reduced to `OR` which is also supported. This strength
2334  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2335  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2336  *maybeKind != vector::CombiningKind::OR) &&
2337  (*maybeOper != ConvOperationKind::Pool ||
2338  !isSupportedPoolKind(*maybeKind)))) {
2339  return failure();
2340  }
2341 
2342  auto rhsRank = rhsShapedType.getRank();
2343  if (*maybeOper == ConvOperationKind::Pool) {
2344  if (rhsRank != 1)
2345  return failure();
2346  } else {
2347  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2348  return failure();
2349  }
2350 
2351  return success();
2352 }
2353 
2354 static LogicalResult vectorizeLinalgOpPrecondition(
2355  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2356  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2357  // tensor with dimension of 0 cannot be vectorized.
2358  if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2359  return llvm::is_contained(linalgOp.getShape(&operand), 0);
2360  }))
2361  return failure();
2362  // Check API contract for input vector sizes.
2363  if (!inputVectorSizes.empty() &&
2364  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2365  inputVectorSizes)))
2366  return failure();
2367 
2368  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2369  linalgOp, flatten1DDepthwiseConv))) {
2370  LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2371  return failure();
2372  }
2373 
2375 
2376  // Register CustomVectorizationPrecondition for extractOp.
2377  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2378 
2379  // All types in the body should be a supported element type for VectorType.
2380  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2381  // Check if any custom hook can vectorize the inner op.
2382  if (llvm::any_of(
2383  customPreconditions,
2384  [&](const CustomVectorizationPrecondition &customPrecondition) {
2385  return succeeded(
2386  customPrecondition(&innerOp, vectorizeNDExtract));
2387  })) {
2388  continue;
2389  }
2390  if (!llvm::all_of(innerOp.getOperandTypes(),
2391  VectorType::isValidElementType)) {
2392  return failure();
2393  }
2394  if (!llvm::all_of(innerOp.getResultTypes(),
2395  VectorType::isValidElementType)) {
2396  return failure();
2397  }
2398  }
2399  if (isElementwise(linalgOp))
2400  return success();
2401 
2402  // TODO: isaConvolutionOpInterface that can also infer from generic
2403  // features. But we will still need stride/dilation attributes that will be
2404  // annoying to reverse-engineer...
2405  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2406  return vectorizeConvOpPrecondition(linalgOp);
2407 
2408  // TODO: the common vector shape is equal to the static loop sizes only when
2409  // all indexing maps are projected permutations. For convs and stencils the
2410  // logic will need to evolve.
2411  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2412  LDBG() << "precondition failed: not projected permutations";
2413  return failure();
2414  }
2415  if (failed(reductionPreconditions(linalgOp))) {
2416  LDBG() << "precondition failed: reduction preconditions";
2417  return failure();
2418  }
2419  return success();
2420 }
2421 
2422 static LogicalResult
2423 vectorizePackOpPrecondition(linalg::PackOp packOp,
2424  ArrayRef<int64_t> inputVectorSizes) {
2425  auto padValue = packOp.getPaddingValue();
2426  Attribute cstAttr;
2427  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2428  LDBG() << "pad value is not constant: " << packOp;
2429  return failure();
2430  }
2431 
2432  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2433  bool satisfyEmptyCond = true;
2434  if (inputVectorSizes.empty()) {
2435  if (!packOp.getDestType().hasStaticShape() ||
2436  !packOp.getSourceType().hasStaticShape())
2437  satisfyEmptyCond = false;
2438  }
2439 
2440  if (!satisfyEmptyCond &&
2442  resultTensorShape.take_front(packOp.getSourceRank()),
2443  inputVectorSizes)))
2444  return failure();
2445 
2446  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2447  return !getConstantIntValue(v).has_value();
2448  })) {
2449  LDBG() << "inner_tiles must be constant: " << packOp;
2450  return failure();
2451  }
2452 
2453  return success();
2454 }
2455 
2456 static LogicalResult
2457 vectorizePadOpPrecondition(tensor::PadOp padOp,
2458  ArrayRef<int64_t> inputVectorSizes) {
2459  auto padValue = padOp.getConstantPaddingValue();
2460  if (!padValue) {
2461  LDBG() << "pad value is not constant: " << padOp;
2462  return failure();
2463  }
2464 
2465  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2466  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2467  inputVectorSizes)))
2468  return failure();
2469 
2470  // Padding with non-zero low pad values is not supported, unless the
2471  // corresponding result dim is 1 as this would require shifting the results to
2472  // the right for the low padded dims by the required amount of low padding.
2473  // However, we do support low padding if the dims being low padded have result
2474  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2475  // input size of that dimension will be dynamically zero (as the sum of the
2476  // low pad and input dim size has to be one) and hence we will create a zero
2477  // mask as the lowering logic just makes the mask one for the input dim size -
2478  // which is zero here. Hence we will load the pad value which is what we want
2479  // in this case. If the low pad is dynamically zero then the lowering is
2480  // correct as well as no shifts are necessary.
2481  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2482  Value padValue = en.value();
2483  unsigned pos = en.index();
2484  std::optional<int64_t> pad = getConstantIntValue(padValue);
2485  return (!pad.has_value() || pad.value() != 0) &&
2486  resultTensorShape[pos] != 1;
2487  })) {
2488  LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2489  return failure();
2490  }
2491 
2492  return success();
2493 }
2494 
2495 /// Preconditions for scalable vectors.
2496 ///
2497 /// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2498 /// models the fact that in practice we would only make selected dimensions
2499 /// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2500 /// unconditionally - we are yet to identify meaningful conditions.
2501 static LogicalResult
2503  ArrayRef<int64_t> inputVectorSizes,
2504  ArrayRef<bool> inputScalableVecDims) {
2505  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2506  "Number of input vector sizes and scalable dims doesn't match");
2507 
2508  size_t numOfScalableDims =
2509  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2510 
2511  if (numOfScalableDims == 0)
2512  return success();
2513 
2514  auto linalgOp = dyn_cast<LinalgOp>(op);
2515 
2516  // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2517  // exception of UnpackOp for which there is a dedicated hook.
2518  if (!linalgOp) {
2519  return success(isa<linalg::UnPackOp>(op));
2520  }
2521 
2522  // Cond 2: There's been no need for more than 2 scalable dims so far
2523  if (numOfScalableDims > 2)
2524  return failure();
2525 
2526  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2527  // it matches one of the supported cases:
2528  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2529  // (*).
2530  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2531  // parallel dims.
2532  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2533  // dim.
2534  // The 2nd restriction above means that only Matmul-like Ops are supported
2535  // when 2 dims are scalable, e.g. :
2536  // * iterators = [parallel, parallel, reduction]
2537  // * scalable flags = [true, true, false]
2538  //
2539  // (*) Non-unit dims get folded away in practice.
2540  // TODO: Relax these conditions as good motivating examples are identified.
2541 
2542  // Find the first scalable flag.
2543  bool seenNonUnitParallel = false;
2544  auto iterators = linalgOp.getIteratorTypesArray();
2545  SmallVector<bool> scalableFlags(inputScalableVecDims);
2546  int64_t idx = scalableFlags.size() - 1;
2547  while (!scalableFlags[idx]) {
2548  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2549  seenNonUnitParallel |=
2550  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2551 
2552  iterators.pop_back();
2553  scalableFlags.pop_back();
2554  --idx;
2555  }
2556 
2557  // Analyze the iterator corresponding to the first scalable dim.
2558  switch (iterators.back()) {
2559  case utils::IteratorType::reduction: {
2560  // Check 3. above is met.
2561  if (iterators.size() != inputVectorSizes.size()) {
2562  LDBG() << "Non-trailing reduction dim requested for scalable "
2563  "vectorization";
2564  return failure();
2565  }
2566  if (isa<linalg::MatmulOp>(op)) {
2567  LDBG()
2568  << "Scalable vectorization of the reduction dim in Matmul-like ops "
2569  "is not supported";
2570  return failure();
2571  }
2572  break;
2573  }
2574  case utils::IteratorType::parallel: {
2575  // Check 1. and 2. above are met.
2576  if (seenNonUnitParallel) {
2577  LDBG() << "Inner parallel dim not requested for scalable "
2578  "vectorization";
2579  return failure();
2580  }
2581  break;
2582  }
2583  }
2584 
2585  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2586  // supported for which expect the folowing config:
2587  // * iterators = [parallel, parallel, reduction]
2588  // * scalable flags = [true, true, false]
2589  if (numOfScalableDims == 2) {
2590  // Disallow below case which breaks 3. above:
2591  // * iterators = [..., parallel, reduction]
2592  // * scalable flags = [..., true, true]
2593  if (iterators.back() == utils::IteratorType::reduction) {
2594  LDBG() << "Higher dim than the trailing reduction dim requested for "
2595  "scalable "
2596  "vectorizatio";
2597  return failure();
2598  }
2599  scalableFlags.pop_back();
2600  iterators.pop_back();
2601 
2602  if (!scalableFlags.back() ||
2603  (iterators.back() != utils::IteratorType::parallel))
2604  return failure();
2605  }
2606 
2607  // Cond 4: Only the following ops are supported in the
2608  // presence of scalable vectors
2609  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2610  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2611  isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2612  isa<linalg::BatchMmt4DOp>(op) ||
2613  hasReductionIterator(linalgOp));
2614 }
2615 
2617  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2618  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2619  bool flatten1DDepthwiseConv) {
2620 
2621  if (!hasVectorizationImpl(op))
2622  return failure();
2623 
2624  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2625  inputScalableVecDims)))
2626  return failure();
2627 
2629  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2630  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2631  vectorizeNDExtract,
2632  flatten1DDepthwiseConv);
2633  })
2634  .Case<tensor::PadOp>([&](auto padOp) {
2635  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2636  })
2637  .Case<linalg::PackOp>([&](auto packOp) {
2638  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2639  })
2640  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2641  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2642  })
2643  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2644  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2645  })
2646  .Default([](auto) { return failure(); });
2647 }
2648 
2649 /// Converts affine.apply Ops to arithmetic operations.
2650 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2651  OpBuilder::InsertionGuard g(rewriter);
2652  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2653 
2654  for (auto op : make_early_inc_range(toReplace)) {
2655  rewriter.setInsertionPoint(op);
2656  auto expanded = affine::expandAffineExpr(
2657  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2658  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2659  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2660  rewriter.replaceOp(op, expanded);
2661  }
2662 }
2663 
2665  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2666  tensor::InsertSliceOp>(op);
2667 }
2668 
2669 FailureOr<VectorizationResult> mlir::linalg::vectorize(
2670  RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2671  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2672  bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2673  bool createNamedContraction) {
2674  LDBG() << "Attempting to vectorize: " << *op;
2675  LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2676  LDBG() << "Input scalable vector dims: "
2677  << llvm::interleaved(inputScalableVecDims);
2678 
2679  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2680  vectorizeNDExtract,
2681  flatten1DDepthwiseConv))) {
2682  LDBG() << "Vectorization pre-conditions failed";
2683  return failure();
2684  }
2685 
2686  // Initialize vectorization state.
2687  VectorizationState state(rewriter);
2688  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2689  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2690  inputScalableVecDims,
2691  assumeDynamicDimsMatchVecSizes))) {
2692  LDBG() << "Vectorization state couldn't be initialized";
2693  return failure();
2694  }
2695  }
2696 
2697  SmallVector<Value> results;
2698  auto vectorizeResult =
2700  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2701  // TODO: isaConvolutionOpInterface that can also infer from
2702  // generic features. Will require stride/dilation attributes
2703  // inference.
2704  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2705  FailureOr<Operation *> convOr = vectorizeConvolution(
2706  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2707  flatten1DDepthwiseConv);
2708  if (succeeded(convOr)) {
2709  llvm::append_range(results, (*convOr)->getResults());
2710  return success();
2711  }
2712 
2713  LDBG() << "Unsupported convolution can't be vectorized.";
2714  return failure();
2715  }
2716 
2717  if (createNamedContraction &&
2718  isa<ContractionOpInterface>(linalgOp.getOperation()))
2719  return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2720  results);
2721 
2722  LDBG()
2723  << "Vectorize generic by broadcasting to the canonical vector "
2724  "shape";
2725 
2726  // Pre-process before proceeding.
2727  convertAffineApply(rewriter, linalgOp);
2728 
2729  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2730  // to 'OpBuilder' when it is passed over to some methods like
2731  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2732  // erase an op within these methods, the actual rewriter won't be
2733  // notified and we will end up with read-after-free issues!
2734  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2735  })
2736  .Case<tensor::PadOp>([&](auto padOp) {
2737  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2738  results);
2739  })
2740  .Case<linalg::PackOp>([&](auto packOp) {
2741  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2742  results);
2743  })
2744  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2745  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2746  inputVectorSizes,
2747  inputScalableVecDims, results);
2748  })
2749  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2750  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2751  results);
2752  })
2753  .Default([](auto) { return failure(); });
2754 
2755  if (failed(vectorizeResult)) {
2756  LDBG() << "Vectorization failed";
2757  return failure();
2758  }
2759 
2760  return VectorizationResult{results};
2761 }
2762 
2764  memref::CopyOp copyOp) {
2765  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2766  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2767  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2768  return failure();
2769 
2770  auto srcElementType = getElementTypeOrSelf(srcType);
2771  auto dstElementType = getElementTypeOrSelf(dstType);
2772  if (!VectorType::isValidElementType(srcElementType) ||
2773  !VectorType::isValidElementType(dstElementType))
2774  return failure();
2775 
2776  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2777  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2778 
2779  Location loc = copyOp->getLoc();
2780  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2781  SmallVector<Value> indices(srcType.getRank(), zero);
2782 
2783  Value readValue = vector::TransferReadOp::create(
2784  rewriter, loc, readType, copyOp.getSource(), indices,
2785  /*padding=*/std::nullopt,
2786  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2787  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2788  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2789  ArrayRef<int64_t>());
2790  readValue =
2791  vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2792  }
2793  Operation *writeValue = vector::TransferWriteOp::create(
2794  rewriter, loc, readValue, copyOp.getTarget(), indices,
2795  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2796  rewriter.replaceOp(copyOp, writeValue->getResults());
2797  return success();
2798 }
2799 
2800 //----------------------------------------------------------------------------//
2801 // Misc. vectorization patterns.
2802 //----------------------------------------------------------------------------//
2803 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2804 /// given operation type OpTy.
2805 template <typename OpTy>
2806 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2808 
2809  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2810  PatternRewriter &rewriter) const final {
2811  bool changed = false;
2812  // Insert users in vector, because some users may be replaced/removed.
2813  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2814  if (auto op = dyn_cast<OpTy>(user))
2815  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2816  return success(changed);
2817  }
2818 
2819 protected:
2820  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2821  tensor::PadOp padOp, OpTy op) const = 0;
2822 };
2823 
2824 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2825 /// ```
2826 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2827 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2828 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2829 /// ```
2830 /// is rewritten to:
2831 /// ```
2832 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2833 /// {in_bounds = [true, true]}
2834 /// : tensor<?x?xf32>, vector<17x5xf32>
2835 /// ```
2836 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2837 /// sure that the original padding value %cst was never used.
2838 ///
2839 /// This rewrite is possible if:
2840 /// - `xferOp` has no out-of-bounds dims or mask.
2841 /// - Low padding is static 0.
2842 /// - Single, scalar padding value.
2844  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2846  vector::TransferReadOp>::VectorizePadOpUserPattern;
2847 
2848  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2849  vector::TransferReadOp xferOp) const override {
2850  // Low padding must be static 0.
2851  if (!padOp.hasZeroLowPad())
2852  return failure();
2853  // Pad value must be a constant.
2854  auto padValue = padOp.getConstantPaddingValue();
2855  if (!padValue)
2856  return failure();
2857  // Padding value of existing `xferOp` is unused.
2858  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2859  return failure();
2860 
2861  rewriter.modifyOpInPlace(xferOp, [&]() {
2862  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2863  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2864  rewriter.getBoolArrayAttr(inBounds));
2865  xferOp.getBaseMutable().assign(padOp.getSource());
2866  xferOp.getPaddingMutable().assign(padValue);
2867  });
2868 
2869  return success();
2870  }
2871 };
2872 
2873 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2874 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2875 /// value, where the same amount of padding is immediately removed again after
2876 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2877 /// tensor value and apply out-of-bounds masking. E.g.:
2878 /// ```
2879 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2880 /// : tensor<...> to tensor<?x?xf32>
2881 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2882 /// %2 = vector.transfer_write %vec, %1[...]
2883 /// : vector<17x5xf32>, tensor<17x5xf32>
2884 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2885 /// : tensor<17x5xf32> to tensor<?x?xf32>
2886 /// ```
2887 /// is rewritten to:
2888 /// ```
2889 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2890 /// : tensor<...> to tensor<?x?xf32>
2891 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2892 /// tensor<?x?xf32>
2893 /// ```
2894 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2895 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2896 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2897 /// from %r's old dimensions.
2898 ///
2899 /// This rewrite is possible if:
2900 /// - Low padding is static 0.
2901 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2902 /// ExtractSliceOp trims the same amount of padding that was added
2903 /// beforehand.
2904 /// - Single, scalar padding value.
2906  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2908  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2909 
2910  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2911  vector::TransferWriteOp xferOp) const override {
2912  // TODO: support 0-d corner case.
2913  if (xferOp.getTransferRank() == 0)
2914  return failure();
2915 
2916  // Low padding must be static 0.
2917  if (!padOp.hasZeroLowPad())
2918  return failure();
2919  // Pad value must be a constant.
2920  auto padValue = padOp.getConstantPaddingValue();
2921  if (!padValue)
2922  return failure();
2923  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2924  if (!xferOp->hasOneUse())
2925  return failure();
2926  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2927  if (!trimPadding)
2928  return failure();
2929  // Only static zero offsets supported when trimming padding.
2930  if (!trimPadding.hasZeroOffset())
2931  return failure();
2932  // trimPadding must remove the amount of padding that was added earlier.
2933  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2934  return failure();
2935 
2936  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2937  rewriter.setInsertionPoint(xferOp);
2938 
2939  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2940  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2941  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2942  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2943  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2944  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2945 
2946  return success();
2947  }
2948 
2949  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2950  /// i.e., same dimensions.
2951  ///
2952  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2953  /// dimensions, this function tries to infer the (static) tensor size by
2954  /// looking at the defining op and utilizing op-specific knowledge.
2955  ///
2956  /// This is a conservative analysis. In case equal tensor sizes cannot be
2957  /// proven statically, this analysis returns `false` even though the tensor
2958  /// sizes may turn out to be equal at runtime.
2959  bool hasSameTensorSize(Value beforePadding,
2960  tensor::ExtractSliceOp afterTrimming) const {
2961  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2962  // result and CastOp operand.
2963  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2964  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2965  return true;
2966 
2967  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2968  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2969  // Only RankedTensorType supported.
2970  if (!t1 || !t2)
2971  return false;
2972  // Rank of both values must be the same.
2973  if (t1.getRank() != t2.getRank())
2974  return false;
2975 
2976  // All static dimensions must be the same. Mixed cases (e.g., dimension
2977  // static in `t1` but dynamic in `t2`) are not supported.
2978  for (unsigned i = 0; i < t1.getRank(); ++i) {
2979  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2980  return false;
2981  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2982  return false;
2983  }
2984 
2985  // Nothing more to check if all dimensions are static.
2986  if (t1.getNumDynamicDims() == 0)
2987  return true;
2988 
2989  // All dynamic sizes must be the same. The only supported case at the
2990  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2991  // thereof).
2992 
2993  // Apart from CastOp, only ExtractSliceOp is supported.
2994  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2995  if (!beforeSlice)
2996  return false;
2997 
2998  assert(static_cast<size_t>(t1.getRank()) ==
2999  beforeSlice.getMixedSizes().size());
3000  assert(static_cast<size_t>(t2.getRank()) ==
3001  afterTrimming.getMixedSizes().size());
3002 
3003  for (unsigned i = 0; i < t1.getRank(); ++i) {
3004  // Skip static dimensions.
3005  if (!t1.isDynamicDim(i))
3006  continue;
3007  auto size1 = beforeSlice.getMixedSizes()[i];
3008  auto size2 = afterTrimming.getMixedSizes()[i];
3009 
3010  // Case 1: Same value or same constant int.
3011  if (isEqualConstantIntOrValue(size1, size2))
3012  continue;
3013 
3014  // Other cases: Take a deeper look at defining ops of values.
3015  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3016  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3017  if (!v1 || !v2)
3018  return false;
3019 
3020  // Case 2: Both values are identical AffineMinOps. (Should not happen if
3021  // CSE is run.)
3022  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3023  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3024  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3025  minOp1.getOperands() == minOp2.getOperands())
3026  continue;
3027 
3028  // Add additional cases as needed.
3029  }
3030 
3031  // All tests passed.
3032  return true;
3033  }
3034 };
3035 
3036 /// Returns the effective Pad value for the input op, provided it's a scalar.
3037 ///
3038 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3039 /// this Op performs padding, retrieve the padding value provided that it's
3040 /// a scalar and static/fixed for all the padded values. Returns an empty value
3041 /// otherwise.
3042 ///
3043 /// TODO: This is used twice (when checking vectorization pre-conditions and
3044 /// when vectorizing). Cache results instead of re-running.
3046  if (!op)
3047  return {};
3048 
3049  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3050  // being broadcast, provided that it's a scalar.
3051  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3052  auto source = bcast.getSource();
3053  if (llvm::dyn_cast<VectorType>(source.getType()))
3054  return {};
3055 
3056  return source;
3057  }
3058 
3059  // 2. linalg.fill - use the scalar input value that used to fill the output
3060  // tensor.
3061  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3062  return fill.getInputs()[0];
3063  }
3064 
3065  // 3. tensor.generateOp - can't guarantee the value is fixed without
3066  // analysing, bail out.
3067  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3068  return {};
3069  }
3070 
3071  // 4. vector.transfer_write - inspect the input vector that's written from. If
3072  // if contains a single value that has been broadcast (e.g. via
3073  // vector.broadcast), extract it, fail otherwise.
3074  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3075  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3076 
3077  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3078  // than the input tensor, then, provided it's constant, we'll extract the
3079  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3080  // TODO: Clarify the semantics when the input tensor is larger than the
3081  // destination.
3082  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3083  return getStaticPadVal(slice.getDest().getDefiningOp());
3084 
3085  return {};
3086 }
3087 
3088 static LogicalResult
3089 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3090  ArrayRef<int64_t> inputVectorSizes,
3091  SmallVectorImpl<Value> &newResults) {
3092  // TODO: Introduce a parent class that will handle the insertion point update.
3093  OpBuilder::InsertionGuard g(rewriter);
3094  rewriter.setInsertionPoint(sliceOp);
3095 
3096  TypedValue<RankedTensorType> source = sliceOp.getSource();
3097  auto sourceType = source.getType();
3098  auto resultType = sliceOp.getResultType();
3099 
3100  Value padValue = getStaticPadVal(sliceOp);
3101 
3102  if (!padValue) {
3103  auto elemType = sourceType.getElementType();
3104  padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3105  rewriter.getZeroAttr(elemType));
3106  }
3107 
3108  // 2. Get the vector shape
3109  SmallVector<int64_t> vecShape;
3110  size_t rankDiff = resultType.getRank() - sourceType.getRank();
3111  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3112  if (!inputVectorSizes.empty()) {
3113  vecShape.push_back(inputVectorSizes[i]);
3114  } else if (!sourceType.isDynamicDim(i)) {
3115  vecShape.push_back(sourceType.getDimSize(i));
3116  } else if (!resultType.isDynamicDim(i)) {
3117  // Source shape is not statically known, but result shape is.
3118  // Vectorize with size of result shape. This may be larger than the
3119  // source size.
3120  // FIXME: Using rankDiff implies that the source tensor is inserted at
3121  // the end of the destination tensor. However, that's not required.
3122  vecShape.push_back(resultType.getDimSize(rankDiff + i));
3123  } else {
3124  // Neither source nor result dim of padOp is static. Cannot vectorize
3125  // the copy.
3126  return failure();
3127  }
3128  }
3129  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3130 
3131  // 3. Generate TransferReadOp + TransferWriteOp
3132  auto loc = sliceOp.getLoc();
3133 
3134  // Create read
3135  SmallVector<Value> readIndices(
3136  vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3138  rewriter, loc, source, vecType.getShape(), padValue,
3139  /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3140  /*inputScalableVecSizes=*/{});
3141 
3142  // Create write
3143  auto writeIndices =
3144  getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3145  Operation *write =
3146  createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3147  writeIndices, inputVectorSizes.empty());
3148 
3149  // 4. Finalize
3150  newResults.push_back(write->getResult(0));
3151 
3152  return success();
3153 }
3154 
3155 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3156 /// ```
3157 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3158 /// %r = tensor.insert_slice %0
3159 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3160 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3161 /// ```
3162 /// is rewritten to:
3163 /// ```
3164 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
3165 /// : tensor<?x?xf32>, vector<17x5xf32>
3166 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3167 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3168 /// ```
3169 ///
3170 /// This rewrite is possible if:
3171 /// - Low padding is static 0.
3172 /// - `padOp` result shape is static.
3173 /// - The entire padded tensor is inserted.
3174 /// (Implies that sizes of `insertOp` are all static.)
3175 /// - Only unit strides in `insertOp`.
3176 /// - Single, scalar padding value.
3177 /// - `padOp` result not used as destination.
3179  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3181  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3182 
3183  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3184  tensor::InsertSliceOp insertOp) const override {
3185  // Low padding must be static 0.
3186  if (!padOp.hasZeroLowPad())
3187  return failure();
3188  // Only unit stride supported.
3189  if (!insertOp.hasUnitStride())
3190  return failure();
3191  // Pad value must be a constant.
3192  auto padValue = padOp.getConstantPaddingValue();
3193  if (!padValue)
3194  return failure();
3195  // Dynamic shapes not supported.
3196  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3197  return failure();
3198  // Pad result not used as destination.
3199  if (insertOp.getDest() == padOp.getResult())
3200  return failure();
3201 
3202  auto vecType = VectorType::get(padOp.getType().getShape(),
3203  padOp.getType().getElementType());
3204  unsigned vecRank = vecType.getRank();
3205  unsigned tensorRank = insertOp.getType().getRank();
3206 
3207  // Check if sizes match: Insert the entire tensor into most minor dims.
3208  // (No permutations allowed.)
3209  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3210  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3211  if (!llvm::all_of(
3212  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3213  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3214  }))
3215  return failure();
3216 
3217  // Insert the TransferReadOp and TransferWriteOp at the position of the
3218  // InsertSliceOp.
3219  rewriter.setInsertionPoint(insertOp);
3220 
3221  // Generate TransferReadOp: Read entire source tensor and add high
3222  // padding.
3223  SmallVector<Value> readIndices(
3224  vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3225  auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3226  vecType, padOp.getSource(),
3227  readIndices, padValue);
3228 
3229  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3230  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3231  // source must fit into the destination at the specified offsets.
3232  auto writeIndices = getValueOrCreateConstantIndexOp(
3233  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3234  SmallVector<bool> inBounds(vecRank, true);
3235  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3236  insertOp, read, insertOp.getDest(), writeIndices,
3237  ArrayRef<bool>{inBounds});
3238 
3239  return success();
3240  }
3241 };
3242 
3244  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3248  patterns.getContext(), baseBenefit.getBenefit() + 1);
3249 }
3250 
3251 //----------------------------------------------------------------------------//
3252 // Forwarding patterns
3253 //----------------------------------------------------------------------------//
3254 
3255 /// Check whether there is any interleaved use of any `values` between
3256 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3257 /// is in a different block.
3258 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3259  ValueRange values) {
3260  if (firstOp->getBlock() != secondOp->getBlock() ||
3261  !firstOp->isBeforeInBlock(secondOp)) {
3262  LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3263  << ", second op: " << *secondOp;
3264  return true;
3265  }
3266  for (auto v : values) {
3267  for (auto &u : v.getUses()) {
3268  Operation *owner = u.getOwner();
3269  if (owner == firstOp || owner == secondOp)
3270  continue;
3271  // TODO: this is too conservative, use dominance info in the future.
3272  if (owner->getBlock() == firstOp->getBlock() &&
3273  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3274  continue;
3275  LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3276  << ", second op: " << *secondOp;
3277  return true;
3278  }
3279  }
3280  return false;
3281 }
3282 
3283 /// Return the unique subview use of `v` if it is indeed unique, null
3284 /// otherwise.
3285 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3286  memref::SubViewOp subViewOp;
3287  for (auto &u : v.getUses()) {
3288  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3289  if (subViewOp)
3290  return memref::SubViewOp();
3291  subViewOp = newSubViewOp;
3292  }
3293  }
3294  return subViewOp;
3295 }
3296 
3297 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3298 /// when available.
3300  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3301 
3302  // TODO: support mask.
3303  if (xferOp.getMask())
3304  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3305 
3306  // Transfer into `view`.
3307  Value viewOrAlloc = xferOp.getBase();
3308  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3309  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3310  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3311 
3312  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3313  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3314  if (!subViewOp)
3315  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3316  Value subView = subViewOp.getResult();
3317 
3318  // Find the copy into `subView` without interleaved uses.
3319  memref::CopyOp copyOp;
3320  for (auto &u : subView.getUses()) {
3321  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3322  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3323  if (newCopyOp.getTarget() != subView)
3324  continue;
3325  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3326  continue;
3327  copyOp = newCopyOp;
3328  break;
3329  }
3330  }
3331  if (!copyOp)
3332  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3333 
3334  // Find the fill into `viewOrAlloc` without interleaved uses before the
3335  // copy.
3336  FillOp maybeFillOp;
3337  for (auto &u : viewOrAlloc.getUses()) {
3338  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3339  assert(isa<MemRefType>(newFillOp.output().getType()));
3340  if (newFillOp.output() != viewOrAlloc)
3341  continue;
3342  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3343  continue;
3344  maybeFillOp = newFillOp;
3345  break;
3346  }
3347  }
3348  // Ensure padding matches.
3349  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3350  return rewriter.notifyMatchFailure(xferOp,
3351  "padding value does not match fill");
3352 
3353  // `in` is the subview that memref.copy reads. Replace it.
3354  Value in = copyOp.getSource();
3355 
3356  // memref.copy + linalg.fill can be used to create a padded local buffer.
3357  // The `masked` attribute is only valid on this padded buffer.
3358  // When forwarding to vector.transfer_read, the attribute must be reset
3359  // conservatively.
3360  auto vectorType = xferOp.getVectorType();
3361  Value res = vector::TransferReadOp::create(
3362  rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3363  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3364  rewriter.getBoolArrayAttr(
3365  SmallVector<bool>(vectorType.getRank(), false)));
3366 
3367  if (maybeFillOp)
3368  rewriter.eraseOp(maybeFillOp);
3369  rewriter.eraseOp(copyOp);
3370  rewriter.replaceOp(xferOp, res);
3371 
3372  return success();
3373 }
3374 
3375 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3376 /// when available.
3378  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3379  // TODO: support mask.
3380  if (xferOp.getMask())
3381  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3382 
3383  // Transfer into `viewOrAlloc`.
3384  Value viewOrAlloc = xferOp.getBase();
3385  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3386  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3387  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3388 
3389  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3390  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3391  if (!subViewOp)
3392  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3393  Value subView = subViewOp.getResult();
3394 
3395  // Find the copy from `subView` without interleaved uses.
3396  memref::CopyOp copyOp;
3397  for (auto &u : subViewOp.getResult().getUses()) {
3398  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3399  if (newCopyOp.getSource() != subView)
3400  continue;
3401  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3402  continue;
3403  copyOp = newCopyOp;
3404  break;
3405  }
3406  }
3407  if (!copyOp)
3408  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3409 
3410  // `out` is the subview copied into that we replace.
3411  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3412  Value out = copyOp.getTarget();
3413 
3414  // Forward vector.transfer into copy.
3415  // memref.copy + linalg.fill can be used to create a padded local buffer.
3416  // The `masked` attribute is only valid on this padded buffer.
3417  // When forwarding to vector.transfer_write, the attribute must be reset
3418  // conservatively.
3419  auto vector = xferOp.getVector();
3420  vector::TransferWriteOp::create(
3421  rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3422  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3424  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3425 
3426  rewriter.eraseOp(copyOp);
3427  rewriter.eraseOp(xferOp);
3428 
3429  return success();
3430 }
3431 
3432 //===----------------------------------------------------------------------===//
3433 // Convolution vectorization patterns
3434 //===----------------------------------------------------------------------===//
3435 
3436 template <int N>
3437 static void bindShapeDims(ShapedType shapedType) {}
3438 
3439 template <int N, typename IntTy, typename... IntTy2>
3440 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3441  val = shapedType.getShape()[N];
3442  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3443 }
3444 
3445 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3446 template <typename... IntTy>
3447 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3448  bindShapeDims<0>(shapedType, vals...);
3449 }
3450 
3451 namespace {
3452 /// Generate a vector implementation for either:
3453 /// ```
3454 /// Op def: ( w, kw )
3455 /// Iters: ({Par(), Red()})
3456 /// Layout: {{w + kw}, {kw}, {w}}
3457 /// ```
3458 /// kw is unrolled.
3459 ///
3460 /// or
3461 ///
3462 /// ```
3463 /// Op def: ( n, w, c, kw, f )
3464 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3465 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3466 /// ```
3467 /// kw is unrolled, w is unrolled iff dilationW > 1.
3468 ///
3469 /// or
3470 ///
3471 /// ```
3472 /// Op def: ( n, c, w, f, kw )
3473 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3474 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3475 /// ```
3476 /// kw is unrolled, w is unrolled iff dilationW > 1.
3477 ///
3478 /// or
3479 ///
3480 /// ```
3481 /// Op def: ( n, w, c, kw )
3482 /// Iters: ({Par(), Par(), Par(), Red()})
3483 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3484 /// ```
3485 /// kw is unrolled, w is unrolled iff dilationW > 1.
3486 struct Conv1DGenerator
3487  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3488  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3489  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3490 
3491  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3492  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3493  resShaped = linalgOp.getDpsInitOperand(0)->get();
3494  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3495  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3496  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3497 
3498  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3499  redOp = reduceOp->getName().getIdentifier();
3500 
3501  setConvOperationKind(reduceOp);
3502 
3503  auto maybeKind = getCombinerOpKind(reduceOp);
3504  reductionKind = maybeKind.value();
3505 
3506  // The ConvolutionOpInterface gives us guarantees of existence for
3507  // strides/dilations. However, we do not need to rely on those, we can
3508  // simply use them if present, otherwise use the default and let the generic
3509  // conv. matcher in the ConvGenerator succeed or fail.
3510  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3511  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3512  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3513  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3514  }
3515 
3516  /// Generate a vector implementation for:
3517  /// ```
3518  /// Op def: ( w, kw )
3519  /// Iters: ({Par(), Red()})
3520  /// Layout: {{w + kw}, {kw}, {w}}
3521  /// ```
3522  /// kw is always unrolled.
3523  ///
3524  /// or
3525  ///
3526  /// ```
3527  /// Op def: ( n, w, c, kw, f )
3528  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3529  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3530  /// ```
3531  /// kw is always unrolled.
3532  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3533  /// > 1.
3534  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3535  int64_t nSize, wSize, cSize, kwSize, fSize;
3536  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3537  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3538  switch (conv1DOpOrder) {
3539  case Conv1DOpOrder::W:
3540  // Initialize unused dimensions
3541  nSize = fSize = cSize = 0;
3542  // out{W}
3543  bindShapeDims(resShapedType, wSize);
3544  // kernel{kw}
3545  bindShapeDims(rhsShapedType, kwSize);
3546  lhsShape = {// iw = ow + kw - 1
3547  // (i.e. 16 convolved with 3 -> 14)
3548  (wSize + kwSize - 1)};
3549  rhsShape = {kwSize};
3550  resShape = {wSize};
3551  break;
3552  case Conv1DOpOrder::Nwc:
3553  // out{n, w, f}
3554  bindShapeDims(resShapedType, nSize, wSize, fSize);
3555  switch (oper) {
3556  case ConvOperationKind::Conv:
3557  // kernel{kw, c, f}
3558  bindShapeDims(rhsShapedType, kwSize, cSize);
3559  break;
3560  case ConvOperationKind::Pool:
3561  // kernel{kw}
3562  bindShapeDims(rhsShapedType, kwSize);
3563  cSize = fSize;
3564  break;
3565  }
3566  lhsShape = {nSize,
3567  // iw = ow * sw + kw * dw - 1
3568  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3569  // Perform the proper inclusive -> exclusive -> inclusive.
3570  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3571  1,
3572  cSize};
3573  switch (oper) {
3574  case ConvOperationKind::Conv:
3575  rhsShape = {kwSize, cSize, fSize};
3576  break;
3577  case ConvOperationKind::Pool:
3578  rhsShape = {kwSize};
3579  break;
3580  }
3581  resShape = {nSize, wSize, fSize};
3582  break;
3583  case Conv1DOpOrder::Ncw:
3584  // out{n, f, w}
3585  bindShapeDims(resShapedType, nSize, fSize, wSize);
3586  switch (oper) {
3587  case ConvOperationKind::Conv:
3588  // kernel{f, c, kw}
3589  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3590  break;
3591  case ConvOperationKind::Pool:
3592  // kernel{kw}
3593  bindShapeDims(rhsShapedType, kwSize);
3594  cSize = fSize;
3595  break;
3596  }
3597  lhsShape = {nSize, cSize,
3598  // iw = ow * sw + kw * dw - 1
3599  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3600  // Perform the proper inclusive -> exclusive -> inclusive.
3601  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3602  1};
3603  switch (oper) {
3604  case ConvOperationKind::Conv:
3605  rhsShape = {fSize, cSize, kwSize};
3606  break;
3607  case ConvOperationKind::Pool:
3608  rhsShape = {kwSize};
3609  break;
3610  }
3611  resShape = {nSize, fSize, wSize};
3612  break;
3613  }
3614 
3615  vector::TransferWriteOp write;
3616  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3617 
3618  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3619  // When strideW == 1, we can batch the contiguous loads and avoid
3620  // unrolling
3621  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3622 
3623  Type lhsEltType = lhsShapedType.getElementType();
3624  Type rhsEltType = rhsShapedType.getElementType();
3625  Type resEltType = resShapedType.getElementType();
3626  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3627  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3628  auto resType = VectorType::get(resShape, resEltType);
3629  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3630  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3631  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3632  SmallVector<Value> resPadding(resShape.size(), zero);
3633 
3634  // Read the whole lhs, rhs and res in one shot (with zero padding).
3635  Value lhs = vector::TransferReadOp::create(
3636  rewriter, loc, lhsType, lhsShaped, lhsPadding,
3637  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3638  // This is needed only for Conv.
3639  Value rhs = nullptr;
3640  if (oper == ConvOperationKind::Conv)
3641  rhs = vector::TransferReadOp::create(
3642  rewriter, loc, rhsType, rhsShaped, rhsPadding,
3643  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3644  Value res = vector::TransferReadOp::create(
3645  rewriter, loc, resType, resShaped, resPadding,
3646  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3647 
3648  // The base vectorization case for channeled convolution is input:
3649  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3650  // vectorization case, we do pre transpose on input, weight, and output.
3651  switch (conv1DOpOrder) {
3652  case Conv1DOpOrder::W:
3653  case Conv1DOpOrder::Nwc:
3654  // Base case, so no transposes necessary.
3655  break;
3656  case Conv1DOpOrder::Ncw: {
3657  // To match base vectorization case, we pre-transpose current case.
3658  // ncw -> nwc
3659  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3660  lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3661  // fcw -> wcf
3662  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3663 
3664  // This is needed only for Conv.
3665  if (oper == ConvOperationKind::Conv)
3666  rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3667  // nfw -> nwf
3668  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3669  res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3670  break;
3671  }
3672  }
3673 
3674  //===------------------------------------------------------------------===//
3675  // Begin vector-only rewrite part
3676  //===------------------------------------------------------------------===//
3677  // Unroll along kw and read slices of lhs and rhs.
3678  SmallVector<Value> lhsVals, rhsVals, resVals;
3679  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3680  kwSize, strideW, dilationW, wSizeStep,
3681  isSingleChanneled);
3682  // Do not do for pooling.
3683  if (oper == ConvOperationKind::Conv)
3684  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3685  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3686  wSizeStep, isSingleChanneled);
3687 
3688  auto linearIndex = [&](int64_t kw, int64_t w) {
3689  return kw * (wSize / wSizeStep) + w;
3690  };
3691 
3692  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3693  // or perform outerproduct for non-channeled convolution or perform simple
3694  // arith operation for pooling
3695  for (int64_t kw = 0; kw < kwSize; ++kw) {
3696  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3697  switch (oper) {
3698  case ConvOperationKind::Conv:
3699  if (isSingleChanneled) {
3700  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3701  lhsVals[linearIndex(kw, w)],
3702  rhsVals[kw], resVals[w]);
3703  } else {
3704  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3705  lhsVals[linearIndex(kw, w)],
3706  rhsVals[kw], resVals[w]);
3707  }
3708  break;
3709  case ConvOperationKind::Pool:
3710  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3711  resVals[w]);
3712  break;
3713  }
3714  }
3715  }
3716 
3717  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3718  isSingleChanneled);
3719  //===------------------------------------------------------------------===//
3720  // End vector-only rewrite part
3721  //===------------------------------------------------------------------===//
3722 
3723  // The base vectorization case for channeled convolution is output:
3724  // {n,w,f} To reuse the result from base pattern vectorization case, we
3725  // post transpose the base case result.
3726  switch (conv1DOpOrder) {
3727  case Conv1DOpOrder::W:
3728  case Conv1DOpOrder::Nwc:
3729  // Base case, so no transposes necessary.
3730  break;
3731  case Conv1DOpOrder::Ncw: {
3732  // nwf -> nfw
3733  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3734  res = vector::TransposeOp::create(rewriter, loc, res, perm);
3735  break;
3736  }
3737  }
3738 
3739  return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3740  resPadding)
3741  .getOperation();
3742  }
3743 
3744  // Take a value and widen to have the same element type as `ty`.
3745  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3746  const Type srcElementType = getElementTypeOrSelf(val.getType());
3747  const Type dstElementType = getElementTypeOrSelf(ty);
3748  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3749  if (srcElementType == dstElementType)
3750  return val;
3751 
3752  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3753  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3754  const Type dstType =
3755  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3756 
3757  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3758  return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3759  }
3760 
3761  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3762  srcWidth < dstWidth)
3763  return arith::ExtFOp::create(rewriter, loc, dstType, val);
3764 
3765  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3766  srcWidth < dstWidth)
3767  return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3768 
3769  assert(false && "unhandled promotion case");
3770  return nullptr;
3771  }
3772 
3773  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3774  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3775  Value lhs, Value rhs, Value res) {
3776  vector::IteratorType par = vector::IteratorType::parallel;
3777  vector::IteratorType red = vector::IteratorType::reduction;
3778  AffineExpr n, w, f, c;
3779  bindDims(ctx, n, w, f, c);
3780  lhs = promote(rewriter, loc, lhs, res.getType());
3781  rhs = promote(rewriter, loc, rhs, res.getType());
3782  auto contrationOp = vector::ContractionOp::create(
3783  rewriter, loc, lhs, rhs, res,
3784  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3785  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3786  contrationOp.setKind(reductionKind);
3787  return contrationOp;
3788  }
3789 
3790  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3791  // convolution.
3792  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3793  Value lhs, Value rhs, Value res) {
3794  return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3795  rhs, res, vector::CombiningKind::ADD);
3796  }
3797 
3798  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3799  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3800  Value res) {
3801  if (isPoolExt)
3802  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3803  return rewriter
3804  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3805  ->getResult(0);
3806  }
3807 
3808  /// Generate a vector implementation for:
3809  /// ```
3810  /// Op def: ( n, w, c, kw)
3811  /// Iters: ({Par(), Par(), Par(), Red()})
3812  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3813  /// ```
3814  /// kw is always unrolled.
3815  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3816  /// > 1.
3817  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3818  bool channelDimScalableFlag,
3819  bool flatten) {
3820  bool scalableChDim = false;
3821  bool useMasking = false;
3822  int64_t nSize, wSize, cSize, kwSize;
3823  // kernel{kw, c}
3824  bindShapeDims(rhsShapedType, kwSize, cSize);
3825  if (ShapedType::isDynamic(cSize)) {
3826  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3827  cSize = channelDimVecSize;
3828  // Scalable vectors are only used when both conditions are met:
3829  // 1. channel dim is dynamic
3830  // 2. channelDimScalableFlag is set
3831  scalableChDim = channelDimScalableFlag;
3832  useMasking = true;
3833  }
3834 
3835  assert(!(useMasking && flatten) &&
3836  "Unsupported flattened conv with dynamic shapes");
3837 
3838  // out{n, w, c}
3839  bindShapeDims(resShapedType, nSize, wSize);
3840 
3841  vector::TransferWriteOp write;
3842  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3843 
3844  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3845  // When strideW == 1, we can batch the contiguous loads and avoid
3846  // unrolling
3847  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3848 
3849  Type lhsEltType = lhsShapedType.getElementType();
3850  Type rhsEltType = rhsShapedType.getElementType();
3851  Type resEltType = resShapedType.getElementType();
3852  VectorType lhsType = VectorType::get(
3853  {nSize,
3854  // iw = ow * sw + kw * dw - 1
3855  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3856  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3857  cSize},
3858  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3859  VectorType rhsType =
3860  VectorType::get({kwSize, cSize}, rhsEltType,
3861  /*scalableDims=*/{false, scalableChDim});
3862  VectorType resType =
3863  VectorType::get({nSize, wSize, cSize}, resEltType,
3864  /*scalableDims=*/{false, false, scalableChDim});
3865 
3866  // Masks the input xfer Op along the channel dim, iff the corresponding
3867  // scalable flag is set.
3868  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3869  ArrayRef<bool> scalableDims,
3870  Operation *opToMask) {
3871  if (!useMasking)
3872  return opToMask;
3873  auto maskType =
3874  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3875 
3876  SmallVector<bool> inBounds(maskShape.size(), true);
3877  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3878  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3879  rewriter.getBoolArrayAttr(inBounds));
3880 
3882  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3883 
3884  Value maskOp =
3885  vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3886 
3887  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3888  };
3889 
3890  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3891  // 0].
3892  Value lhs = vector::TransferReadOp::create(
3893  rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3894  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3895  auto maybeMaskedLhs = maybeMaskXferOp(
3896  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3897 
3898  // Read rhs slice of size {kw, c} @ [0, 0].
3899  Value rhs = vector::TransferReadOp::create(
3900  rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3901  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3902  auto maybeMaskedRhs = maybeMaskXferOp(
3903  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3904 
3905  // Read res slice of size {n, w, c} @ [0, 0, 0].
3906  Value res = vector::TransferReadOp::create(
3907  rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
3908  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3909  auto maybeMaskedRes = maybeMaskXferOp(
3910  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3911 
3912  //===------------------------------------------------------------------===//
3913  // Begin vector-only rewrite part
3914  //===------------------------------------------------------------------===//
3915  // Unroll along kw and read slices of lhs and rhs.
3916  SmallVector<Value> lhsVals, rhsVals, resVals;
3917  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3918  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3919 
3920  // Extract lhs slice of size {n, wSizeStep, c}
3921  // @ [0, sw * w + dw * kw, 0].
3922  for (int64_t kw = 0; kw < kwSize; ++kw) {
3923  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3924  lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3925  rewriter, loc, maybeMaskedLhs->getResult(0),
3926  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3927  inOutSliceSizes, inOutStrides));
3928  }
3929  }
3930  // Extract rhs slice of size {c} @ [kw].
3931  for (int64_t kw = 0; kw < kwSize; ++kw) {
3932  rhsVals.push_back(
3933  vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3934  /*offsets=*/ArrayRef<int64_t>{kw}));
3935  }
3936  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3937  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3938  resVals.push_back(vector::ExtractStridedSliceOp::create(
3939  rewriter, loc, maybeMaskedRes->getResult(0),
3940  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3941  inOutStrides));
3942  }
3943 
3944  auto linearIndex = [&](int64_t kw, int64_t w) {
3945  return kw * (wSize / wSizeStep) + w;
3946  };
3947 
3948  // Note - the scalable flags are ignored as flattening combined with
3949  // scalable vectorization is not supported.
3950  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3951  auto lhsTypeAfterFlattening =
3952  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3953  auto resTypeAfterFlattening =
3954  VectorType::get(inOutFlattenSliceSizes, resEltType);
3955 
3956  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3957  for (int64_t kw = 0; kw < kwSize; ++kw) {
3958  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3959  Value lhsVal = lhsVals[linearIndex(kw, w)];
3960  Value resVal = resVals[w];
3961  if (flatten) {
3962  // Flatten the input and output vectors (collapse the channel
3963  // dimension)
3964  lhsVal =
3965  vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3966  lhsVals[linearIndex(kw, w)]);
3967  resVal = vector::ShapeCastOp::create(
3968  rewriter, loc, resTypeAfterFlattening, resVals[w]);
3969  }
3970  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3971  rhsVals[kw], resVal, flatten);
3972  if (flatten) {
3973  // Un-flatten the output vector (restore the channel dimension)
3974  resVals[w] = vector::ShapeCastOp::create(
3975  rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
3976  resVals[w]);
3977  }
3978  }
3979  }
3980 
3981  // Its possible we failed to create the Fma.
3982  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3983  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3984  for (auto &collection :
3985  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3986  for (Value v : collection)
3987  rewriter.eraseOp(v.getDefiningOp());
3988  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3989  }
3990 
3991  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3992  // This does not depend on kw.
3993  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3994  maybeMaskedRes = vector::InsertStridedSliceOp::create(
3995  rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
3996  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3997  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3998  }
3999  //===------------------------------------------------------------------===//
4000  // End vector-only rewrite part
4001  //===------------------------------------------------------------------===//
4002 
4003  // Write back res slice of size {n, w, c} @ [0, 0, 0].
4004  Operation *resOut = vector::TransferWriteOp::create(
4005  rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4006  ValueRange{zero, zero, zero});
4007  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4008  resOut);
4009  }
4010 
4011  /// Lower:
4012  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4013  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4014  /// to MulAcc.
4015  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4016  Value lhs, Value rhs, Value res,
4017  bool flatten) {
4018  auto rhsTy = cast<ShapedType>(rhs.getType());
4019  auto resTy = cast<ShapedType>(res.getType());
4020 
4021  // TODO(suderman): Change this to use a vector.ima intrinsic.
4022  lhs = promote(rewriter, loc, lhs, resTy);
4023 
4024  if (flatten) {
4025  // NOTE: This following logic won't work for scalable vectors. For this
4026  // reason, "flattening" is not supported when shapes are dynamic (this
4027  // should be captured by one of the pre-conditions).
4028 
4029  // There are two options for handling the filter:
4030  // * shape_cast(broadcast(filter))
4031  // * broadcast(shuffle(filter))
4032  // Opt for the option without shape_cast to simplify the codegen.
4033  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4034  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4035 
4036  SmallVector<int64_t, 16> indices;
4037  for (int i = 0; i < resSize / rhsSize; ++i) {
4038  for (int j = 0; j < rhsSize; ++j)
4039  indices.push_back(j);
4040  }
4041 
4042  rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4043  }
4044  // Broadcast the filter to match the output vector
4045  rhs = vector::BroadcastOp::create(rewriter, loc,
4046  resTy.clone(rhsTy.getElementType()), rhs);
4047 
4048  rhs = promote(rewriter, loc, rhs, resTy);
4049 
4050  if (!lhs || !rhs)
4051  return nullptr;
4052 
4053  if (isa<FloatType>(resTy.getElementType()))
4054  return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4055 
4056  auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4057  return arith::AddIOp::create(rewriter, loc, mul, res);
4058  }
4059 
4060  /// Entry point for non-channeled convolution:
4061  /// {{w + kw}, {kw}, {w}}
4062  FailureOr<Operation *> generateNonChanneledConv() {
4063  AffineExpr w, kw;
4064  bindDims(ctx, w, kw);
4065  if (!iters({Par(), Red()}))
4066  return rewriter.notifyMatchFailure(op,
4067  "failed to match conv::W 1-par 1-red");
4068 
4069  // No transposition needed.
4070  if (layout({/*lhsIndex*/ {w + kw},
4071  /*rhsIndex*/ {kw},
4072  /*resIndex*/ {w}}))
4073  return conv(Conv1DOpOrder::W);
4074 
4075  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4076  }
4077 
4078  /// Entry point that transposes into the common form:
4079  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4080  FailureOr<Operation *> generateNwcConv() {
4081  AffineExpr n, w, f, kw, c;
4082  bindDims(ctx, n, w, f, kw, c);
4083  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4084  return rewriter.notifyMatchFailure(
4085  op, "failed to match conv::Nwc 3-par 2-red");
4086 
4087  // No transposition needed.
4088  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4089  /*rhsIndex*/ {kw, c, f},
4090  /*resIndex*/ {n, w, f}}))
4091  return conv(Conv1DOpOrder::Nwc);
4092 
4093  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4094  }
4095 
4096  /// Entry point that transposes into the common form:
4097  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4098  FailureOr<Operation *> generateNcwConv() {
4099  AffineExpr n, w, f, kw, c;
4100  bindDims(ctx, n, f, w, c, kw);
4101  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4102  return rewriter.notifyMatchFailure(
4103  op, "failed to match conv::Ncw 3-par 2-red");
4104 
4105  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4106  /*rhsIndex*/ {f, c, kw},
4107  /*resIndex*/ {n, f, w}}))
4108  return conv(Conv1DOpOrder::Ncw);
4109 
4110  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4111  }
4112 
4113  /// Entry point that transposes into the common form:
4114  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4115  FailureOr<Operation *> generateNwcPooling() {
4116  AffineExpr n, w, c, kw;
4117  bindDims(ctx, n, w, c, kw);
4118  if (!iters({Par(), Par(), Par(), Red()}))
4119  return rewriter.notifyMatchFailure(op,
4120  "failed to match pooling 3-par 1-red");
4121 
4122  // No transposition needed.
4123  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4124  /*rhsIndex*/ {kw},
4125  /*resIndex*/ {n, w, c}}))
4126  return conv(Conv1DOpOrder::Nwc);
4127 
4128  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4129  }
4130 
4131  /// Entry point that transposes into the common form:
4132  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4133  FailureOr<Operation *> generateNcwPooling() {
4134  AffineExpr n, w, c, kw;
4135  bindDims(ctx, n, c, w, kw);
4136  if (!iters({Par(), Par(), Par(), Red()}))
4137  return rewriter.notifyMatchFailure(op,
4138  "failed to match pooling 3-par 1-red");
4139 
4140  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4141  /*rhsIndex*/ {kw},
4142  /*resIndex*/ {n, c, w}}))
4143  return conv(Conv1DOpOrder::Ncw);
4144 
4145  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4146  }
4147 
4148  /// Entry point that transposes into the common form:
4149  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4150  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4151  bool vecChDimScalableFlag = false,
4152  bool flatten = false) {
4153  AffineExpr n, w, c, kw;
4154  bindDims(ctx, n, w, c, kw);
4155  if (!iters({Par(), Par(), Par(), Red()}))
4156  return rewriter.notifyMatchFailure(
4157  op, "failed to match depthwise::Nwc conv 3-par 1-red");
4158 
4159  // No transposition needed.
4160  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4161  /*rhsIndex*/ {kw, c},
4162  /*resIndex*/ {n, w, c}}))
4163  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4164 
4165  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4166  }
4167 
4168 private:
4169  ConvOperationKind oper = ConvOperationKind::Conv;
4170  StringAttr redOp;
4171  StringAttr poolExtOp;
4172  bool isPoolExt = false;
4173  int strideW, dilationW;
4174  Value lhsShaped, rhsShaped, resShaped;
4175  ShapedType lhsShapedType, rhsShapedType, resShapedType;
4176  vector::CombiningKind reductionKind;
4177 
4178  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4179  void setConvOperationKind(Operation *reduceOp) {
4180  int numBlockArguments =
4181  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4182  if (numBlockArguments == 1) {
4183  // Will be convolution if feeder is a MulOp.
4184  // A strength reduced version of MulOp for i1 type is AndOp which is also
4185  // supported. Otherwise, it can be pooling. This strength reduction logic
4186  // is in `buildBinaryFn` helper in the Linalg dialect.
4187  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4188  llvm::IsaPred<BlockArgument>);
4189  Operation *feedOp = (*feedValIt).getDefiningOp();
4190  if (isCastOfBlockArgument(feedOp)) {
4191  oper = ConvOperationKind::Pool;
4192  isPoolExt = true;
4193  poolExtOp = feedOp->getName().getIdentifier();
4194  return;
4195  }
4196  oper = ConvOperationKind::Conv;
4197  return;
4198  }
4199  // numBlockArugments == 2 and this is a pooling op.
4200  oper = ConvOperationKind::Pool;
4201  isPoolExt = false;
4202  }
4203 };
4204 } // namespace
4205 
4206 /// Helper function to vectorize a LinalgOp with convolution semantics.
4207 // TODO: extend the generic vectorization to support windows and drop this.
4208 static FailureOr<Operation *> vectorizeConvolution(
4209  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4210  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4211  Conv1DGenerator conv1dGen(rewriter, op);
4212  auto res = conv1dGen.generateNonChanneledConv();
4213  if (succeeded(res))
4214  return res;
4215  res = conv1dGen.generateNwcConv();
4216  if (succeeded(res))
4217  return res;
4218  res = conv1dGen.generateNcwConv();
4219  if (succeeded(res))
4220  return res;
4221  res = conv1dGen.generateNwcPooling();
4222  if (succeeded(res))
4223  return res;
4224  res = conv1dGen.generateNcwPooling();
4225  if (succeeded(res))
4226  return res;
4227 
4228  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4229  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4230  // masked/scalable) is the channel dim (i.e. the trailing dim).
4231  uint64_t vecChDimSize = ShapedType::kDynamic;
4232  bool vecChDimScalableFlag = false;
4233  if (!inputVecSizes.empty()) {
4234  // Only use the input vector size corresponding to the channel dim. Other
4235  // vector dims will be inferred from the Ops.
4236  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4237  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4238  "Not a 1D depthwise conv!");
4239  size_t chDimIdx =
4241  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4242  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4243 
4244  vecChDimSize = inputVecSizes[chDimIdx];
4245  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4246  }
4247  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4248  flatten1DDepthwiseConv);
4249 }
4250 
4253 
4254  LogicalResult matchAndRewrite(LinalgOp op,
4255  PatternRewriter &rewriter) const override {
4256  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4257  if (failed(resultOrFail))
4258  return failure();
4259  Operation *newOp = *resultOrFail;
4260  if (newOp->getNumResults() == 0) {
4261  rewriter.eraseOp(op.getOperation());
4262  return success();
4263  }
4264  assert(newOp->getNumResults() == 1 && "expected single result");
4265  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4266  return success();
4267  }
4268 };
4269 
4272  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4273 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:5180
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:138
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:600
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:308
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:218
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:380
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:200
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:664
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2781
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:332
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Definition: Transforms.h:885
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.