MLIR  22.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Support/LLVM.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 /// Try to vectorize `convOp` as a convolution.
53 static FailureOr<Operation *>
54 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
55  ArrayRef<int64_t> inputVecSizes = {},
56  ArrayRef<bool> inputVecScalableFlags = {},
57  bool flatten1DDepthwiseConv = false);
58 
59 /// Vectorize tensor::InsertSliceOp with:
60 /// * vector::TransferReadOp + vector::TransferWriteOp
61 /// The vector sizes are either:
62 /// * user-provided in `inputVectorSizes`, or
63 /// * inferred from the static dims in the input and output tensors.
64 /// Bails out if:
65 /// * vector sizes are not user-provided, and
66 /// * at least one dim is dynamic (in both the input and output tensors).
67 ///
68 /// Before:
69 /// !t_in_type = tensor<1x2x3xf32>
70 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
71 /// !v_type = vector<1x2x3xf32>
72 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
73 /// into !t_out_type
74 /// After:
75 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
76 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
77 static LogicalResult
78 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
79  ArrayRef<int64_t> inputVectorSizes,
80  SmallVectorImpl<Value> &newResults);
81 
82 /// Returns the effective Pad value for the input op, provided it's a scalar.
83 ///
84 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
85 /// this Op performs padding, retrieve the padding value provided that it's
86 /// a scalar and static/fixed for all the padded values. Returns an empty value
87 /// otherwise.
88 static Value getStaticPadVal(Operation *op);
89 
90 /// Return the unique instance of OpType in `block` if it is indeed unique.
91 /// Return null if none or more than 1 instances exist.
92 template <typename OpType>
93 static OpType getSingleOpOfType(Block &block) {
94  OpType res;
95  block.walk([&](OpType op) {
96  if (res) {
97  res = nullptr;
98  return WalkResult::interrupt();
99  }
100  res = op;
101  return WalkResult::advance();
102  });
103  return res;
104 }
105 
106 /// Helper function to extract the input slices after filter is unrolled along
107 /// kw.
108 static SmallVector<Value>
110  int64_t nSize, int64_t wSize, int64_t cSize,
111  int64_t kwSize, int strideW, int dilationW,
112  int64_t wSizeStep, bool isSingleChanneled) {
113  SmallVector<Value> result;
114  if (isSingleChanneled) {
115  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
116  // convolution.
117  SmallVector<int64_t> sizes = {wSizeStep};
118  SmallVector<int64_t> strides = {1};
119  for (int64_t kw = 0; kw < kwSize; ++kw) {
120  for (int64_t w = 0; w < wSize; w += wSizeStep) {
121  result.push_back(vector::ExtractStridedSliceOp::create(
122  rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
123  strides));
124  }
125  }
126  } else {
127  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128  // for channeled convolution.
129  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130  SmallVector<int64_t> strides = {1, 1, 1};
131  for (int64_t kw = 0; kw < kwSize; ++kw) {
132  for (int64_t w = 0; w < wSize; w += wSizeStep) {
133  result.push_back(vector::ExtractStridedSliceOp::create(
134  rewriter, loc, input,
135  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136  sizes, strides));
137  }
138  }
139  }
140  return result;
141 }
142 
143 /// Helper function to extract the filter slices after filter is unrolled along
144 /// kw.
146  Location loc, Value filter,
147  int64_t kwSize) {
148  SmallVector<Value> result;
149  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150  // non-chanelled convolution] @ [kw].
151  for (int64_t kw = 0; kw < kwSize; ++kw) {
152  result.push_back(vector::ExtractOp::create(
153  rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
154  }
155  return result;
156 }
157 
158 /// Helper function to extract the result slices after filter is unrolled along
159 /// kw.
160 static SmallVector<Value>
162  int64_t nSize, int64_t wSize, int64_t fSize,
163  int64_t wSizeStep, bool isSingleChanneled) {
164  SmallVector<Value> result;
165  if (isSingleChanneled) {
166  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167  SmallVector<int64_t> sizes = {wSizeStep};
168  SmallVector<int64_t> strides = {1};
169  for (int64_t w = 0; w < wSize; w += wSizeStep) {
170  result.push_back(vector::ExtractStridedSliceOp::create(
171  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
172  strides));
173  }
174  } else {
175  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
176  // convolution.
177  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
178  SmallVector<int64_t> strides = {1, 1, 1};
179  for (int64_t w = 0; w < wSize; w += wSizeStep) {
180  result.push_back(vector::ExtractStridedSliceOp::create(
181  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
182  strides));
183  }
184  }
185  return result;
186 }
187 
188 /// Helper function to insert the computed result slices.
190  Value res, int64_t wSize, int64_t wSizeStep,
191  SmallVectorImpl<Value> &resVals,
192  bool isSingleChanneled) {
193 
194  if (isSingleChanneled) {
195  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196  // This does not depend on kw.
197  SmallVector<int64_t> strides = {1};
198  for (int64_t w = 0; w < wSize; w += wSizeStep) {
199  res = vector::InsertStridedSliceOp::create(
200  rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
201  strides);
202  }
203  } else {
204  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
205  // convolution. This does not depend on kw.
206  SmallVector<int64_t> strides = {1, 1, 1};
207  for (int64_t w = 0; w < wSize; w += wSizeStep) {
208  res = vector::InsertStridedSliceOp::create(
209  rewriter, loc, resVals[w], res,
210  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
211  }
212  }
213  return res;
214 }
215 
216 /// Contains the vectorization state and related methods used across the
217 /// vectorization process of a given operation.
219  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
220 
221  /// Initializes the vectorization state, including the computation of the
222  /// canonical vector shape for vectorization.
223  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224  ArrayRef<int64_t> inputVectorSizes,
225  ArrayRef<bool> inputScalableVecDims,
226  bool assumeDynamicDimsMatchVecSizes = false);
227 
228  /// Returns the canonical vector shape used to vectorize the iteration space.
229  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
230 
231  /// Returns the vector dimensions that are scalable in the canonical vector
232  /// shape.
233  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
234 
235  /// Returns a vector type of the provided `elementType` with the canonical
236  /// vector shape and the corresponding fixed/scalable dimensions bit. If
237  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
238  /// accordingly.
240  Type elementType,
241  std::optional<AffineMap> dimPermutation = std::nullopt) const {
243  SmallVector<bool> scalableDims;
244  if (dimPermutation.has_value()) {
245  vectorShape =
246  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
247  scalableDims =
248  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
249  } else {
250  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
252  }
253 
254  return VectorType::get(vectorShape, elementType, scalableDims);
255  }
256 
257  /// Masks an operation with the canonical vector mask if the operation needs
258  /// masking. Returns the masked operation or the original operation if masking
259  /// is not needed. If provided, the canonical mask for this operation is
260  /// permuted using `maybeIndexingMap`.
261  Operation *
262  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
263  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
264 
265 private:
266  /// Initializes the iteration space static sizes using the Linalg op
267  /// information. This may become more complicated in the future.
268  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
270  }
271 
272  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
273  /// all the static and dynamic dimensions of the iteration space to be
274  /// vectorized and store them in `iterSpaceValueSizes`.
275  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276  LinalgOp linalgOp);
277 
278  /// Create or retrieve an existing mask value to mask `opToMask` in the
279  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
280  /// permuted using that permutation map. If a new mask is created, it will be
281  /// cached for future users.
282  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
283  LinalgOp linalgOp,
284  std::optional<AffineMap> maybeMaskingMap);
285 
286  /// Check whether this permutation map can be used for masking. At the
287  /// moment we only make sure that there are no broadcast dimensions, but this
288  /// might change if indexing maps evolve.
289  bool isValidMaskingMap(AffineMap maskingMap) {
290  return maskingMap.getBroadcastDims().size() == 0;
291  }
292 
293  /// Turn the input indexing map into a valid masking map.
294  ///
295  /// The input indexing map may contain "zero" results, e.g.:
296  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
297  /// Applying such maps to canonical vector shapes like this one:
298  /// (1, 16, 16, 4)
299  /// would yield an invalid vector shape like this:
300  /// (16, 16, 1, 0)
301  /// Instead, drop the broadcasting dims that make no sense for masking perm.
302  /// maps:
303  /// (d0, d1, d2, d3) -> (d2, d1, d0)
304  /// This way, the corresponding vector/mask type will be:
305  /// vector<16x16x1xty>
306  /// rather than this invalid Vector type:
307  /// vector<16x16x1x0xty>
308  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
309  return indexingMap.dropZeroResults();
310  }
311 
312  // Holds the compile-time static sizes of the iteration space to vectorize.
313  // Dynamic dimensions are represented using ShapedType::kDynamic.
314  SmallVector<int64_t> iterSpaceStaticSizes;
315 
316  /// Holds the value sizes of the iteration space to vectorize. Static
317  /// dimensions are represented by 'arith.constant' and dynamic
318  /// dimensions by 'tensor/memref.dim'.
319  SmallVector<Value> iterSpaceValueSizes;
320 
321  /// Holds the canonical vector shape used to vectorize the iteration space.
322  SmallVector<int64_t> canonicalVecShape;
323 
324  /// Holds the vector dimensions that are scalable in the canonical vector
325  /// shape.
326  SmallVector<bool> scalableVecDims;
327 
328  /// Holds the active masks for permutations of the canonical vector iteration
329  /// space.
330  DenseMap<AffineMap, Value> activeMaskCache;
331 
332  /// Global vectorization guard for the incoming rewriter. It's initialized
333  /// when the vectorization state is initialized.
334  OpBuilder::InsertionGuard rewriterGuard;
335 
336  /// Do all dynamic dims match the corresponding vector sizes?
337  ///
338  /// When a dynamic tensor/memref dimension matches the corresponding vector
339  /// dimension, masking can be safely skipped, despite the presence of dynamic
340  /// shapes. Use this flag with care and only for cases where you are
341  /// confident the assumption holds.
342  bool assumeDynamicDimsMatchVecSizes = false;
343 };
344 
345 LogicalResult
346 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
347  LinalgOp linalgOp) {
348  // TODO: Support 0-d vectors.
349  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350  if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
351  // Create constant index op for static dimensions.
352  iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
353  rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
354  continue;
355  }
356 
357  // Find an operand defined on this dimension of the iteration space to
358  // extract the runtime dimension size.
359  Value operand;
360  unsigned operandDimPos;
361  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
362  operandDimPos)))
363  return failure();
364 
365  Value dynamicDim =
366  linalgOp.hasPureTensorSemantics()
367  ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
368  operandDimPos)
369  : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370  operandDimPos);
371  iterSpaceValueSizes.push_back(dynamicDim);
372  }
373 
374  return success();
375 }
376 
377 /// Initializes the vectorization state, including the computation of the
378 /// canonical vector shape for vectorization.
379 // TODO: Move this to the constructor when we can remove the failure cases.
381  LinalgOp linalgOp,
382  ArrayRef<int64_t> inputVectorSizes,
383  ArrayRef<bool> inputScalableVecDims,
384  bool assumeDimsMatchVec) {
385  assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
386  // Initialize the insertion point.
387  rewriter.setInsertionPoint(linalgOp);
388 
389  if (!inputVectorSizes.empty()) {
390  // Get the canonical vector shape from the input vector sizes provided. This
391  // path should be taken to vectorize code with dynamic shapes and when using
392  // vector sizes greater than the iteration space sizes.
393  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394  scalableVecDims.append(inputScalableVecDims.begin(),
395  inputScalableVecDims.end());
396  } else {
397  // Compute the canonical vector shape from the operation shape. If there are
398  // dynamic shapes, the operation won't be vectorized. We assume all the
399  // vector dimensions are fixed.
400  canonicalVecShape = linalgOp.getStaticLoopRanges();
401  scalableVecDims.append(linalgOp.getNumLoops(), false);
402  }
403 
404  LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405  LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
406 
407  if (ShapedType::isDynamicShape(canonicalVecShape))
408  return failure();
409 
410  // Initialize iteration space static sizes.
411  initIterSpaceStaticSizes(linalgOp);
412 
413  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
414  // all the static and dynamic dimensions of the iteration space, needed to
415  // compute a mask during vectorization.
416  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417  return failure();
418 
419  return success();
420 }
421 
422 /// Create or retrieve an existing mask value to mask `opToMask` in the
423 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
424 /// using that permutation map. If a new mask is created, it will be cached for
425 /// future users.
426 Value VectorizationState::getOrCreateMaskFor(
427  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
428  std::optional<AffineMap> maybeMaskingMap) {
429 
430  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431  "Ill-formed masking map.");
432 
433  // No mask is needed if the operation is not maskable.
434  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
435  if (!maskableOp)
436  return Value();
437 
438  assert(!maskableOp.isMasked() &&
439  "Masking an operation that is already masked");
440 
441  // If no masking map was provided, use an identity map with the loop dims.
442  assert((!maybeMaskingMap || *maybeMaskingMap) &&
443  "Unexpected null mask permutation map");
444  AffineMap maskingMap =
445  maybeMaskingMap ? *maybeMaskingMap
447  linalgOp.getNumLoops(), rewriter.getContext());
448 
449  LDBG() << "Masking map: " << maskingMap;
450 
451  // Return the active mask for the masking map of this operation if it was
452  // already created.
453  auto activeMaskIt = activeMaskCache.find(maskingMap);
454  if (activeMaskIt != activeMaskCache.end()) {
455  Value mask = activeMaskIt->second;
456  LDBG() << "Reusing mask: " << mask;
457  return mask;
458  }
459 
460  // Compute permuted projection of the iteration space to be masked and the
461  // corresponding mask shape. If the resulting iteration space dimensions are
462  // static and identical to the mask shape, masking is not needed for this
463  // operation.
464  // TODO: Improve this check. Only projected permutation indexing maps are
465  // supported.
466  SmallVector<int64_t> permutedStaticSizes =
467  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
469  auto maskShape = maskType.getShape();
470 
471  LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
472 
473  if (permutedStaticSizes == maskShape) {
474  LDBG() << "Masking is not needed for masking map: " << maskingMap;
475  activeMaskCache[maskingMap] = Value();
476  return Value();
477  }
478 
479  if (assumeDynamicDimsMatchVecSizes) {
480  // While for _dynamic_ dim sizes we can _assume_ that the corresponding
481  // vector sizes match, we still need to check the _static_ dim sizes. Only
482  // then we can be 100% sure that masking is not required.
483  if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
484  [](auto it) {
485  return std::get<0>(it) == ShapedType::kDynamic
486  ? true
487  : std::get<0>(it) == std::get<1>(it);
488  })) {
489  LDBG()
490  << "Dynamic + static dimensions match vector sizes, masking is not "
491  "required.";
492  activeMaskCache[maskingMap] = Value();
493  return Value();
494  }
495  }
496 
497  // Permute the iteration space value sizes to compute the mask upper bounds.
498  SmallVector<Value> upperBounds =
499  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
500  assert(!maskShape.empty() && !upperBounds.empty() &&
501  "Masked 0-d vectors are not supported yet");
502 
503  // Create the mask based on the dimension values.
504  Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505  maskType, upperBounds);
506  LDBG() << "Creating new mask: " << mask;
507  activeMaskCache[maskingMap] = mask;
508  return mask;
509 }
510 
511 Operation *
513  LinalgOp linalgOp,
514  std::optional<AffineMap> maybeIndexingMap) {
515  LDBG() << "Trying to mask: " << *opToMask;
516 
517  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518  if (maybeIndexingMap)
519  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
520 
521  // Create or retrieve mask for this operation.
522  Value mask =
523  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
524 
525  if (!mask) {
526  LDBG() << "No mask required";
527  return opToMask;
528  }
529 
530  // Wrap the operation with a new `vector.mask` and update D-U chain.
531  assert(opToMask && "Expected a valid operation to mask");
532  auto maskOp = cast<vector::MaskOp>(
533  mlir::vector::maskOperation(rewriter, opToMask, mask));
534  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
535 
536  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
537  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
538  maskOpTerminator);
539 
540  LDBG() << "Masked operation: " << *maskOp;
541  return maskOp;
542 }
543 
544 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
545 /// projectedPermutation, compress the unused dimensions to serve as a
546 /// permutation_map for a vector transfer operation.
547 /// For example, given a linalg op such as:
548 ///
549 /// ```
550 /// %0 = linalg.generic {
551 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
552 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
553 /// }
554 /// ins(%0 : tensor<2x3x4xf32>)
555 /// outs(%1 : tensor<5x6xf32>)
556 /// ```
557 ///
558 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
559 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
560 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
562  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
563  "expected projected permutation");
564  auto res = compressUnusedDims(map);
565  assert(res.getNumDims() ==
566  (res.getNumResults() - res.getNumOfZeroResults()) &&
567  "expected reindexed map with same number of dims and results");
568  return res;
569 }
570 
571 /// Helper enum to represent conv1d input traversal order.
572 enum class Conv1DOpOrder {
573  W, // Corresponds to non-channeled 1D convolution operation.
574  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
575  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
576 };
577 
578 /// Helper data structure to represent the result of vectorization for a single
579 /// operation. In certain specific cases, like terminators, we do not want to
580 /// propagate.
582  /// Op failed to vectorize.
583  Failure = 0,
584  /// Op vectorized and custom function took care of replacement logic
586  /// Op vectorized into a new Op whose results will replace original Op's
587  /// results.
588  NewOp
589  // TODO: support values if Op vectorized to Many-Ops whose results we need to
590  // aggregate for replacement.
591 };
592 /// VectorizationHookResult contains the vectorized op returned from a
593 /// CustomVectorizationHook. This is an internal implementation detail of
594 /// linalg vectorization, not to be confused with VectorizationResult.
596  /// Return status from vectorizing the current op.
598  /// New vectorized operation to replace the current op.
599  /// Replacement behavior is specified by `status`.
601 };
602 
603 std::optional<vector::CombiningKind>
605  using ::mlir::vector::CombiningKind;
606 
607  if (!combinerOp)
608  return std::nullopt;
610  .Case<arith::AddIOp, arith::AddFOp>(
611  [&](auto op) { return CombiningKind::ADD; })
612  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
613  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
614  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
615  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
616  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
617  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
618  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
619  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
620  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
621  .Case<arith::MulIOp, arith::MulFOp>(
622  [&](auto op) { return CombiningKind::MUL; })
623  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
624  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
625  .Default([&](auto op) { return std::nullopt; });
626 }
627 
628 /// Check whether `outputOperand` is a reduction with a single combiner
629 /// operation. Return the combiner operation of the reduction. Return
630 /// nullptr otherwise. Multiple reduction operations would impose an
631 /// ordering between reduction dimensions and is currently unsupported in
632 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
633 /// max(min(X))
634 // TODO: use in LinalgOp verification, there is a circular dependency atm.
635 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
636  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
637  unsigned outputPos =
638  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
639  // Only single combiner operations are supported for now.
640  SmallVector<Operation *, 4> combinerOps;
641  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
642  combinerOps.size() != 1)
643  return nullptr;
644 
645  // Return the combiner operation.
646  return combinerOps[0];
647 }
648 
649 /// Broadcast `value` to a vector of `shape` if possible. Return value
650 /// otherwise.
651 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
652  auto dstVecType = dyn_cast<VectorType>(dstType);
653  // If no shape to broadcast to, just return `value`.
654  if (dstVecType.getRank() == 0)
655  return value;
656  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
658  return value;
659  Location loc = b.getInsertionPoint()->getLoc();
660  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
661 }
662 
663 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
664 /// assumes that `reductionOp` has two operands and one of them is the reduction
665 /// initial value.buildMultiDimReduce
666 // Note: this is a true builder that notifies the OpBuilder listener.
667 // TODO: Consider moving as a static helper on the ReduceOp.
669  Value valueToReduce, Value acc,
670  ArrayRef<bool> dimsToMask) {
671  auto maybeKind = getCombinerOpKind(reduceOp);
672  assert(maybeKind && "Failed precondition: could not get reduction kind");
673  return vector::MultiDimReductionOp::create(
674  b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
675 }
676 
677 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
678  return llvm::to_vector(
679  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
680 }
681 
682 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
683 /// reduction iterator.
684 static bool hasReductionIterator(LinalgOp &op) {
685  return isa<linalg::ReduceOp>(op) ||
686  (isa<linalg::GenericOp>(op) &&
687  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
688 }
689 
690 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
691 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
692 /// currently being vectorized. If `dest` has null rank, build an memref.store.
693 /// Return the produced value or null if no value is produced.
694 // Note: this is a true builder that notifies the OpBuilder listener.
695 // TODO: Consider moving as a static helper on the ReduceOp.
696 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
697  OpOperand *outputOperand,
698  VectorizationState &state) {
699  Location loc = value.getLoc();
700  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
701  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
702 
703  // Compute the vector type of the value to store. This type should be an
704  // identity or projection of the canonical vector type without any permutation
705  // applied, given that any permutation in a transfer write happens as part of
706  // the write itself.
708  opOperandMap.getContext(), opOperandMap.getNumInputs(),
709  [&](AffineDimExpr dimExpr) -> bool {
710  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
711  });
712  auto vectorType = state.getCanonicalVecType(
713  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
714 
715  Operation *write;
716  if (vectorType.getRank() > 0) {
717  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
718  SmallVector<Value> indices(
719  linalgOp.getRank(outputOperand),
720  arith::ConstantIndexOp::create(rewriter, loc, 0));
721  value = broadcastIfNeeded(rewriter, value, vectorType);
722  assert(value.getType() == vectorType && "Incorrect type");
723  write = vector::TransferWriteOp::create(
724  rewriter, loc, value, outputOperand->get(), indices, writeMap);
725  } else {
726  // 0-d case is still special: do not invert the reindexing writeMap.
727  if (!isa<VectorType>(value.getType()))
728  value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
729  assert(value.getType() == vectorType && "Incorrect type");
730  write = vector::TransferWriteOp::create(rewriter, loc, value,
731  outputOperand->get(), ValueRange{});
732  }
733 
734  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
735 
736  // If masked, set in-bounds to true. Masking guarantees that the access will
737  // be in-bounds.
738  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
739  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
740  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
741  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
742  }
743 
744  LDBG() << "vectorized op: " << *write;
745  if (!write->getResults().empty())
746  return write->getResult(0);
747  return Value();
748 }
749 
750 // Custom vectorization precondition function type. This is intented to be used
751 // with CustomVectorizationHook. Returns success if the corresponding custom
752 // hook can vectorize the op.
754  std::function<LogicalResult(Operation *, bool)>;
755 
756 // Custom vectorization function type. Produce a vector form of Operation*
757 // assuming all its vectorized operands are already in the IRMapping.
758 // Return nullptr if the Operation cannot be vectorized.
760  std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
761 
762 /// Helper function to vectorize the terminator of a `linalgOp`. New result
763 /// vector values are appended to `newResults`. Return
764 /// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
765 /// that it should not try to map produced operations and instead return the
766 /// results using the `newResults` vector making them available to the
767 /// vectorization algorithm for RAUW. This function is meant to be used as a
768 /// CustomVectorizationHook.
771  const IRMapping &bvm, VectorizationState &state,
772  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
773  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
774  if (!yieldOp)
776  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
777  // TODO: Scan for an opportunity for reuse.
778  // TODO: use a map.
779  Value vectorValue = bvm.lookup(output.value());
780  Value newResult =
781  buildVectorWrite(rewriter, vectorValue,
782  linalgOp.getDpsInitOperand(output.index()), state);
783  if (newResult)
784  newResults.push_back(newResult);
785  }
786 
788 }
789 
790 /// Helper function to vectorize the index operations of a `linalgOp`. Return
791 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
792 /// should map the produced operations. This function is meant to be used as a
793 /// CustomVectorizationHook.
795  VectorizationState &state,
796  Operation *op,
797  LinalgOp linalgOp) {
798  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
799  if (!indexOp)
801  auto loc = indexOp.getLoc();
802  // Compute the static loop sizes of the index op.
803  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
804  auto dim = indexOp.getDim();
805  // Compute a one-dimensional index vector for the index op dimension.
806  auto indexVectorType =
807  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
808  state.getScalableVecDims()[dim]);
809  auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
810  // Return the one-dimensional index vector if it lives in the trailing
811  // dimension of the iteration space since the vectorization algorithm in this
812  // case can handle the broadcast.
813  if (dim == targetShape.size() - 1)
815  // Otherwise permute the targetShape to move the index dimension last,
816  // broadcast the one-dimensional index vector to the permuted shape, and
817  // finally transpose the broadcasted index vector to undo the permutation.
818  auto permPattern =
819  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
820  std::swap(permPattern[dim], permPattern.back());
821  auto permMap =
822  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
823 
824  auto broadCastOp = vector::BroadcastOp::create(
825  rewriter, loc,
826  state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
827  SmallVector<int64_t> transposition =
828  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
829  std::swap(transposition.back(), transposition[dim]);
830  auto transposeOp =
831  vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
833 }
834 
835 /// Helper function to check if the tensor.extract can be vectorized by the
836 /// custom hook vectorizeTensorExtract.
837 static LogicalResult
838 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
839  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
840  if (!extractOp)
841  return failure();
842 
843  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
844  return failure();
845 
846  // Check the index type, but only for non 0-d tensors (for which we do need
847  // access indices).
848  if (not extractOp.getIndices().empty()) {
849  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
850  return failure();
851  }
852 
853  if (!llvm::all_of(extractOp->getResultTypes(),
854  VectorType::isValidElementType)) {
855  return failure();
856  }
857 
858  return success();
859 }
860 
861 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
862 /// generated from `tensor.extract`. The offset is calculated as follows
863 /// (example using scalar values):
864 ///
865 /// offset = extractOp.indices[0]
866 /// for (i = 1; i < numIndices; i++)
867 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
868 ///
869 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
870 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
872  VectorizationState &state,
873  tensor::ExtractOp extractOp,
874  const IRMapping &bvm) {
875  // The vector of indices for GatherOp should be shaped as the output vector.
876  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
877  auto loc = extractOp.getLoc();
878 
879  Value offset = broadcastIfNeeded(
880  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
881 
882  const size_t numIndices = extractOp.getIndices().size();
883  for (size_t i = 1; i < numIndices; i++) {
884  Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
885 
886  auto dimSize = broadcastIfNeeded(
887  rewriter,
888  tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
889  indexVecType);
890 
891  offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
892 
893  auto extractOpIndex = broadcastIfNeeded(
894  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
895 
896  offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
897  }
898 
899  return offset;
900 }
901 
903 
904 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
905 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
906 /// represents a contiguous load operation.
907 ///
908 /// Note that when calling this hook, it is assumed that the output vector is
909 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
910 /// labelled as a gather load before entering this method.
911 ///
912 /// Following on from the above, it is assumed that:
913 /// * for statically shaped loops, when no masks are used, only one dim is !=
914 /// 1 (that's what the shape of the output vector is based on).
915 /// * for dynamically shaped loops, there might be more non-unit dims
916 /// as the output vector type is user-specified.
917 ///
918 /// TODO: Statically shaped loops + vector masking
919 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
920  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
921  assert(
922  (linalgOp.hasDynamicShape() ||
923  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
924  "For statically shaped Linalg Ops, only one "
925  "non-unit loop dim is expected");
926  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
927 
928  size_t idx = loopRanges.size() - 1;
929  for (; idx != 0; idx--)
930  if (loopRanges[idx] != 1)
931  break;
932 
933  return idx;
934 }
935 
936 /// Checks whether `val` can be used for calculating a loop invariant index.
937 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
938  VectorType resType) {
939 
940  assert(((llvm::count_if(resType.getShape(),
941  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942  "n-D vectors are not yet supported");
943 
944  // Blocks outside _this_ linalg.generic are effectively loop invariant.
945  // However, analysing block arguments for _this_ linalg.generic Op is a bit
946  // tricky. Just bail out in the latter case.
947  // TODO: We could try analysing the corresponding affine map here.
948  auto *block = linalgOp.getBlock();
949  if (isa<BlockArgument>(val))
950  return !llvm::is_contained(block->getArguments(), val);
951 
952  Operation *defOp = val.getDefiningOp();
953  assert(defOp && "This is neither a block argument nor an operation result");
954 
955  // IndexOp is loop invariant as long as its result remains constant across
956  // iterations. Note that for dynamic shapes, the corresponding dim will also
957  // be conservatively treated as != 1.
958  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
960  }
961 
962  auto *ancestor = block->findAncestorOpInBlock(*defOp);
963 
964  // Values define outside `linalgOp` are loop invariant.
965  if (!ancestor)
966  return true;
967 
968  // Values defined inside `linalgOp`, which are constant, are loop invariant.
969  if (isa<arith::ConstantOp>(ancestor))
970  return true;
971 
972  bool result = true;
973  for (auto op : ancestor->getOperands())
974  result &= isLoopInvariantIdx(linalgOp, op, resType);
975 
976  return result;
977 }
978 
979 /// Check whether `val` could be used for calculating the trailing index for a
980 /// contiguous load operation.
981 ///
982 /// There are currently 3 types of values that are allowed here:
983 /// 1. loop-invariant values,
984 /// 2. values that increment by 1 with every loop iteration,
985 /// 3. results of basic arithmetic operations (linear and continuous)
986 /// involving 1., 2. and 3.
987 /// This method returns True if indeed only such values are used in calculating
988 /// `val.`
989 ///
990 /// Additionally, the trailing index for a contiguous load operation should
991 /// increment by 1 with every loop iteration, i.e. be based on:
992 /// * `linalg.index <dim>` ,
993 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
994 /// `linalg.index <dim>` increments by 1 with every loop iteration).
995 /// `foundIndexOp` is updated to `true` when such Op is found.
996 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
997  bool &foundIndexOp, VectorType resType) {
998 
999  assert(((llvm::count_if(resType.getShape(),
1000  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1001  "n-D vectors are not yet supported");
1002 
1003  // Blocks outside _this_ linalg.generic are effectively loop invariant.
1004  // However, analysing block arguments for _this_ linalg.generic Op is a bit
1005  // tricky. Just bail out in the latter case.
1006  // TODO: We could try analysing the corresponding affine map here.
1007  auto *block = linalgOp.getBlock();
1008  if (isa<BlockArgument>(val))
1009  return !llvm::is_contained(block->getArguments(), val);
1010 
1011  Operation *defOp = val.getDefiningOp();
1012  assert(defOp && "This is neither a block argument nor an operation result");
1013 
1014  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1015  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1016 
1017  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1018  return true;
1019  }
1020 
1021  auto *ancestor = block->findAncestorOpInBlock(*defOp);
1022 
1023  if (!ancestor)
1024  return false;
1025 
1026  // Conservatively reject Ops that could lead to indices with stride other
1027  // than 1.
1028  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1029  return false;
1030 
1031  bool result = false;
1032  for (auto op : ancestor->getOperands())
1033  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1034 
1035  return result;
1036 }
1037 
1038 /// Infer the memory access pattern for the input ExtractOp
1039 ///
1040 /// Based on the ExtratOp result shape and the access indices, decides whether
1041 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1042 /// or a gather load. When analysing the ExtractOp indices (to identify
1043 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1044 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1045 /// Op).
1046 ///
1047 /// Note that it is always safe to use gather load operations for contiguous
1048 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1049 /// that `extractOp` is a gather load.
1051 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1052  LinalgOp &linalgOp, VectorType resType) {
1053 
1054  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1055 
1056  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1057  if (inputShape.getShape().empty())
1059 
1060  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1061  // otherwise.
1062  bool isOutput1DVector =
1063  (llvm::count_if(resType.getShape(),
1064  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1065  // 1. Assume that it's a gather load when reading non-1D vector.
1066  if (!isOutput1DVector)
1068 
1069  bool leadingIdxsLoopInvariant = true;
1070 
1071  // 2. Analyze the leading indices of `extractOp`.
1072  // Look at the way each index is calculated and decide whether it is suitable
1073  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1074  // gather load.
1075  auto indices = extractOp.getIndices();
1076  auto leadIndices = indices.drop_back(1);
1077 
1078  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1079  if (inputShape.getShape()[i] == 1)
1080  continue;
1081 
1082  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1083  }
1084 
1085  if (!leadingIdxsLoopInvariant) {
1086  LDBG() << "Found gather load: " << extractOp;
1088  }
1089 
1090  // 3. Analyze the trailing index for `extractOp`.
1091  // At this point we know that the leading indices are loop invariant. This
1092  // means that is potentially a scalar or a contiguous load. We can decide
1093  // based on the trailing idx.
1094  auto extractOpTrailingIdx = indices.back();
1095 
1096  // 3a. Scalar broadcast load
1097  // If the trailing index is loop invariant then this is a scalar load.
1098  if (leadingIdxsLoopInvariant &&
1099  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1100  LDBG() << "Found scalar broadcast load: " << extractOp;
1101 
1103  }
1104 
1105  // 3b. Contiguous loads
1106  // The trailing `extractOp` index should increment with every loop iteration.
1107  // This effectively means that it must be based on the trailing loop index.
1108  // This is what the following bool captures.
1109  bool foundIndexOp = false;
1110  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1111  foundIndexOp, resType);
1112  // TODO: Support generating contiguous loads for column vectors - that will
1113  // require adding a permutation map to tranfer_read Ops.
1114  bool isRowVector = resType.getShape().back() != 1;
1115  isContiguousLoad &= (foundIndexOp && isRowVector);
1116 
1117  if (isContiguousLoad) {
1118  LDBG() << "Found contigous load: " << extractOp;
1120  }
1121 
1122  // 4. Fallback case - gather load.
1123  LDBG() << "Found gather load: " << extractOp;
1125 }
1126 
1127 /// Helper function to vectorize the tensor.extract operations. Returns
1128 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1129 /// should map the produced operations. This function is meant to be used as a
1130 /// CustomVectorizationHook.
1133  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1134  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1135  if (!extractOp)
1137  auto loc = extractOp.getLoc();
1138 
1139  // Compute the static loop sizes of the extract op.
1140  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1141  auto maskConstantOp = arith::ConstantOp::create(
1142  rewriter, loc,
1143  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1144  /*value=*/true));
1145  auto passThruConstantOp = arith::ConstantOp::create(
1146  rewriter, loc, rewriter.getZeroAttr(resultType));
1147 
1148  // Base indices are currently set to 0. We will need to re-visit if more
1149  // generic scenarios are to be supported.
1150  SmallVector<Value> baseIndices(
1151  extractOp.getIndices().size(),
1152  arith::ConstantIndexOp::create(rewriter, loc, 0));
1153 
1154  VectorMemoryAccessKind memAccessKind =
1155  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1156 
1157  // 1. Handle gather access
1158  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1159  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1160 
1161  // Generate the gather load
1162  Operation *gatherOp = vector::GatherOp::create(
1163  rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1164  maskConstantOp, passThruConstantOp);
1165  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1166 
1167  LDBG() << "Vectorised as gather load: " << extractOp;
1169  }
1170 
1171  // 2. Handle:
1172  // a. scalar loads + broadcast,
1173  // b. contiguous loads.
1174  // Both cases use vector.transfer_read.
1175 
1176  // Collect indices for `vector.transfer_read`. At this point, the indices will
1177  // either be scalars or would have been broadcast to vectors matching the
1178  // result type. For indices that are vectors, there are two options:
1179  // * for non-trailing indices, all elements are identical (contiguous
1180  // loads are identified by looking for non-trailing indices that are
1181  // invariant with respect to the corresponding linalg.generic), or
1182  // * for trailing indices, the index vector will contain values with stride
1183  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1184  // needed.
1185  // This means that
1186  // * for scalar indices - just re-use it,
1187  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1188  // (0th) element and use that.
1189  SmallVector<Value> transferReadIdxs;
1190  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1191  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1192  if (idx.getType().isIndex()) {
1193  transferReadIdxs.push_back(idx);
1194  continue;
1195  }
1196 
1197  auto indexAs1dVector = vector::ShapeCastOp::create(
1198  rewriter, loc,
1199  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1200  resultType.getScalableDims().back()),
1201  idx);
1202  transferReadIdxs.push_back(
1203  vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1204  }
1205 
1206  // `tensor.extract_element` is always in-bounds, hence the following holds.
1207  auto dstRank = resultType.getRank();
1208  auto srcRank = extractOp.getTensor().getType().getRank();
1209  SmallVector<bool> inBounds(dstRank, true);
1210 
1211  // 2a. Handle scalar broadcast access.
1212  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1213  MLIRContext *ctx = rewriter.getContext();
1214  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1215  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1216 
1217  auto transferReadOp = vector::TransferReadOp::create(
1218  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1219  /*padding=*/std::nullopt, permutationMap, inBounds);
1220 
1221  // Mask this broadcasting xfer_read here rather than relying on the generic
1222  // path (the generic path assumes identity masking map, which wouldn't be
1223  // valid here).
1224  SmallVector<int64_t> readMaskShape = {1};
1225  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1226  auto allTrue = vector::ConstantMaskOp::create(
1227  rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1228  auto *maskedReadOp =
1229  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1230 
1231  LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1233  maskedReadOp};
1234  }
1235 
1236  // 2b. Handle contiguous access.
1237  auto permutationMap = AffineMap::getMinorIdentityMap(
1238  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1239 
1240  int32_t rankDiff = dstRank - srcRank;
1241  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1242  // dims so that the ranks match. This is done by extending the map with 0s.
1243  // For example, for dstRank = 3, srcRank = 2, the following map created
1244  // above:
1245  // (d0, d1) --> (d0, d1)
1246  // is extended as:
1247  // (d0, d1) --> (0, d0, d1)
1248  while (rankDiff > 0) {
1249  permutationMap = permutationMap.insertResult(
1250  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1251  rankDiff--;
1252  }
1253 
1254  auto transferReadOp = vector::TransferReadOp::create(
1255  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1256  /*padding=*/std::nullopt, permutationMap, inBounds);
1257 
1258  LDBG() << "Vectorised as contiguous load: " << extractOp;
1260  transferReadOp};
1261 }
1262 
1263 /// Emit reduction operations if the shapes of the value to reduce is different
1264 /// that the result shape.
1265 // Note: this is a true builder that notifies the OpBuilder listener.
1266 // TODO: Consider moving as a static helper on the ReduceOp.
1267 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1268  Value reduceValue, Value initialValue,
1269  const IRMapping &bvm) {
1270  Value reduceVec = bvm.lookup(reduceValue);
1271  Value outputVec = bvm.lookup(initialValue);
1272  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1273  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1274  // Reduce only if needed as the value may already have been reduce for
1275  // contraction vectorization.
1276  if (!reduceType ||
1277  (outputType && reduceType.getShape() == outputType.getShape()))
1278  return nullptr;
1279  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1280  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1281 }
1282 
1283 /// Generic vectorization for a single operation `op`, given already vectorized
1284 /// operands carried by `bvm`. Vectorization occurs as follows:
1285 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1286 /// result on success.
1287 /// 2. Clone any constant in the current scope without vectorization: each
1288 /// consumer of the constant will later determine the shape to which the
1289 /// constant needs to be broadcast to.
1290 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1291 /// of the `customVectorizationHooks` to cover such cases.
1292 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1293 /// operand of maximal rank. Other operands have smaller rank and are
1294 /// broadcast accordingly. It is assumed this broadcast is always legal,
1295 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1296 ///
1297 /// This function assumes all operands of `op` have been vectorized and are in
1298 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1299 /// a topologically-sorted list of ops.
1300 /// This function does not update `bvm` but returns a VectorizationHookStatus
1301 /// that instructs the caller what `bvm` update needs to occur.
1304  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1305  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1306  LDBG() << "vectorize op " << *op;
1307 
1308  // 1. Try to apply any CustomVectorizationHook.
1309  if (!customVectorizationHooks.empty()) {
1310  for (auto &customFunc : customVectorizationHooks) {
1311  VectorizationHookResult result = customFunc(op, bvm);
1313  continue;
1314  return result;
1315  }
1316  }
1317 
1318  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1319  // Clone so that the constant is not confined to the linalgOp block .
1320  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1322  rewriter.clone(*op)};
1323 
1324  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1327 
1328  // 4 . Check if the operation is a reduction.
1329  SmallVector<std::pair<Value, Value>> reductionOperands;
1330  for (Value operand : op->getOperands()) {
1331  auto blockArg = dyn_cast<BlockArgument>(operand);
1332  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1333  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1334  continue;
1335  SmallVector<Operation *> reductionOps;
1336  Value reduceValue = matchReduction(
1337  linalgOp.getRegionOutputArgs(),
1338  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1339  if (!reduceValue)
1340  continue;
1341  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1342  }
1343  if (!reductionOperands.empty()) {
1344  assert(reductionOperands.size() == 1);
1345  Operation *reduceOp =
1346  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1347  reductionOperands[0].second, bvm);
1348  if (reduceOp)
1350  }
1351 
1352  // 5. Generic vectorization path for ElementwiseMappable ops.
1353  // a. Get the first max ranked shape.
1354  VectorType firstMaxRankedType;
1355  for (Value operand : op->getOperands()) {
1356  auto vecOperand = bvm.lookup(operand);
1357  assert(vecOperand && "Vector operand couldn't be found");
1358 
1359  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1360  if (vecType && (!firstMaxRankedType ||
1361  firstMaxRankedType.getRank() < vecType.getRank()))
1362  firstMaxRankedType = vecType;
1363  }
1364  // b. Broadcast each op if needed.
1365  SmallVector<Value> vecOperands;
1366  for (Value scalarOperand : op->getOperands()) {
1367  Value vecOperand = bvm.lookup(scalarOperand);
1368  assert(vecOperand && "Vector operand couldn't be found");
1369 
1370  if (firstMaxRankedType) {
1371  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1372  getElementTypeOrSelf(vecOperand.getType()),
1373  firstMaxRankedType.getScalableDims());
1374  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1375  } else {
1376  vecOperands.push_back(vecOperand);
1377  }
1378  }
1379  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1380  SmallVector<Type> resultTypes;
1381  for (Type resultType : op->getResultTypes()) {
1382  resultTypes.push_back(
1383  firstMaxRankedType
1384  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1385  firstMaxRankedType.getScalableDims())
1386  : resultType);
1387  }
1388  // d. Build and return the new op.
1389  return VectorizationHookResult{
1391  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1392  resultTypes, op->getAttrs())};
1393 }
1394 
1395 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1396 /// vector form. Generic vectorization proceeds as follows:
1397 /// 1. Verify the `linalgOp` has one non-empty region.
1398 /// 2. Values defined above the region are mapped to themselves and will be
1399 /// broadcasted on a per-need basis by their consumers.
1400 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1401 /// load).
1402 /// TODO: Reuse opportunities for RAR dependencies.
1403 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1404 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1405 /// iteration indices.
1406 /// 5. Iteratively call vectorizeOneOp on the region operations.
1407 ///
1408 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1409 /// performed to the maximal common vector size implied by the `linalgOp`
1410 /// iteration space. This eager broadcasting is introduced in the
1411 /// permutation_map of the vector.transfer_read operations. The eager
1412 /// broadcasting makes it trivial to determine where broadcast, transposes and
1413 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1414 /// the absence of good canonicalizations, the amount of work increases.
1415 /// This is not deemed a problem as we expect canonicalizations and foldings to
1416 /// aggressively clean up the useless work.
1417 static LogicalResult
1419  LinalgOp linalgOp,
1420  SmallVectorImpl<Value> &newResults) {
1421  LDBG() << "Vectorizing operation as linalg generic/n";
1422  Block *block = linalgOp.getBlock();
1423 
1424  // 2. Values defined above the region can only be broadcast for now. Make them
1425  // map to themselves.
1426  IRMapping bvm;
1427  SetVector<Value> valuesSet;
1428  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1429  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1430 
1431  if (linalgOp.getNumDpsInits() == 0)
1432  return failure();
1433 
1434  // 3. Turn all BBArgs into vector.transfer_read / load.
1435  Location loc = linalgOp.getLoc();
1436  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1437  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1438  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1439  if (linalgOp.isScalar(opOperand)) {
1440  bvm.map(bbarg, opOperand->get());
1441  continue;
1442  }
1443 
1444  // 3.a. Convert the indexing map for this input/output to a transfer read
1445  // permutation map and masking map.
1446  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1447 
1448  AffineMap readMap;
1449  VectorType readType;
1450  Type elemType = getElementTypeOrSelf(opOperand->get());
1451  if (linalgOp.isDpsInput(opOperand)) {
1452  // 3.a.i. For input reads we use the canonical vector shape.
1453  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1454  readType = state.getCanonicalVecType(elemType);
1455  } else {
1456  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1457  // reductions), the vector shape is computed by mapping the canonical
1458  // vector shape to the output domain and back to the canonical domain.
1459  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1460  readType =
1461  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1462  }
1463 
1464  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1465 
1466  Operation *read = vector::TransferReadOp::create(
1467  rewriter, loc, readType, opOperand->get(), indices,
1468  /*padding=*/std::nullopt, readMap);
1469  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1470  Value readValue = read->getResult(0);
1471 
1472  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1473  // will be in-bounds.
1474  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1475  SmallVector<bool> inBounds(readType.getRank(), true);
1476  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1477  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1478  }
1479 
1480  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1481  // TODO: remove this.
1482  if (readType.getRank() == 0)
1483  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1484  ArrayRef<int64_t>());
1485 
1486  LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1487  << "): " << readValue;
1488  bvm.map(bbarg, readValue);
1489  bvm.map(opOperand->get(), readValue);
1490  }
1491 
1493  // 4a. Register CustomVectorizationHook for yieldOp.
1494  CustomVectorizationHook vectorizeYield =
1495  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1496  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1497  };
1498  hooks.push_back(vectorizeYield);
1499 
1500  // 4b. Register CustomVectorizationHook for indexOp.
1501  CustomVectorizationHook vectorizeIndex =
1502  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1503  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1504  };
1505  hooks.push_back(vectorizeIndex);
1506 
1507  // 4c. Register CustomVectorizationHook for extractOp.
1508  CustomVectorizationHook vectorizeExtract =
1509  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1510  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1511  };
1512  hooks.push_back(vectorizeExtract);
1513 
1514  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1515  for (Operation &op : block->getOperations()) {
1516  VectorizationHookResult result =
1517  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1518  if (result.status == VectorizationHookStatus::Failure) {
1519  LDBG() << "failed to vectorize: " << op;
1520  return failure();
1521  }
1522  if (result.status == VectorizationHookStatus::NewOp) {
1523  Operation *maybeMaskedOp =
1524  state.maskOperation(rewriter, result.newOp, linalgOp);
1525  LDBG() << "New vector op: " << *maybeMaskedOp;
1526  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1527  }
1528  }
1529 
1530  return success();
1531 }
1532 
1533 /// Given a linalg::PackOp, return the `dest` shape before any packing
1534 /// permutations.
1535 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1536  ArrayRef<int64_t> destShape) {
1537  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1538 }
1539 
1540 /// Determines whether a mask for xfer_write is trivially "all true"
1541 ///
1542 /// Given all the inputs required to generate a mask (mask sizes and shapes),
1543 /// and an xfer_write operation (write indices and the destination tensor
1544 /// shape), determines whether the corresponding mask would be trivially
1545 /// foldable (i.e., trivially "all true").
1546 ///
1547 /// Use this method to avoid generating spurious masks and relaying on
1548 /// vectorization post-processing to remove them.
1549 ///
1550 /// Pre-conditions for a mask to be trivially foldable:
1551 /// * All involved shapes (mask + destination tensor) are static.
1552 /// * All write indices are constant.
1553 /// * All mask sizes are constant (including `arith.constant`).
1554 ///
1555 /// If the pre-conditions are met, the method checks for each destination
1556 /// dimension `d`:
1557 /// (1) destDimSize[rankDiff + d] <= maskShape[d]
1558 /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1559 ///
1560 /// rankDiff = rank(dest) - rank(mask).
1561 ///
1562 /// This method takes a conservative view: it may return false even if the mask
1563 /// is technically foldable.
1564 ///
1565 /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1566 /// of the dest tensor):
1567 /// %c0 = arith.constant 0 : index
1568 /// %mask = vector.create_mask 5, 1
1569 /// vector.mask %mask {
1570 /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1571 /// {in_bounds = [true, true]}
1572 /// : vector<5x1xi32>, tensor<5x1xi32>
1573 /// }
1574 ///
1575 /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1576 /// mask is required to avoid out-of-bounds write):
1577 /// %c0 = arith.constant 0 : index
1578 /// %mask = vector.create_mask 5, 1
1579 /// vector.mask %mask {
1580 /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1581 /// {in_bounds = [true, true]}
1582 /// : vector<8x1xi32>, tensor<5x1xi32>
1583 /// }
1584 ///
1585 /// TODO: Re-use in createReadOrMaskedRead
1587  SmallVector<Value> &writeIdxs,
1588  ArrayRef<int64_t> destShape,
1589  ArrayRef<int64_t> maskShape) {
1590  // Masking is unavoidable in the case of dynamic tensors.
1591  if (ShapedType::isDynamicShape(destShape))
1592  return false;
1593 
1594  // Collect all constant mask sizes.
1595  SmallVector<int64_t, 4> cstMaskSizes;
1596  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1597  if (auto intSize = getConstantIntValue(dimSize)) {
1598  cstMaskSizes.push_back(*intSize);
1599  }
1600  }
1601 
1602  // If any of the mask sizes is non-constant, bail out.
1603  if (cstMaskSizes.size() != maskShape.size())
1604  return false;
1605 
1606  // Collect all constant write indices.
1607  SmallVector<int64_t, 4> cstWriteIdxs;
1608  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1609  APSInt intVal;
1610  if (matchPattern(idx, m_ConstantInt(&intVal))) {
1611  cstWriteIdxs.push_back(intVal.getSExtValue());
1612  }
1613  }
1614 
1615  // If any of the write indices is non-constant, bail out.
1616  if (cstWriteIdxs.size() != destShape.size())
1617  return false;
1618 
1619  // Go over all destination dims and check (1) and (2). Take into account that:
1620  // * The number of mask sizes will match the rank of the vector to store.
1621  // This could be lower than the rank of the destination tensor.
1622  // * Mask sizes could be larger than the corresponding mask shape (hence
1623  // `clamp`).
1624  // TODO: The 2nd item should be rejected by the verifier.
1625  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1626  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1627  if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1628  /*(2)*/ destShape[rankDiff + i] <
1629  (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1630  cstWriteIdxs[i]))
1631  return false;
1632  }
1633 
1634  return true;
1635 }
1636 
1637 /// Creates an optionally masked TransferWriteOp
1638 ///
1639 /// Generates the following operation:
1640 /// %res = vector.transfer_write %vecToStore into %dest
1641 ///
1642 /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1643 ///
1644 /// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1645 /// %res = vector.mask %mask {
1646 /// vector.transfer_write %vecToStore into %dest
1647 /// }
1648 ///
1649 /// The mask shape is identical to `vecToStore` (with the element type ==
1650 /// i1), and the mask values are based on the shape of the `dest` tensor.
1651 ///
1652 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1653 /// is used instead of masking:
1654 ///
1655 /// %write = vector.transfer_write %vecToStore into %dest
1656 /// in_bounds_flags = (...)
1657 /// %res = vector.transfer_write %input into %dest
1658 /// {in_bounds = in_bounds_flags}
1659 ///
1660 /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1661 /// are set to 0.
1662 static Operation *
1664  Value dest, SmallVector<Value> writeIndices = {},
1665  bool useInBoundsInsteadOfMasking = false) {
1666 
1667  ShapedType destType = cast<ShapedType>(dest.getType());
1668  int64_t destRank = destType.getRank();
1669  auto destShape = destType.getShape();
1670 
1671  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1672  int64_t vecToStoreRank = vecToStoreType.getRank();
1673  auto vecToStoreShape = vecToStoreType.getShape();
1674 
1675  // Compute the in_bounds attribute
1676  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1677  if (useInBoundsInsteadOfMasking) {
1678  // Update the inBounds attribute.
1679  // FIXME: This computation is too weak - it ignores the write indices.
1680  for (unsigned i = 0; i < vecToStoreRank; i++)
1681  inBoundsVal[i] =
1682  (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1683  ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1684  }
1685 
1686  // If missing, initialize the write indices to 0.
1687  assert((writeIndices.empty() ||
1688  writeIndices.size() == static_cast<size_t>(destRank)) &&
1689  "Invalid number of write indices!");
1690  if (writeIndices.empty()) {
1691  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1692  writeIndices.assign(destRank, zero);
1693  }
1694 
1695  // Generate the xfer_write Op
1696  Operation *write = vector::TransferWriteOp::create(builder, loc,
1697  /*vector=*/vecToStore,
1698  /*source=*/dest,
1699  /*indices=*/writeIndices,
1700  /*inBounds=*/inBoundsVal);
1701 
1702  // If masking is disabled, exit.
1703  if (useInBoundsInsteadOfMasking)
1704  return write;
1705 
1706  // Check if masking is needed. If not, exit.
1707  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1708  return write;
1709 
1710  // Compute the mask and mask the write Op.
1711  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1712  vecToStoreType.getScalableDims());
1713 
1714  SmallVector<OpFoldResult> destSizes =
1715  isa<MemRefType>(dest.getType())
1716  ? memref::getMixedSizes(builder, loc, dest)
1717  : tensor::getMixedSizes(builder, loc, dest);
1718  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1719  destSizes.end());
1720 
1721  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1722  vecToStoreShape))
1723  return write;
1724 
1725  Value maskForWrite =
1726  builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1727  return mlir::vector::maskOperation(builder, write, maskForWrite);
1728 }
1729 
1730 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1731 /// padding value and (3) input vector sizes into:
1732 ///
1733 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1734 ///
1735 /// As in the following example:
1736 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1737 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1738 ///
1739 /// This pack would be vectorized to:
1740 ///
1741 /// %load = vector.mask %mask {
1742 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1743 /// {in_bounds = [true, true, true]} :
1744 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1745 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1746 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1747 /// to vector<32x4x2x1x16xf32>
1748 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1749 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1750 /// %write = vector.transfer_write %transpose,
1751 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1752 /// {in_bounds = [true, true, true, true, true]}
1753 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1754 ///
1755 /// If the (3) input vector sizes are not provided, the vector sizes are
1756 /// determined by the result tensor shape and the `in_bounds`
1757 /// attribute is used instead of masking to mark out-of-bounds accesses.
1758 ///
1759 /// NOTE: The input vector sizes specify the dimensions corresponding to the
1760 /// outer dimensions of the output tensor. The remaining dimensions are
1761 /// computed based on, e.g., the static inner tiles.
1762 /// Supporting dynamic inner tiles will require the user to specify the
1763 /// missing vector sizes. This is left as a TODO.
1764 static LogicalResult
1765 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1766  ArrayRef<int64_t> inputVectorSizes,
1767  SmallVectorImpl<Value> &newResults) {
1768  // TODO: Introduce a parent class that will handle the insertion point update.
1769  OpBuilder::InsertionGuard g(rewriter);
1770  rewriter.setInsertionPoint(packOp);
1771 
1772  Location loc = packOp.getLoc();
1773  auto padValue = packOp.getPaddingValue();
1774  if (!padValue) {
1775  padValue = arith::ConstantOp::create(
1776  rewriter, loc,
1777  rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1778  }
1779  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1780  LogicalResult status =
1781  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1782  .reifyResultShapes(rewriter, reifiedReturnShapes);
1783  (void)status; // prevent unused variable warning on non-assert builds.
1784  assert(succeeded(status) && "failed to reify result shapes");
1785 
1786  // If the input vector sizes are not provided, then the vector sizes are
1787  // determined by the result tensor shape. In case the vector sizes aren't
1788  // provided, we update the inBounds attribute instead of masking.
1789  bool useInBoundsInsteadOfMasking = false;
1790  if (inputVectorSizes.empty()) {
1791  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1792  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1793  useInBoundsInsteadOfMasking = true;
1794  }
1795 
1796  // Create masked TransferReadOp.
1797  SmallVector<int64_t> inputShape(inputVectorSizes);
1798  auto innerTiles = packOp.getStaticInnerTiles();
1799  auto innerDimsPos = packOp.getInnerDimsPos();
1800  auto outerDimsPerm = packOp.getOuterDimsPerm();
1801  if (!outerDimsPerm.empty())
1802  applyPermutationToVector(inputShape,
1804  for (auto [idx, size] : enumerate(innerTiles))
1805  inputShape[innerDimsPos[idx]] *= size;
1806  auto maskedRead = vector::createReadOrMaskedRead(
1807  rewriter, loc, packOp.getSource(), inputShape, padValue,
1808  useInBoundsInsteadOfMasking,
1809  /*inputScalableVecSizes=*/{});
1810 
1811  // Create ShapeCastOp.
1812  SmallVector<int64_t> destShape(inputVectorSizes);
1813  destShape.append(innerTiles.begin(), innerTiles.end());
1814  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1815  packOp.getDestType().getElementType());
1816  auto shapeCastOp =
1817  vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1818 
1819  // Create TransposeOp.
1820  auto destPermutation =
1822  auto transposeOp = vector::TransposeOp::create(
1823  rewriter, loc, shapeCastOp.getResult(), destPermutation);
1824 
1825  // Create TransferWriteOp.
1826  Value dest = tensor::EmptyOp::create(
1827  rewriter, loc, reifiedReturnShapes[0],
1828  transposeOp.getResult().getType().getElementType());
1829  Operation *write =
1830  createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
1831  newResults.push_back(write->getResult(0));
1832  return success();
1833 }
1834 
1835 /// Given the re-associations, "collapses" the input Vector type
1836 ///
1837 /// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1838 /// differences:
1839 /// * We can safely assume that there are no dynamic sizes.
1840 /// * Scalable flags are updated alongside regular dims.
1841 ///
1842 /// When collapsing scalable flags, conservatively avoids cases with two
1843 /// scalable dims. We could re-visit this in the future.
1844 ///
1845 /// EXAMPLE:
1846 /// type = vector<4x16x[8]x16xf32>
1847 /// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1848 /// (d0, d1, d2, d3) -> (d2, d3)]
1849 /// Result:
1850 /// vector<64x[128]xf32>
1851 static VectorType getCollapsedVecType(VectorType type,
1852  ArrayRef<AffineMap> reassociation) {
1853  assert(type.getNumScalableDims() < 2 &&
1854  "Collapsing more than 1 scalable dim is not supported ATM");
1855 
1856  // Use the fact that reassociation is valid to simplify the logic: only use
1857  // each map's rank.
1858  assert(isReassociationValid(reassociation) && "invalid reassociation");
1859 
1860  auto shape = type.getShape();
1861  auto scalableFlags = type.getScalableDims();
1862  SmallVector<int64_t> newShape;
1863  SmallVector<bool> newScalableFlags;
1864 
1865  unsigned currentDim = 0;
1866  for (AffineMap m : reassociation) {
1867  unsigned dim = m.getNumResults();
1868  int64_t size = 1;
1869  bool flag = false;
1870  for (unsigned d = 0; d < dim; ++d) {
1871  size *= shape[currentDim + d];
1872  flag |= scalableFlags[currentDim + d];
1873  }
1874  newShape.push_back(size);
1875  newScalableFlags.push_back(flag);
1876  currentDim += dim;
1877  }
1878 
1879  return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1880 }
1881 
1882 /// Vectorize `linalg.unpack` as:
1883 /// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884 ///
1885 /// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886 /// for the xfer_read operation). This is sufficient to infer the other vector
1887 /// sizes required here.
1888 ///
1889 /// If the vector sizes are not provided:
1890 /// * the vector sizes are determined from the input tensor static shape.
1891 /// * the inBounds attribute is used instead of masking.
1892 ///
1893 /// EXAMPLE (no vector sizes):
1894 /// ```
1895 /// %unpack = linalg.unpack %src
1896 /// inner_dims_pos = [0, 1]
1897 /// inner_tiles = [8, 8]
1898 /// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1899 /// ```
1900 /// is vectorized as:
1901 /// ```
1902 /// %read = vector.transfer_read %src
1903 /// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1904 /// %tr = vector.transpose %read, [0, 2, 1, 3]
1905 /// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1906 /// %sc = vector.shape_cast %tr
1907 /// : vector<1x8x1x8xf32> to vector<8x8xf32>
1908 /// %vector = vector.transfer_write %sc into %dest
1909 /// : vector<8x8xf32>, tensor<8x8xf32>
1910 /// ```
1911 static LogicalResult
1912 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1913  ArrayRef<int64_t> inputVectorSizes,
1914  ArrayRef<bool> inputScalableVecDims,
1915  SmallVectorImpl<Value> &newResults) {
1916  if (!inputVectorSizes.empty()) {
1917  assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1918  "Invalid number of input vector sizes!");
1919  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1920  "Incompatible number of vector sizes and vector scalable flags!");
1921  }
1922 
1923  // TODO: Introduce a parent class that will handle the insertion point update.
1924  OpBuilder::InsertionGuard g(rewriter);
1925  rewriter.setInsertionPoint(unpackOp);
1926 
1927  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1928 
1929  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1930  bool useInBoundsInsteadOfMasking = false;
1931 
1932  Location loc = unpackOp->getLoc();
1933 
1934  // Obtain vector sizes for the read operation.
1935  SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1936  SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1937 
1938  // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1939  if (inputVectorSizes.empty()) {
1940  if (ShapedType::isDynamicShape(sourceShape))
1941  return failure();
1942 
1943  readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1944  useInBoundsInsteadOfMasking = true;
1945  }
1946 
1947  // -- Generate the read operation --
1948  auto padValue = arith::ConstantOp::create(
1949  rewriter, loc,
1950  rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1952  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1953  useInBoundsInsteadOfMasking, readScalableVectorFlags);
1954 
1955  // -- Generate the transpose operation --
1956  PackingMetadata packMetadata;
1957  SmallVector<int64_t> lastDimToInsertPosPerm =
1958  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1959  vector::TransposeOp transposeOp = vector::TransposeOp::create(
1960  rewriter, loc, readResult, lastDimToInsertPosPerm);
1961 
1962  // -- Generate the shape_cast operation --
1963  VectorType collapsedVecType = getCollapsedVecType(
1964  transposeOp.getType(),
1966  rewriter.getContext(), packMetadata.reassociations)));
1967  vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1968  rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1969 
1970  // -- Generate the write operation --
1972  rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1973  /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1974 
1975  newResults.push_back(write->getResult(0));
1976  return success();
1977 }
1978 
1979 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1980 /// and (3) all-zero lowPad to
1981 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1982 static LogicalResult
1983 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1984  ArrayRef<int64_t> inputVectorSizes,
1985  SmallVectorImpl<Value> &newResults) {
1986  auto padValue = padOp.getConstantPaddingValue();
1987  Location loc = padOp.getLoc();
1988 
1989  // TODO: Introduce a parent class that will handle the insertion point update.
1990  OpBuilder::InsertionGuard g(rewriter);
1991  rewriter.setInsertionPoint(padOp);
1992 
1993  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1994  LogicalResult status =
1995  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1996  .reifyResultShapes(rewriter, reifiedReturnShapes);
1997  (void)status; // prevent unused variable warning on non-assert builds
1998  assert(succeeded(status) && "failed to reify result shapes");
1999  auto maskedRead = vector::createReadOrMaskedRead(
2000  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2001  /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
2002 
2003  // Create Xfer write Op
2004  Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2005  padOp.getResultType().getElementType());
2006  Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2007  newResults.push_back(write->getResult(0));
2008  return success();
2009 }
2010 
2011 // TODO: probably need some extra checks for reduction followed by consumer
2012 // ops that may not commute (e.g. linear reduction + non-linear instructions).
2013 static LogicalResult reductionPreconditions(LinalgOp op) {
2014  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2015  LDBG() << "reduction precondition failed: no reduction iterator";
2016  return failure();
2017  }
2018  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2019  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2020  if (indexingMap.isPermutation())
2021  continue;
2022 
2023  Operation *reduceOp = matchLinalgReduction(&opOperand);
2024  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2025  LDBG() << "reduction precondition failed: reduction detection failed";
2026  return failure();
2027  }
2028  }
2029  return success();
2030 }
2031 
2032 static LogicalResult
2034  bool flatten1DDepthwiseConv) {
2035  if (flatten1DDepthwiseConv) {
2036  LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2037  "supported";
2038  return failure();
2039  }
2040 
2041  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2042  LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2043  return failure();
2044  }
2045 
2046  // Support dynamic shapes in 1D depthwise convolution, but only in the
2047  // _channel_ dimension.
2048  Value lhs = conv.getDpsInputOperand(0)->get();
2049  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2050  auto shapeWithoutCh = lhsShape.drop_back(1);
2051  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2052  LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2053  "channel dim can be dynamic";
2054  return failure();
2055  }
2056 
2057  return success();
2058 }
2059 
2060 static LogicalResult
2062  bool flatten1DDepthwiseConv) {
2063  if (isa<ConvolutionOpInterface>(op.getOperation()))
2064  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2065 
2066  if (hasReductionIterator(op))
2067  return reductionPreconditions(op);
2068 
2069  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2070  // linalg.copy ops and ops that implement ContractionOpInterface for now.
2071  if (!isElementwise(op) &&
2072  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2073  op.getOperation()))
2074  return failure();
2075 
2076  LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2077  return success();
2078 }
2079 
2080 //// This hook considers two cases:
2081 /// (1) If the input-vector-sizes are empty, then the vector sizes will be
2082 /// infered. This is only possible when all shapes are static.
2083 /// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2084 /// carry out basic sanity-checking.
2085 static LogicalResult
2086 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2087  ArrayRef<int64_t> inputVectorSizes) {
2088  // If there are no input vector sizes and all shapes are static, there is
2089  // nothing left to check.
2090  if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2091  unpackOp.getSourceType().hasStaticShape())
2092  return success();
2093 
2094  // The number of input vector sizes must be equal to:
2095  // * read-vector-rank
2096  if (!inputVectorSizes.empty() &&
2097  (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2098  LDBG() << "Incorrect number of input vector sizes";
2099  return failure();
2100  }
2101 
2102  // Check the vector sizes for the read operation.
2104  unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2105  LDBG() << "Invalid vector sizes for the read operation";
2106  return failure();
2107  }
2108 
2109  return success();
2110 }
2111 
2112 static LogicalResult
2113 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2114  ArrayRef<int64_t> inputVectorSizes) {
2115 
2116  TypedValue<RankedTensorType> source = sliceOp.getSource();
2117  auto sourceType = source.getType();
2118  if (!VectorType::isValidElementType(sourceType.getElementType()))
2119  return failure();
2120 
2121  // Get the pad value.
2122  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2123  // scalar padding value. Note that:
2124  // * for in-bounds accesses,
2125  // the value is actually irrelevant. There are 2 cases in which xfer.read
2126  // accesses are known to be in-bounds:
2127  // 1. The source shape is static (output vector sizes would be based on
2128  // the source shape and hence all memory accesses would be in-bounds),
2129  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2130  // this case it is safe to assume that all memory accesses are in-bounds.
2131  //
2132  // When the value is not known and not needed, use 0. Otherwise, bail out.
2133  Value padValue = getStaticPadVal(sliceOp);
2134  bool isOutOfBoundsRead =
2135  !sourceType.hasStaticShape() && inputVectorSizes.empty();
2136 
2137  if (!padValue && isOutOfBoundsRead) {
2138  LDBG() << "Failed to get a pad value for out-of-bounds read access";
2139  return failure();
2140  }
2141  return success();
2142 }
2143 
2144 /// Vectorize a named linalg contraction op into:
2145 /// vector::TransferReadOp - Reads vectors from the operands
2146 /// vector::ContractionOp - Performs contraction
2147 /// vector::TransferWriteOp - Write the result vector back to the
2148 /// destination
2149 /// The operands shapes are preserved and loaded directly into vectors.
2150 /// Any further permutations or numerical casting remain within contraction op.
2151 static LogicalResult
2153  LinalgOp linalgOp,
2154  SmallVectorImpl<Value> &newResults) {
2155  Location loc = linalgOp.getLoc();
2156  MLIRContext *ctx = linalgOp.getContext();
2157 
2158  // For simplicity, contraction vectorization is limited to linalg named ops.
2159  // Generic op is ignored as not every arbitrary contraction body can be
2160  // expressed by a vector.contract.
2161  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2162  return failure();
2163 
2164  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2165  Operation *reduceOp = matchLinalgReduction(outOperand);
2166  auto maybeKind = getCombinerOpKind(reduceOp);
2167  if (!maybeKind) {
2168  LDBG() << "Failed to determine contraction combining kind.";
2169  return failure();
2170  }
2171 
2172  // Check that all dimensions are present in the input operands.
2173  // Arbitrary broadcasts are not supported by the vector contraction.
2174  // Broadcasts are expected to be decomposed before vectorization.
2175  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2176  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2177  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2178  LDBG() << "Contractions with broadcasts are not supported.";
2179  return failure();
2180  }
2181 
2182  // Load operands.
2183  SmallVector<Value> vecOperands;
2184  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2185  // The operand vector shape is computed by mapping the canonical vector
2186  // shape to the operand's domain. Further permutations are left as a part of
2187  // the contraction.
2188  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2190  indexingMap.getNumResults(), rewriter.getContext());
2191  Type elemType = getElementTypeOrSelf(opOperand.get());
2192  VectorType readType =
2193  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2194 
2196  rewriter, loc, opOperand.get(), readType.getShape(),
2197  /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2198  /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2199  vecOperands.push_back(read);
2200  }
2201 
2202  // Remap iterators from linalg to vector.
2203  SmallVector<Attribute> iterAttrs;
2204  auto iterators = linalgOp.getIteratorTypesArray();
2205  for (utils::IteratorType iter : iterators) {
2206  auto vecIter = iter == utils::IteratorType::parallel
2207  ? vector::IteratorType::parallel
2208  : vector::IteratorType::reduction;
2209  iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2210  }
2211 
2212  // Create contraction.
2213  Operation *contractOp = vector::ContractionOp::create(
2214  rewriter, loc, /*lhs=*/vecOperands[0],
2215  /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2216  linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2217  contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2218 
2219  // Store result.
2221  rewriter, loc, contractOp->getResult(0), outOperand->get());
2222 
2223  // Finalize.
2224  if (!write->getResults().empty())
2225  newResults.push_back(write->getResult(0));
2226 
2227  return success();
2228 }
2229 
2230 namespace {
2231 enum class ConvOperationKind { Conv, Pool };
2232 } // namespace
2233 
2235  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2236  isa<BlockArgument>(op->getOperand(0));
2237 }
2238 
2239 // Returns the ConvOperationKind of the op using reduceOp of the generic
2240 // payload. If it is neither a convolution nor a pooling, it returns
2241 // std::nullopt.
2242 //
2243 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2244 // + yield) and rhs is not used) then it is the body of a pooling
2245 // If conv, check for single `mul` predecessor. The `mul` operands must be
2246 // block arguments or extension of block arguments.
2247 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2248 // must be block arguments or extension of block arguments.
2249 static std::optional<ConvOperationKind>
2251  int numBlockArguments =
2252  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2253 
2254  switch (numBlockArguments) {
2255  case 1: {
2256  // Will be convolution if feeder is a MulOp.
2257  // A strength reduced version of MulOp for i1 type is AndOp which is also
2258  // supported. Otherwise, it can be pooling. This strength reduction logic
2259  // is in `buildBinaryFn` helper in the Linalg dialect.
2260  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2261  llvm::IsaPred<BlockArgument>);
2262  assert(feedValIt != reduceOp->operand_end() &&
2263  "Expected a non-block argument operand");
2264  Operation *feedOp = (*feedValIt).getDefiningOp();
2265  if (isCastOfBlockArgument(feedOp)) {
2266  return ConvOperationKind::Pool;
2267  }
2268 
2269  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2270  (isa<arith::AndIOp>(feedOp) &&
2271  feedOp->getResultTypes()[0].isInteger(1))) &&
2272  llvm::all_of(feedOp->getOperands(), [](Value v) {
2273  if (isa<BlockArgument>(v))
2274  return true;
2275  if (Operation *op = v.getDefiningOp())
2276  return isCastOfBlockArgument(op);
2277  return false;
2278  }))) {
2279  return std::nullopt;
2280  }
2281 
2282  return ConvOperationKind::Conv;
2283  }
2284  case 2:
2285  // Must be pooling
2286  return ConvOperationKind::Pool;
2287  default:
2288  return std::nullopt;
2289  }
2290 }
2291 
2292 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2293  switch (kind) {
2294  case vector::CombiningKind::ADD:
2295  case vector::CombiningKind::MAXNUMF:
2296  case vector::CombiningKind::MAXIMUMF:
2297  case vector::CombiningKind::MAXSI:
2298  case vector::CombiningKind::MAXUI:
2299  case vector::CombiningKind::MINNUMF:
2300  case vector::CombiningKind::MINIMUMF:
2301  case vector::CombiningKind::MINSI:
2303  return true;
2304  default:
2305  return false;
2306  }
2307 }
2308 
2309 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2310  auto getOperandType = [&](auto operand) {
2311  return dyn_cast<ShapedType>((operand->get()).getType());
2312  };
2313  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2314  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2315  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2316  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2317  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2318  // Note that this also ensures 2D and 3D convolutions are rejected.
2319  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2320  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2321  return failure();
2322 
2323  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2324  if (!reduceOp)
2325  return failure();
2326 
2327  auto maybeOper = getConvOperationKind(reduceOp);
2328  if (!maybeOper.has_value())
2329  return failure();
2330 
2331  auto maybeKind = getCombinerOpKind(reduceOp);
2332  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2333  // can get strength reduced to `OR` which is also supported. This strength
2334  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2335  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2336  *maybeKind != vector::CombiningKind::OR) &&
2337  (*maybeOper != ConvOperationKind::Pool ||
2338  !isSupportedPoolKind(*maybeKind)))) {
2339  return failure();
2340  }
2341 
2342  auto rhsRank = rhsShapedType.getRank();
2343  if (*maybeOper == ConvOperationKind::Pool) {
2344  if (rhsRank != 1)
2345  return failure();
2346  } else {
2347  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2348  return failure();
2349  }
2350 
2351  return success();
2352 }
2353 
2354 static LogicalResult vectorizeLinalgOpPrecondition(
2355  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2356  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2357  // tensor with dimension of 0 cannot be vectorized.
2358  if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2359  return llvm::is_contained(linalgOp.getShape(&operand), 0);
2360  }))
2361  return failure();
2362  // Check API contract for input vector sizes.
2363  if (!inputVectorSizes.empty() &&
2364  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2365  inputVectorSizes)))
2366  return failure();
2367 
2368  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2369  linalgOp, flatten1DDepthwiseConv))) {
2370  LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2371  return failure();
2372  }
2373 
2375 
2376  // Register CustomVectorizationPrecondition for extractOp.
2377  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2378 
2379  // All types in the body should be a supported element type for VectorType.
2380  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2381  // Check if any custom hook can vectorize the inner op.
2382  if (llvm::any_of(
2383  customPreconditions,
2384  [&](const CustomVectorizationPrecondition &customPrecondition) {
2385  return succeeded(
2386  customPrecondition(&innerOp, vectorizeNDExtract));
2387  })) {
2388  continue;
2389  }
2390  if (!llvm::all_of(innerOp.getOperandTypes(),
2391  VectorType::isValidElementType)) {
2392  return failure();
2393  }
2394  if (!llvm::all_of(innerOp.getResultTypes(),
2395  VectorType::isValidElementType)) {
2396  return failure();
2397  }
2398  }
2399  if (isElementwise(linalgOp))
2400  return success();
2401 
2402  // TODO: isaConvolutionOpInterface that can also infer from generic
2403  // features. But we will still need stride/dilation attributes that will be
2404  // annoying to reverse-engineer...
2405  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2406  return vectorizeConvOpPrecondition(linalgOp);
2407 
2408  // TODO: the common vector shape is equal to the static loop sizes only when
2409  // all indexing maps are projected permutations. For convs and stencils the
2410  // logic will need to evolve.
2411  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2412  LDBG() << "precondition failed: not projected permutations";
2413  return failure();
2414  }
2415  if (failed(reductionPreconditions(linalgOp))) {
2416  LDBG() << "precondition failed: reduction preconditions";
2417  return failure();
2418  }
2419  return success();
2420 }
2421 
2422 static LogicalResult
2423 vectorizePackOpPrecondition(linalg::PackOp packOp,
2424  ArrayRef<int64_t> inputVectorSizes) {
2425  auto padValue = packOp.getPaddingValue();
2426  Attribute cstAttr;
2427  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2428  LDBG() << "pad value is not constant: " << packOp;
2429  return failure();
2430  }
2431 
2432  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2433  bool satisfyEmptyCond = true;
2434  if (inputVectorSizes.empty()) {
2435  if (!packOp.getDestType().hasStaticShape() ||
2436  !packOp.getSourceType().hasStaticShape())
2437  satisfyEmptyCond = false;
2438  }
2439 
2440  if (!satisfyEmptyCond &&
2442  resultTensorShape.take_front(packOp.getSourceRank()),
2443  inputVectorSizes)))
2444  return failure();
2445 
2446  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2447  return !getConstantIntValue(v).has_value();
2448  })) {
2449  LDBG() << "inner_tiles must be constant: " << packOp;
2450  return failure();
2451  }
2452 
2453  return success();
2454 }
2455 
2456 static LogicalResult
2457 vectorizePadOpPrecondition(tensor::PadOp padOp,
2458  ArrayRef<int64_t> inputVectorSizes) {
2459  auto padValue = padOp.getConstantPaddingValue();
2460  if (!padValue) {
2461  LDBG() << "pad value is not constant: " << padOp;
2462  return failure();
2463  }
2464 
2465  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2466  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2467  inputVectorSizes)))
2468  return failure();
2469 
2470  // Padding with non-zero low pad values is not supported, unless the
2471  // corresponding result dim is 1 as this would require shifting the results to
2472  // the right for the low padded dims by the required amount of low padding.
2473  // However, we do support low padding if the dims being low padded have result
2474  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2475  // input size of that dimension will be dynamically zero (as the sum of the
2476  // low pad and input dim size has to be one) and hence we will create a zero
2477  // mask as the lowering logic just makes the mask one for the input dim size -
2478  // which is zero here. Hence we will load the pad value which is what we want
2479  // in this case. If the low pad is dynamically zero then the lowering is
2480  // correct as well as no shifts are necessary.
2481  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2482  Value padValue = en.value();
2483  unsigned pos = en.index();
2484  std::optional<int64_t> pad = getConstantIntValue(padValue);
2485  return (!pad.has_value() || pad.value() != 0) &&
2486  resultTensorShape[pos] != 1;
2487  })) {
2488  LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2489  return failure();
2490  }
2491 
2492  return success();
2493 }
2494 
2495 /// Preconditions for scalable vectors.
2496 ///
2497 /// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2498 /// models the fact that in practice we would only make selected dimensions
2499 /// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2500 /// unconditionally - we are yet to identify meaningful conditions.
2501 static LogicalResult
2503  ArrayRef<int64_t> inputVectorSizes,
2504  ArrayRef<bool> inputScalableVecDims) {
2505  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2506  "Number of input vector sizes and scalable dims doesn't match");
2507 
2508  size_t numOfScalableDims =
2509  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2510 
2511  if (numOfScalableDims == 0)
2512  return success();
2513 
2514  auto linalgOp = dyn_cast<LinalgOp>(op);
2515 
2516  // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2517  // exception of UnpackOp for which there is a dedicated hook.
2518  if (!linalgOp) {
2519  return success(isa<linalg::UnPackOp>(op));
2520  }
2521 
2522  // Cond 2: There's been no need for more than 2 scalable dims so far
2523  if (numOfScalableDims > 2)
2524  return failure();
2525 
2526  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2527  // it matches one of the supported cases:
2528  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2529  // (*).
2530  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2531  // parallel dims.
2532  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2533  // dim.
2534  // The 2nd restriction above means that only Matmul-like Ops are supported
2535  // when 2 dims are scalable, e.g. :
2536  // * iterators = [parallel, parallel, reduction]
2537  // * scalable flags = [true, true, false]
2538  //
2539  // (*) Non-unit dims get folded away in practice.
2540  // TODO: Relax these conditions as good motivating examples are identified.
2541 
2542  // Find the first scalable flag.
2543  bool seenNonUnitParallel = false;
2544  auto iterators = linalgOp.getIteratorTypesArray();
2545  SmallVector<bool> scalableFlags(inputScalableVecDims);
2546  int64_t idx = scalableFlags.size() - 1;
2547  while (!scalableFlags[idx]) {
2548  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2549  seenNonUnitParallel |=
2550  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2551 
2552  iterators.pop_back();
2553  scalableFlags.pop_back();
2554  --idx;
2555  }
2556 
2557  // Analyze the iterator corresponding to the first scalable dim.
2558  switch (iterators.back()) {
2559  case utils::IteratorType::reduction: {
2560  // Check 3. above is met.
2561  if (iterators.size() != inputVectorSizes.size()) {
2562  LDBG() << "Non-trailing reduction dim requested for scalable "
2563  "vectorization";
2564  return failure();
2565  }
2566  if (isa<linalg::MatmulOp>(op)) {
2567  LDBG()
2568  << "Scalable vectorization of the reduction dim in Matmul-like ops "
2569  "is not supported";
2570  return failure();
2571  }
2572  break;
2573  }
2574  case utils::IteratorType::parallel: {
2575  // Check 1. and 2. above are met.
2576  if (seenNonUnitParallel) {
2577  LDBG() << "Inner parallel dim not requested for scalable "
2578  "vectorization";
2579  return failure();
2580  }
2581  break;
2582  }
2583  }
2584 
2585  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2586  // supported for which expect the folowing config:
2587  // * iterators = [parallel, parallel, reduction]
2588  // * scalable flags = [true, true, false]
2589  if (numOfScalableDims == 2) {
2590  // Disallow below case which breaks 3. above:
2591  // * iterators = [..., parallel, reduction]
2592  // * scalable flags = [..., true, true]
2593  if (iterators.back() == utils::IteratorType::reduction) {
2594  LDBG() << "Higher dim than the trailing reduction dim requested for "
2595  "scalable "
2596  "vectorizatio";
2597  return failure();
2598  }
2599  scalableFlags.pop_back();
2600  iterators.pop_back();
2601 
2602  if (!scalableFlags.back() ||
2603  (iterators.back() != utils::IteratorType::parallel))
2604  return failure();
2605  }
2606 
2607  // Cond 4: Only the following ops are supported in the
2608  // presence of scalable vectors
2609  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2610  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2611  isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2612  hasReductionIterator(linalgOp));
2613 }
2614 
2616  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2617  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2618  bool flatten1DDepthwiseConv) {
2619 
2620  if (!hasVectorizationImpl(op))
2621  return failure();
2622 
2623  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2624  inputScalableVecDims)))
2625  return failure();
2626 
2628  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2629  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2630  vectorizeNDExtract,
2631  flatten1DDepthwiseConv);
2632  })
2633  .Case<tensor::PadOp>([&](auto padOp) {
2634  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2635  })
2636  .Case<linalg::PackOp>([&](auto packOp) {
2637  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2638  })
2639  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2640  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2641  })
2642  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2643  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2644  })
2645  .Default([](auto) { return failure(); });
2646 }
2647 
2648 /// Converts affine.apply Ops to arithmetic operations.
2649 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2650  OpBuilder::InsertionGuard g(rewriter);
2651  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2652 
2653  for (auto op : make_early_inc_range(toReplace)) {
2654  rewriter.setInsertionPoint(op);
2655  auto expanded = affine::expandAffineExpr(
2656  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2657  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2658  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2659  rewriter.replaceOp(op, expanded);
2660  }
2661 }
2662 
2664  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2665  tensor::InsertSliceOp>(op);
2666 }
2667 
2668 FailureOr<VectorizationResult> mlir::linalg::vectorize(
2669  RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2670  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2671  bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2672  bool createNamedContraction) {
2673  LDBG() << "Attempting to vectorize: " << *op;
2674  LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2675  LDBG() << "Input scalable vector dims: "
2676  << llvm::interleaved(inputScalableVecDims);
2677 
2678  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2679  vectorizeNDExtract,
2680  flatten1DDepthwiseConv))) {
2681  LDBG() << "Vectorization pre-conditions failed";
2682  return failure();
2683  }
2684 
2685  // Initialize vectorization state.
2686  VectorizationState state(rewriter);
2687  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2688  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2689  inputScalableVecDims,
2690  assumeDynamicDimsMatchVecSizes))) {
2691  LDBG() << "Vectorization state couldn't be initialized";
2692  return failure();
2693  }
2694  }
2695 
2696  SmallVector<Value> results;
2697  auto vectorizeResult =
2699  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2700  // TODO: isaConvolutionOpInterface that can also infer from
2701  // generic features. Will require stride/dilation attributes
2702  // inference.
2703  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2704  FailureOr<Operation *> convOr = vectorizeConvolution(
2705  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2706  flatten1DDepthwiseConv);
2707  if (succeeded(convOr)) {
2708  llvm::append_range(results, (*convOr)->getResults());
2709  return success();
2710  }
2711 
2712  LDBG() << "Unsupported convolution can't be vectorized.";
2713  return failure();
2714  }
2715 
2716  if (createNamedContraction &&
2717  isa<ContractionOpInterface>(linalgOp.getOperation()))
2718  return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2719  results);
2720 
2721  LDBG()
2722  << "Vectorize generic by broadcasting to the canonical vector "
2723  "shape";
2724 
2725  // Pre-process before proceeding.
2726  convertAffineApply(rewriter, linalgOp);
2727 
2728  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2729  // to 'OpBuilder' when it is passed over to some methods like
2730  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2731  // erase an op within these methods, the actual rewriter won't be
2732  // notified and we will end up with read-after-free issues!
2733  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2734  })
2735  .Case<tensor::PadOp>([&](auto padOp) {
2736  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2737  results);
2738  })
2739  .Case<linalg::PackOp>([&](auto packOp) {
2740  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2741  results);
2742  })
2743  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2744  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2745  inputVectorSizes,
2746  inputScalableVecDims, results);
2747  })
2748  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2749  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2750  results);
2751  })
2752  .Default([](auto) { return failure(); });
2753 
2754  if (failed(vectorizeResult)) {
2755  LDBG() << "Vectorization failed";
2756  return failure();
2757  }
2758 
2759  return VectorizationResult{results};
2760 }
2761 
2763  memref::CopyOp copyOp) {
2764  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2765  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2766  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2767  return failure();
2768 
2769  auto srcElementType = getElementTypeOrSelf(srcType);
2770  auto dstElementType = getElementTypeOrSelf(dstType);
2771  if (!VectorType::isValidElementType(srcElementType) ||
2772  !VectorType::isValidElementType(dstElementType))
2773  return failure();
2774 
2775  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2776  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2777 
2778  Location loc = copyOp->getLoc();
2779  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2780  SmallVector<Value> indices(srcType.getRank(), zero);
2781 
2782  Value readValue = vector::TransferReadOp::create(
2783  rewriter, loc, readType, copyOp.getSource(), indices,
2784  /*padding=*/std::nullopt,
2785  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2786  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2787  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2788  ArrayRef<int64_t>());
2789  readValue =
2790  vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2791  }
2792  Operation *writeValue = vector::TransferWriteOp::create(
2793  rewriter, loc, readValue, copyOp.getTarget(), indices,
2794  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2795  rewriter.replaceOp(copyOp, writeValue->getResults());
2796  return success();
2797 }
2798 
2799 //----------------------------------------------------------------------------//
2800 // Misc. vectorization patterns.
2801 //----------------------------------------------------------------------------//
2802 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2803 /// given operation type OpTy.
2804 template <typename OpTy>
2805 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2807 
2808  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2809  PatternRewriter &rewriter) const final {
2810  bool changed = false;
2811  // Insert users in vector, because some users may be replaced/removed.
2812  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2813  if (auto op = dyn_cast<OpTy>(user))
2814  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2815  return success(changed);
2816  }
2817 
2818 protected:
2819  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2820  tensor::PadOp padOp, OpTy op) const = 0;
2821 };
2822 
2823 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2824 /// ```
2825 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2826 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2827 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2828 /// ```
2829 /// is rewritten to:
2830 /// ```
2831 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2832 /// {in_bounds = [true, true]}
2833 /// : tensor<?x?xf32>, vector<17x5xf32>
2834 /// ```
2835 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2836 /// sure that the original padding value %cst was never used.
2837 ///
2838 /// This rewrite is possible if:
2839 /// - `xferOp` has no out-of-bounds dims or mask.
2840 /// - Low padding is static 0.
2841 /// - Single, scalar padding value.
2843  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2845  vector::TransferReadOp>::VectorizePadOpUserPattern;
2846 
2847  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2848  vector::TransferReadOp xferOp) const override {
2849  // Low padding must be static 0.
2850  if (!padOp.hasZeroLowPad())
2851  return failure();
2852  // Pad value must be a constant.
2853  auto padValue = padOp.getConstantPaddingValue();
2854  if (!padValue)
2855  return failure();
2856  // Padding value of existing `xferOp` is unused.
2857  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2858  return failure();
2859 
2860  rewriter.modifyOpInPlace(xferOp, [&]() {
2861  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2862  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2863  rewriter.getBoolArrayAttr(inBounds));
2864  xferOp.getBaseMutable().assign(padOp.getSource());
2865  xferOp.getPaddingMutable().assign(padValue);
2866  });
2867 
2868  return success();
2869  }
2870 };
2871 
2872 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2873 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2874 /// value, where the same amount of padding is immediately removed again after
2875 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2876 /// tensor value and apply out-of-bounds masking. E.g.:
2877 /// ```
2878 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2879 /// : tensor<...> to tensor<?x?xf32>
2880 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2881 /// %2 = vector.transfer_write %vec, %1[...]
2882 /// : vector<17x5xf32>, tensor<17x5xf32>
2883 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2884 /// : tensor<17x5xf32> to tensor<?x?xf32>
2885 /// ```
2886 /// is rewritten to:
2887 /// ```
2888 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2889 /// : tensor<...> to tensor<?x?xf32>
2890 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2891 /// tensor<?x?xf32>
2892 /// ```
2893 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2894 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2895 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2896 /// from %r's old dimensions.
2897 ///
2898 /// This rewrite is possible if:
2899 /// - Low padding is static 0.
2900 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2901 /// ExtractSliceOp trims the same amount of padding that was added
2902 /// beforehand.
2903 /// - Single, scalar padding value.
2905  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2907  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2908 
2909  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2910  vector::TransferWriteOp xferOp) const override {
2911  // TODO: support 0-d corner case.
2912  if (xferOp.getTransferRank() == 0)
2913  return failure();
2914 
2915  // Low padding must be static 0.
2916  if (!padOp.hasZeroLowPad())
2917  return failure();
2918  // Pad value must be a constant.
2919  auto padValue = padOp.getConstantPaddingValue();
2920  if (!padValue)
2921  return failure();
2922  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2923  if (!xferOp->hasOneUse())
2924  return failure();
2925  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2926  if (!trimPadding)
2927  return failure();
2928  // Only static zero offsets supported when trimming padding.
2929  if (!trimPadding.hasZeroOffset())
2930  return failure();
2931  // trimPadding must remove the amount of padding that was added earlier.
2932  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2933  return failure();
2934 
2935  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2936  rewriter.setInsertionPoint(xferOp);
2937 
2938  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2939  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2940  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2941  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2942  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2943  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2944 
2945  return success();
2946  }
2947 
2948  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2949  /// i.e., same dimensions.
2950  ///
2951  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2952  /// dimensions, this function tries to infer the (static) tensor size by
2953  /// looking at the defining op and utilizing op-specific knowledge.
2954  ///
2955  /// This is a conservative analysis. In case equal tensor sizes cannot be
2956  /// proven statically, this analysis returns `false` even though the tensor
2957  /// sizes may turn out to be equal at runtime.
2958  bool hasSameTensorSize(Value beforePadding,
2959  tensor::ExtractSliceOp afterTrimming) const {
2960  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2961  // result and CastOp operand.
2962  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2963  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2964  return true;
2965 
2966  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2967  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2968  // Only RankedTensorType supported.
2969  if (!t1 || !t2)
2970  return false;
2971  // Rank of both values must be the same.
2972  if (t1.getRank() != t2.getRank())
2973  return false;
2974 
2975  // All static dimensions must be the same. Mixed cases (e.g., dimension
2976  // static in `t1` but dynamic in `t2`) are not supported.
2977  for (unsigned i = 0; i < t1.getRank(); ++i) {
2978  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2979  return false;
2980  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2981  return false;
2982  }
2983 
2984  // Nothing more to check if all dimensions are static.
2985  if (t1.getNumDynamicDims() == 0)
2986  return true;
2987 
2988  // All dynamic sizes must be the same. The only supported case at the
2989  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2990  // thereof).
2991 
2992  // Apart from CastOp, only ExtractSliceOp is supported.
2993  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2994  if (!beforeSlice)
2995  return false;
2996 
2997  assert(static_cast<size_t>(t1.getRank()) ==
2998  beforeSlice.getMixedSizes().size());
2999  assert(static_cast<size_t>(t2.getRank()) ==
3000  afterTrimming.getMixedSizes().size());
3001 
3002  for (unsigned i = 0; i < t1.getRank(); ++i) {
3003  // Skip static dimensions.
3004  if (!t1.isDynamicDim(i))
3005  continue;
3006  auto size1 = beforeSlice.getMixedSizes()[i];
3007  auto size2 = afterTrimming.getMixedSizes()[i];
3008 
3009  // Case 1: Same value or same constant int.
3010  if (isEqualConstantIntOrValue(size1, size2))
3011  continue;
3012 
3013  // Other cases: Take a deeper look at defining ops of values.
3014  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3015  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3016  if (!v1 || !v2)
3017  return false;
3018 
3019  // Case 2: Both values are identical AffineMinOps. (Should not happen if
3020  // CSE is run.)
3021  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3022  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3023  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3024  minOp1.getOperands() == minOp2.getOperands())
3025  continue;
3026 
3027  // Add additional cases as needed.
3028  }
3029 
3030  // All tests passed.
3031  return true;
3032  }
3033 };
3034 
3035 /// Returns the effective Pad value for the input op, provided it's a scalar.
3036 ///
3037 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3038 /// this Op performs padding, retrieve the padding value provided that it's
3039 /// a scalar and static/fixed for all the padded values. Returns an empty value
3040 /// otherwise.
3041 ///
3042 /// TODO: This is used twice (when checking vectorization pre-conditions and
3043 /// when vectorizing). Cache results instead of re-running.
3045  if (!op)
3046  return {};
3047 
3048  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3049  // being broadcast, provided that it's a scalar.
3050  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3051  auto source = bcast.getSource();
3052  if (llvm::dyn_cast<VectorType>(source.getType()))
3053  return {};
3054 
3055  return source;
3056  }
3057 
3058  // 2. linalg.fill - use the scalar input value that used to fill the output
3059  // tensor.
3060  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3061  return fill.getInputs()[0];
3062  }
3063 
3064  // 3. tensor.generateOp - can't guarantee the value is fixed without
3065  // analysing, bail out.
3066  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3067  return {};
3068  }
3069 
3070  // 4. vector.transfer_write - inspect the input vector that's written from. If
3071  // if contains a single value that has been broadcast (e.g. via
3072  // vector.broadcast), extract it, fail otherwise.
3073  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3074  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3075 
3076  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3077  // than the input tensor, then, provided it's constant, we'll extract the
3078  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3079  // TODO: Clarify the semantics when the input tensor is larger than the
3080  // destination.
3081  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3082  return getStaticPadVal(slice.getDest().getDefiningOp());
3083 
3084  return {};
3085 }
3086 
3087 static LogicalResult
3088 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3089  ArrayRef<int64_t> inputVectorSizes,
3090  SmallVectorImpl<Value> &newResults) {
3091  // TODO: Introduce a parent class that will handle the insertion point update.
3092  OpBuilder::InsertionGuard g(rewriter);
3093  rewriter.setInsertionPoint(sliceOp);
3094 
3095  TypedValue<RankedTensorType> source = sliceOp.getSource();
3096  auto sourceType = source.getType();
3097  auto resultType = sliceOp.getResultType();
3098 
3099  Value padValue = getStaticPadVal(sliceOp);
3100 
3101  if (!padValue) {
3102  auto elemType = sourceType.getElementType();
3103  padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3104  rewriter.getZeroAttr(elemType));
3105  }
3106 
3107  // 2. Get the vector shape
3108  SmallVector<int64_t> vecShape;
3109  size_t rankDiff = resultType.getRank() - sourceType.getRank();
3110  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3111  if (!inputVectorSizes.empty()) {
3112  vecShape.push_back(inputVectorSizes[i]);
3113  } else if (!sourceType.isDynamicDim(i)) {
3114  vecShape.push_back(sourceType.getDimSize(i));
3115  } else if (!resultType.isDynamicDim(i)) {
3116  // Source shape is not statically known, but result shape is.
3117  // Vectorize with size of result shape. This may be larger than the
3118  // source size.
3119  // FIXME: Using rankDiff implies that the source tensor is inserted at
3120  // the end of the destination tensor. However, that's not required.
3121  vecShape.push_back(resultType.getDimSize(rankDiff + i));
3122  } else {
3123  // Neither source nor result dim of padOp is static. Cannot vectorize
3124  // the copy.
3125  return failure();
3126  }
3127  }
3128  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3129 
3130  // 3. Generate TransferReadOp + TransferWriteOp
3131  auto loc = sliceOp.getLoc();
3132 
3133  // Create read
3134  SmallVector<Value> readIndices(
3135  vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3137  rewriter, loc, source, vecType.getShape(), padValue,
3138  /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3139  /*inputScalableVecSizes=*/{});
3140 
3141  // Create write
3142  auto writeIndices =
3143  getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3144  Operation *write =
3145  createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3146  writeIndices, inputVectorSizes.empty());
3147 
3148  // 4. Finalize
3149  newResults.push_back(write->getResult(0));
3150 
3151  return success();
3152 }
3153 
3154 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3155 /// ```
3156 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3157 /// %r = tensor.insert_slice %0
3158 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3159 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3160 /// ```
3161 /// is rewritten to:
3162 /// ```
3163 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
3164 /// : tensor<?x?xf32>, vector<17x5xf32>
3165 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3166 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3167 /// ```
3168 ///
3169 /// This rewrite is possible if:
3170 /// - Low padding is static 0.
3171 /// - `padOp` result shape is static.
3172 /// - The entire padded tensor is inserted.
3173 /// (Implies that sizes of `insertOp` are all static.)
3174 /// - Only unit strides in `insertOp`.
3175 /// - Single, scalar padding value.
3176 /// - `padOp` result not used as destination.
3178  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3180  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3181 
3182  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3183  tensor::InsertSliceOp insertOp) const override {
3184  // Low padding must be static 0.
3185  if (!padOp.hasZeroLowPad())
3186  return failure();
3187  // Only unit stride supported.
3188  if (!insertOp.hasUnitStride())
3189  return failure();
3190  // Pad value must be a constant.
3191  auto padValue = padOp.getConstantPaddingValue();
3192  if (!padValue)
3193  return failure();
3194  // Dynamic shapes not supported.
3195  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3196  return failure();
3197  // Pad result not used as destination.
3198  if (insertOp.getDest() == padOp.getResult())
3199  return failure();
3200 
3201  auto vecType = VectorType::get(padOp.getType().getShape(),
3202  padOp.getType().getElementType());
3203  unsigned vecRank = vecType.getRank();
3204  unsigned tensorRank = insertOp.getType().getRank();
3205 
3206  // Check if sizes match: Insert the entire tensor into most minor dims.
3207  // (No permutations allowed.)
3208  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3209  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3210  if (!llvm::all_of(
3211  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3212  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3213  }))
3214  return failure();
3215 
3216  // Insert the TransferReadOp and TransferWriteOp at the position of the
3217  // InsertSliceOp.
3218  rewriter.setInsertionPoint(insertOp);
3219 
3220  // Generate TransferReadOp: Read entire source tensor and add high
3221  // padding.
3222  SmallVector<Value> readIndices(
3223  vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3224  auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3225  vecType, padOp.getSource(),
3226  readIndices, padValue);
3227 
3228  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3229  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3230  // source must fit into the destination at the specified offsets.
3231  auto writeIndices = getValueOrCreateConstantIndexOp(
3232  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3233  SmallVector<bool> inBounds(vecRank, true);
3234  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3235  insertOp, read, insertOp.getDest(), writeIndices,
3236  ArrayRef<bool>{inBounds});
3237 
3238  return success();
3239  }
3240 };
3241 
3243  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3247  patterns.getContext(), baseBenefit.getBenefit() + 1);
3248 }
3249 
3250 //----------------------------------------------------------------------------//
3251 // Forwarding patterns
3252 //----------------------------------------------------------------------------//
3253 
3254 /// Check whether there is any interleaved use of any `values` between
3255 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3256 /// is in a different block.
3257 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3258  ValueRange values) {
3259  if (firstOp->getBlock() != secondOp->getBlock() ||
3260  !firstOp->isBeforeInBlock(secondOp)) {
3261  LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3262  << ", second op: " << *secondOp;
3263  return true;
3264  }
3265  for (auto v : values) {
3266  for (auto &u : v.getUses()) {
3267  Operation *owner = u.getOwner();
3268  if (owner == firstOp || owner == secondOp)
3269  continue;
3270  // TODO: this is too conservative, use dominance info in the future.
3271  if (owner->getBlock() == firstOp->getBlock() &&
3272  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3273  continue;
3274  LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3275  << ", second op: " << *secondOp;
3276  return true;
3277  }
3278  }
3279  return false;
3280 }
3281 
3282 /// Return the unique subview use of `v` if it is indeed unique, null
3283 /// otherwise.
3284 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3285  memref::SubViewOp subViewOp;
3286  for (auto &u : v.getUses()) {
3287  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3288  if (subViewOp)
3289  return memref::SubViewOp();
3290  subViewOp = newSubViewOp;
3291  }
3292  }
3293  return subViewOp;
3294 }
3295 
3296 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3297 /// when available.
3299  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3300 
3301  // TODO: support mask.
3302  if (xferOp.getMask())
3303  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3304 
3305  // Transfer into `view`.
3306  Value viewOrAlloc = xferOp.getBase();
3307  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3308  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3309  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3310 
3311  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3312  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3313  if (!subViewOp)
3314  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3315  Value subView = subViewOp.getResult();
3316 
3317  // Find the copy into `subView` without interleaved uses.
3318  memref::CopyOp copyOp;
3319  for (auto &u : subView.getUses()) {
3320  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3321  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3322  if (newCopyOp.getTarget() != subView)
3323  continue;
3324  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3325  continue;
3326  copyOp = newCopyOp;
3327  break;
3328  }
3329  }
3330  if (!copyOp)
3331  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3332 
3333  // Find the fill into `viewOrAlloc` without interleaved uses before the
3334  // copy.
3335  FillOp maybeFillOp;
3336  for (auto &u : viewOrAlloc.getUses()) {
3337  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3338  assert(isa<MemRefType>(newFillOp.output().getType()));
3339  if (newFillOp.output() != viewOrAlloc)
3340  continue;
3341  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3342  continue;
3343  maybeFillOp = newFillOp;
3344  break;
3345  }
3346  }
3347  // Ensure padding matches.
3348  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3349  return rewriter.notifyMatchFailure(xferOp,
3350  "padding value does not match fill");
3351 
3352  // `in` is the subview that memref.copy reads. Replace it.
3353  Value in = copyOp.getSource();
3354 
3355  // memref.copy + linalg.fill can be used to create a padded local buffer.
3356  // The `masked` attribute is only valid on this padded buffer.
3357  // When forwarding to vector.transfer_read, the attribute must be reset
3358  // conservatively.
3359  auto vectorType = xferOp.getVectorType();
3360  Value res = vector::TransferReadOp::create(
3361  rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3362  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3363  rewriter.getBoolArrayAttr(
3364  SmallVector<bool>(vectorType.getRank(), false)));
3365 
3366  if (maybeFillOp)
3367  rewriter.eraseOp(maybeFillOp);
3368  rewriter.eraseOp(copyOp);
3369  rewriter.replaceOp(xferOp, res);
3370 
3371  return success();
3372 }
3373 
3374 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3375 /// when available.
3377  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3378  // TODO: support mask.
3379  if (xferOp.getMask())
3380  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3381 
3382  // Transfer into `viewOrAlloc`.
3383  Value viewOrAlloc = xferOp.getBase();
3384  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3385  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3386  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3387 
3388  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3389  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3390  if (!subViewOp)
3391  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3392  Value subView = subViewOp.getResult();
3393 
3394  // Find the copy from `subView` without interleaved uses.
3395  memref::CopyOp copyOp;
3396  for (auto &u : subViewOp.getResult().getUses()) {
3397  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3398  if (newCopyOp.getSource() != subView)
3399  continue;
3400  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3401  continue;
3402  copyOp = newCopyOp;
3403  break;
3404  }
3405  }
3406  if (!copyOp)
3407  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3408 
3409  // `out` is the subview copied into that we replace.
3410  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3411  Value out = copyOp.getTarget();
3412 
3413  // Forward vector.transfer into copy.
3414  // memref.copy + linalg.fill can be used to create a padded local buffer.
3415  // The `masked` attribute is only valid on this padded buffer.
3416  // When forwarding to vector.transfer_write, the attribute must be reset
3417  // conservatively.
3418  auto vector = xferOp.getVector();
3419  vector::TransferWriteOp::create(
3420  rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3421  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3423  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3424 
3425  rewriter.eraseOp(copyOp);
3426  rewriter.eraseOp(xferOp);
3427 
3428  return success();
3429 }
3430 
3431 //===----------------------------------------------------------------------===//
3432 // Convolution vectorization patterns
3433 //===----------------------------------------------------------------------===//
3434 
3435 template <int N>
3436 static void bindShapeDims(ShapedType shapedType) {}
3437 
3438 template <int N, typename IntTy, typename... IntTy2>
3439 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3440  val = shapedType.getShape()[N];
3441  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3442 }
3443 
3444 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3445 template <typename... IntTy>
3446 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3447  bindShapeDims<0>(shapedType, vals...);
3448 }
3449 
3450 namespace {
3451 /// Generate a vector implementation for either:
3452 /// ```
3453 /// Op def: ( w, kw )
3454 /// Iters: ({Par(), Red()})
3455 /// Layout: {{w + kw}, {kw}, {w}}
3456 /// ```
3457 /// kw is unrolled.
3458 ///
3459 /// or
3460 ///
3461 /// ```
3462 /// Op def: ( n, w, c, kw, f )
3463 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3464 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3465 /// ```
3466 /// kw is unrolled, w is unrolled iff dilationW > 1.
3467 ///
3468 /// or
3469 ///
3470 /// ```
3471 /// Op def: ( n, c, w, f, kw )
3472 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3473 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3474 /// ```
3475 /// kw is unrolled, w is unrolled iff dilationW > 1.
3476 ///
3477 /// or
3478 ///
3479 /// ```
3480 /// Op def: ( n, w, c, kw )
3481 /// Iters: ({Par(), Par(), Par(), Red()})
3482 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3483 /// ```
3484 /// kw is unrolled, w is unrolled iff dilationW > 1.
3485 struct Conv1DGenerator
3486  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3487  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3488  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3489 
3490  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3491  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3492  resShaped = linalgOp.getDpsInitOperand(0)->get();
3493  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3494  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3495  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3496 
3497  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3498  redOp = reduceOp->getName().getIdentifier();
3499 
3500  setConvOperationKind(reduceOp);
3501 
3502  auto maybeKind = getCombinerOpKind(reduceOp);
3503  reductionKind = maybeKind.value();
3504 
3505  // The ConvolutionOpInterface gives us guarantees of existence for
3506  // strides/dilations. However, we do not need to rely on those, we can
3507  // simply use them if present, otherwise use the default and let the generic
3508  // conv. matcher in the ConvGenerator succeed or fail.
3509  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3510  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3511  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3512  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3513  }
3514 
3515  /// Generate a vector implementation for:
3516  /// ```
3517  /// Op def: ( w, kw )
3518  /// Iters: ({Par(), Red()})
3519  /// Layout: {{w + kw}, {kw}, {w}}
3520  /// ```
3521  /// kw is always unrolled.
3522  ///
3523  /// or
3524  ///
3525  /// ```
3526  /// Op def: ( n, w, c, kw, f )
3527  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3528  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3529  /// ```
3530  /// kw is always unrolled.
3531  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3532  /// > 1.
3533  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3534  int64_t nSize, wSize, cSize, kwSize, fSize;
3535  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3536  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3537  switch (conv1DOpOrder) {
3538  case Conv1DOpOrder::W:
3539  // Initialize unused dimensions
3540  nSize = fSize = cSize = 0;
3541  // out{W}
3542  bindShapeDims(resShapedType, wSize);
3543  // kernel{kw}
3544  bindShapeDims(rhsShapedType, kwSize);
3545  lhsShape = {// iw = ow + kw - 1
3546  // (i.e. 16 convolved with 3 -> 14)
3547  (wSize + kwSize - 1)};
3548  rhsShape = {kwSize};
3549  resShape = {wSize};
3550  break;
3551  case Conv1DOpOrder::Nwc:
3552  // out{n, w, f}
3553  bindShapeDims(resShapedType, nSize, wSize, fSize);
3554  switch (oper) {
3555  case ConvOperationKind::Conv:
3556  // kernel{kw, c, f}
3557  bindShapeDims(rhsShapedType, kwSize, cSize);
3558  break;
3559  case ConvOperationKind::Pool:
3560  // kernel{kw}
3561  bindShapeDims(rhsShapedType, kwSize);
3562  cSize = fSize;
3563  break;
3564  }
3565  lhsShape = {nSize,
3566  // iw = ow * sw + kw * dw - 1
3567  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3568  // Perform the proper inclusive -> exclusive -> inclusive.
3569  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3570  1,
3571  cSize};
3572  switch (oper) {
3573  case ConvOperationKind::Conv:
3574  rhsShape = {kwSize, cSize, fSize};
3575  break;
3576  case ConvOperationKind::Pool:
3577  rhsShape = {kwSize};
3578  break;
3579  }
3580  resShape = {nSize, wSize, fSize};
3581  break;
3582  case Conv1DOpOrder::Ncw:
3583  // out{n, f, w}
3584  bindShapeDims(resShapedType, nSize, fSize, wSize);
3585  switch (oper) {
3586  case ConvOperationKind::Conv:
3587  // kernel{f, c, kw}
3588  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3589  break;
3590  case ConvOperationKind::Pool:
3591  // kernel{kw}
3592  bindShapeDims(rhsShapedType, kwSize);
3593  cSize = fSize;
3594  break;
3595  }
3596  lhsShape = {nSize, cSize,
3597  // iw = ow * sw + kw * dw - 1
3598  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3599  // Perform the proper inclusive -> exclusive -> inclusive.
3600  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3601  1};
3602  switch (oper) {
3603  case ConvOperationKind::Conv:
3604  rhsShape = {fSize, cSize, kwSize};
3605  break;
3606  case ConvOperationKind::Pool:
3607  rhsShape = {kwSize};
3608  break;
3609  }
3610  resShape = {nSize, fSize, wSize};
3611  break;
3612  }
3613 
3614  vector::TransferWriteOp write;
3615  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3616 
3617  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3618  // When strideW == 1, we can batch the contiguous loads and avoid
3619  // unrolling
3620  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3621 
3622  Type lhsEltType = lhsShapedType.getElementType();
3623  Type rhsEltType = rhsShapedType.getElementType();
3624  Type resEltType = resShapedType.getElementType();
3625  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3626  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3627  auto resType = VectorType::get(resShape, resEltType);
3628  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3629  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3630  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3631  SmallVector<Value> resPadding(resShape.size(), zero);
3632 
3633  // Read the whole lhs, rhs and res in one shot (with zero padding).
3634  Value lhs = vector::TransferReadOp::create(
3635  rewriter, loc, lhsType, lhsShaped, lhsPadding,
3636  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3637  // This is needed only for Conv.
3638  Value rhs = nullptr;
3639  if (oper == ConvOperationKind::Conv)
3640  rhs = vector::TransferReadOp::create(
3641  rewriter, loc, rhsType, rhsShaped, rhsPadding,
3642  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3643  Value res = vector::TransferReadOp::create(
3644  rewriter, loc, resType, resShaped, resPadding,
3645  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3646 
3647  // The base vectorization case for channeled convolution is input:
3648  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3649  // vectorization case, we do pre transpose on input, weight, and output.
3650  switch (conv1DOpOrder) {
3651  case Conv1DOpOrder::W:
3652  case Conv1DOpOrder::Nwc:
3653  // Base case, so no transposes necessary.
3654  break;
3655  case Conv1DOpOrder::Ncw: {
3656  // To match base vectorization case, we pre-transpose current case.
3657  // ncw -> nwc
3658  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3659  lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3660  // fcw -> wcf
3661  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3662 
3663  // This is needed only for Conv.
3664  if (oper == ConvOperationKind::Conv)
3665  rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3666  // nfw -> nwf
3667  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3668  res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3669  break;
3670  }
3671  }
3672 
3673  //===------------------------------------------------------------------===//
3674  // Begin vector-only rewrite part
3675  //===------------------------------------------------------------------===//
3676  // Unroll along kw and read slices of lhs and rhs.
3677  SmallVector<Value> lhsVals, rhsVals, resVals;
3678  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3679  kwSize, strideW, dilationW, wSizeStep,
3680  isSingleChanneled);
3681  // Do not do for pooling.
3682  if (oper == ConvOperationKind::Conv)
3683  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3684  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3685  wSizeStep, isSingleChanneled);
3686 
3687  auto linearIndex = [&](int64_t kw, int64_t w) {
3688  return kw * (wSize / wSizeStep) + w;
3689  };
3690 
3691  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3692  // or perform outerproduct for non-channeled convolution or perform simple
3693  // arith operation for pooling
3694  for (int64_t kw = 0; kw < kwSize; ++kw) {
3695  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3696  switch (oper) {
3697  case ConvOperationKind::Conv:
3698  if (isSingleChanneled) {
3699  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3700  lhsVals[linearIndex(kw, w)],
3701  rhsVals[kw], resVals[w]);
3702  } else {
3703  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3704  lhsVals[linearIndex(kw, w)],
3705  rhsVals[kw], resVals[w]);
3706  }
3707  break;
3708  case ConvOperationKind::Pool:
3709  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3710  resVals[w]);
3711  break;
3712  }
3713  }
3714  }
3715 
3716  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3717  isSingleChanneled);
3718  //===------------------------------------------------------------------===//
3719  // End vector-only rewrite part
3720  //===------------------------------------------------------------------===//
3721 
3722  // The base vectorization case for channeled convolution is output:
3723  // {n,w,f} To reuse the result from base pattern vectorization case, we
3724  // post transpose the base case result.
3725  switch (conv1DOpOrder) {
3726  case Conv1DOpOrder::W:
3727  case Conv1DOpOrder::Nwc:
3728  // Base case, so no transposes necessary.
3729  break;
3730  case Conv1DOpOrder::Ncw: {
3731  // nwf -> nfw
3732  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3733  res = vector::TransposeOp::create(rewriter, loc, res, perm);
3734  break;
3735  }
3736  }
3737 
3738  return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3739  resPadding)
3740  .getOperation();
3741  }
3742 
3743  // Take a value and widen to have the same element type as `ty`.
3744  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3745  const Type srcElementType = getElementTypeOrSelf(val.getType());
3746  const Type dstElementType = getElementTypeOrSelf(ty);
3747  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3748  if (srcElementType == dstElementType)
3749  return val;
3750 
3751  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3752  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3753  const Type dstType =
3754  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3755 
3756  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3757  return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3758  }
3759 
3760  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3761  srcWidth < dstWidth)
3762  return arith::ExtFOp::create(rewriter, loc, dstType, val);
3763 
3764  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3765  srcWidth < dstWidth)
3766  return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3767 
3768  assert(false && "unhandled promotion case");
3769  return nullptr;
3770  }
3771 
3772  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3773  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3774  Value lhs, Value rhs, Value res) {
3775  vector::IteratorType par = vector::IteratorType::parallel;
3776  vector::IteratorType red = vector::IteratorType::reduction;
3777  AffineExpr n, w, f, c;
3778  bindDims(ctx, n, w, f, c);
3779  lhs = promote(rewriter, loc, lhs, res.getType());
3780  rhs = promote(rewriter, loc, rhs, res.getType());
3781  auto contrationOp = vector::ContractionOp::create(
3782  rewriter, loc, lhs, rhs, res,
3783  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3784  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3785  contrationOp.setKind(reductionKind);
3786  return contrationOp;
3787  }
3788 
3789  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3790  // convolution.
3791  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3792  Value lhs, Value rhs, Value res) {
3793  return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3794  rhs, res, vector::CombiningKind::ADD);
3795  }
3796 
3797  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3798  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3799  Value res) {
3800  if (isPoolExt)
3801  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3802  return rewriter
3803  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3804  ->getResult(0);
3805  }
3806 
3807  /// Generate a vector implementation for:
3808  /// ```
3809  /// Op def: ( n, w, c, kw)
3810  /// Iters: ({Par(), Par(), Par(), Red()})
3811  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3812  /// ```
3813  /// kw is always unrolled.
3814  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3815  /// > 1.
3816  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3817  bool channelDimScalableFlag,
3818  bool flatten) {
3819  bool scalableChDim = false;
3820  bool useMasking = false;
3821  int64_t nSize, wSize, cSize, kwSize;
3822  // kernel{kw, c}
3823  bindShapeDims(rhsShapedType, kwSize, cSize);
3824  if (ShapedType::isDynamic(cSize)) {
3825  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3826  cSize = channelDimVecSize;
3827  // Scalable vectors are only used when both conditions are met:
3828  // 1. channel dim is dynamic
3829  // 2. channelDimScalableFlag is set
3830  scalableChDim = channelDimScalableFlag;
3831  useMasking = true;
3832  }
3833 
3834  assert(!(useMasking && flatten) &&
3835  "Unsupported flattened conv with dynamic shapes");
3836 
3837  // out{n, w, c}
3838  bindShapeDims(resShapedType, nSize, wSize);
3839 
3840  vector::TransferWriteOp write;
3841  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3842 
3843  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3844  // When strideW == 1, we can batch the contiguous loads and avoid
3845  // unrolling
3846  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3847 
3848  Type lhsEltType = lhsShapedType.getElementType();
3849  Type rhsEltType = rhsShapedType.getElementType();
3850  Type resEltType = resShapedType.getElementType();
3851  VectorType lhsType = VectorType::get(
3852  {nSize,
3853  // iw = ow * sw + kw * dw - 1
3854  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3855  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3856  cSize},
3857  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3858  VectorType rhsType =
3859  VectorType::get({kwSize, cSize}, rhsEltType,
3860  /*scalableDims=*/{false, scalableChDim});
3861  VectorType resType =
3862  VectorType::get({nSize, wSize, cSize}, resEltType,
3863  /*scalableDims=*/{false, false, scalableChDim});
3864 
3865  // Masks the input xfer Op along the channel dim, iff the corresponding
3866  // scalable flag is set.
3867  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3868  ArrayRef<bool> scalableDims,
3869  Operation *opToMask) {
3870  if (!useMasking)
3871  return opToMask;
3872  auto maskType =
3873  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3874 
3875  SmallVector<bool> inBounds(maskShape.size(), true);
3876  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3877  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3878  rewriter.getBoolArrayAttr(inBounds));
3879 
3881  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3882 
3883  Value maskOp =
3884  vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3885 
3886  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3887  };
3888 
3889  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3890  // 0].
3891  Value lhs = vector::TransferReadOp::create(
3892  rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3893  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3894  auto maybeMaskedLhs = maybeMaskXferOp(
3895  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3896 
3897  // Read rhs slice of size {kw, c} @ [0, 0].
3898  Value rhs = vector::TransferReadOp::create(
3899  rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3900  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3901  auto maybeMaskedRhs = maybeMaskXferOp(
3902  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3903 
3904  // Read res slice of size {n, w, c} @ [0, 0, 0].
3905  Value res = vector::TransferReadOp::create(
3906  rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
3907  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3908  auto maybeMaskedRes = maybeMaskXferOp(
3909  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3910 
3911  //===------------------------------------------------------------------===//
3912  // Begin vector-only rewrite part
3913  //===------------------------------------------------------------------===//
3914  // Unroll along kw and read slices of lhs and rhs.
3915  SmallVector<Value> lhsVals, rhsVals, resVals;
3916  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3917  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3918 
3919  // Extract lhs slice of size {n, wSizeStep, c}
3920  // @ [0, sw * w + dw * kw, 0].
3921  for (int64_t kw = 0; kw < kwSize; ++kw) {
3922  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3923  lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3924  rewriter, loc, maybeMaskedLhs->getResult(0),
3925  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3926  inOutSliceSizes, inOutStrides));
3927  }
3928  }
3929  // Extract rhs slice of size {c} @ [kw].
3930  for (int64_t kw = 0; kw < kwSize; ++kw) {
3931  rhsVals.push_back(
3932  vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3933  /*offsets=*/ArrayRef<int64_t>{kw}));
3934  }
3935  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3936  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3937  resVals.push_back(vector::ExtractStridedSliceOp::create(
3938  rewriter, loc, maybeMaskedRes->getResult(0),
3939  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3940  inOutStrides));
3941  }
3942 
3943  auto linearIndex = [&](int64_t kw, int64_t w) {
3944  return kw * (wSize / wSizeStep) + w;
3945  };
3946 
3947  // Note - the scalable flags are ignored as flattening combined with
3948  // scalable vectorization is not supported.
3949  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3950  auto lhsTypeAfterFlattening =
3951  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3952  auto resTypeAfterFlattening =
3953  VectorType::get(inOutFlattenSliceSizes, resEltType);
3954 
3955  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3956  for (int64_t kw = 0; kw < kwSize; ++kw) {
3957  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3958  Value lhsVal = lhsVals[linearIndex(kw, w)];
3959  Value resVal = resVals[w];
3960  if (flatten) {
3961  // Flatten the input and output vectors (collapse the channel
3962  // dimension)
3963  lhsVal =
3964  vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3965  lhsVals[linearIndex(kw, w)]);
3966  resVal = vector::ShapeCastOp::create(
3967  rewriter, loc, resTypeAfterFlattening, resVals[w]);
3968  }
3969  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3970  rhsVals[kw], resVal, flatten);
3971  if (flatten) {
3972  // Un-flatten the output vector (restore the channel dimension)
3973  resVals[w] = vector::ShapeCastOp::create(
3974  rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
3975  resVals[w]);
3976  }
3977  }
3978  }
3979 
3980  // Its possible we failed to create the Fma.
3981  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3982  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3983  for (auto &collection :
3984  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3985  for (Value v : collection)
3986  rewriter.eraseOp(v.getDefiningOp());
3987  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3988  }
3989 
3990  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3991  // This does not depend on kw.
3992  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3993  maybeMaskedRes = vector::InsertStridedSliceOp::create(
3994  rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
3995  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3996  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3997  }
3998  //===------------------------------------------------------------------===//
3999  // End vector-only rewrite part
4000  //===------------------------------------------------------------------===//
4001 
4002  // Write back res slice of size {n, w, c} @ [0, 0, 0].
4003  Operation *resOut = vector::TransferWriteOp::create(
4004  rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4005  ValueRange{zero, zero, zero});
4006  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4007  resOut);
4008  }
4009 
4010  /// Lower:
4011  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4012  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4013  /// to MulAcc.
4014  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4015  Value lhs, Value rhs, Value res,
4016  bool flatten) {
4017  auto rhsTy = cast<ShapedType>(rhs.getType());
4018  auto resTy = cast<ShapedType>(res.getType());
4019 
4020  // TODO(suderman): Change this to use a vector.ima intrinsic.
4021  lhs = promote(rewriter, loc, lhs, resTy);
4022 
4023  if (flatten) {
4024  // NOTE: This following logic won't work for scalable vectors. For this
4025  // reason, "flattening" is not supported when shapes are dynamic (this
4026  // should be captured by one of the pre-conditions).
4027 
4028  // There are two options for handling the filter:
4029  // * shape_cast(broadcast(filter))
4030  // * broadcast(shuffle(filter))
4031  // Opt for the option without shape_cast to simplify the codegen.
4032  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4033  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4034 
4035  SmallVector<int64_t, 16> indices;
4036  for (int i = 0; i < resSize / rhsSize; ++i) {
4037  for (int j = 0; j < rhsSize; ++j)
4038  indices.push_back(j);
4039  }
4040 
4041  rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4042  }
4043  // Broadcast the filter to match the output vector
4044  rhs = vector::BroadcastOp::create(rewriter, loc,
4045  resTy.clone(rhsTy.getElementType()), rhs);
4046 
4047  rhs = promote(rewriter, loc, rhs, resTy);
4048 
4049  if (!lhs || !rhs)
4050  return nullptr;
4051 
4052  if (isa<FloatType>(resTy.getElementType()))
4053  return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4054 
4055  auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4056  return arith::AddIOp::create(rewriter, loc, mul, res);
4057  }
4058 
4059  /// Entry point for non-channeled convolution:
4060  /// {{w + kw}, {kw}, {w}}
4061  FailureOr<Operation *> generateNonChanneledConv() {
4062  AffineExpr w, kw;
4063  bindDims(ctx, w, kw);
4064  if (!iters({Par(), Red()}))
4065  return rewriter.notifyMatchFailure(op,
4066  "failed to match conv::W 1-par 1-red");
4067 
4068  // No transposition needed.
4069  if (layout({/*lhsIndex*/ {w + kw},
4070  /*rhsIndex*/ {kw},
4071  /*resIndex*/ {w}}))
4072  return conv(Conv1DOpOrder::W);
4073 
4074  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4075  }
4076 
4077  /// Entry point that transposes into the common form:
4078  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4079  FailureOr<Operation *> generateNwcConv() {
4080  AffineExpr n, w, f, kw, c;
4081  bindDims(ctx, n, w, f, kw, c);
4082  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4083  return rewriter.notifyMatchFailure(
4084  op, "failed to match conv::Nwc 3-par 2-red");
4085 
4086  // No transposition needed.
4087  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4088  /*rhsIndex*/ {kw, c, f},
4089  /*resIndex*/ {n, w, f}}))
4090  return conv(Conv1DOpOrder::Nwc);
4091 
4092  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4093  }
4094 
4095  /// Entry point that transposes into the common form:
4096  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4097  FailureOr<Operation *> generateNcwConv() {
4098  AffineExpr n, w, f, kw, c;
4099  bindDims(ctx, n, f, w, c, kw);
4100  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4101  return rewriter.notifyMatchFailure(
4102  op, "failed to match conv::Ncw 3-par 2-red");
4103 
4104  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4105  /*rhsIndex*/ {f, c, kw},
4106  /*resIndex*/ {n, f, w}}))
4107  return conv(Conv1DOpOrder::Ncw);
4108 
4109  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4110  }
4111 
4112  /// Entry point that transposes into the common form:
4113  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4114  FailureOr<Operation *> generateNwcPooling() {
4115  AffineExpr n, w, c, kw;
4116  bindDims(ctx, n, w, c, kw);
4117  if (!iters({Par(), Par(), Par(), Red()}))
4118  return rewriter.notifyMatchFailure(op,
4119  "failed to match pooling 3-par 1-red");
4120 
4121  // No transposition needed.
4122  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4123  /*rhsIndex*/ {kw},
4124  /*resIndex*/ {n, w, c}}))
4125  return conv(Conv1DOpOrder::Nwc);
4126 
4127  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4128  }
4129 
4130  /// Entry point that transposes into the common form:
4131  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4132  FailureOr<Operation *> generateNcwPooling() {
4133  AffineExpr n, w, c, kw;
4134  bindDims(ctx, n, c, w, kw);
4135  if (!iters({Par(), Par(), Par(), Red()}))
4136  return rewriter.notifyMatchFailure(op,
4137  "failed to match pooling 3-par 1-red");
4138 
4139  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4140  /*rhsIndex*/ {kw},
4141  /*resIndex*/ {n, c, w}}))
4142  return conv(Conv1DOpOrder::Ncw);
4143 
4144  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4145  }
4146 
4147  /// Entry point that transposes into the common form:
4148  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4149  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4150  bool vecChDimScalableFlag = false,
4151  bool flatten = false) {
4152  AffineExpr n, w, c, kw;
4153  bindDims(ctx, n, w, c, kw);
4154  if (!iters({Par(), Par(), Par(), Red()}))
4155  return rewriter.notifyMatchFailure(
4156  op, "failed to match depthwise::Nwc conv 3-par 1-red");
4157 
4158  // No transposition needed.
4159  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4160  /*rhsIndex*/ {kw, c},
4161  /*resIndex*/ {n, w, c}}))
4162  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4163 
4164  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4165  }
4166 
4167 private:
4168  ConvOperationKind oper = ConvOperationKind::Conv;
4169  StringAttr redOp;
4170  StringAttr poolExtOp;
4171  bool isPoolExt = false;
4172  int strideW, dilationW;
4173  Value lhsShaped, rhsShaped, resShaped;
4174  ShapedType lhsShapedType, rhsShapedType, resShapedType;
4175  vector::CombiningKind reductionKind;
4176 
4177  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4178  void setConvOperationKind(Operation *reduceOp) {
4179  int numBlockArguments =
4180  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4181  if (numBlockArguments == 1) {
4182  // Will be convolution if feeder is a MulOp.
4183  // A strength reduced version of MulOp for i1 type is AndOp which is also
4184  // supported. Otherwise, it can be pooling. This strength reduction logic
4185  // is in `buildBinaryFn` helper in the Linalg dialect.
4186  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4187  llvm::IsaPred<BlockArgument>);
4188  Operation *feedOp = (*feedValIt).getDefiningOp();
4189  if (isCastOfBlockArgument(feedOp)) {
4190  oper = ConvOperationKind::Pool;
4191  isPoolExt = true;
4192  poolExtOp = feedOp->getName().getIdentifier();
4193  return;
4194  }
4195  oper = ConvOperationKind::Conv;
4196  return;
4197  }
4198  // numBlockArugments == 2 and this is a pooling op.
4199  oper = ConvOperationKind::Pool;
4200  isPoolExt = false;
4201  }
4202 };
4203 } // namespace
4204 
4205 /// Helper function to vectorize a LinalgOp with convolution semantics.
4206 // TODO: extend the generic vectorization to support windows and drop this.
4207 static FailureOr<Operation *> vectorizeConvolution(
4208  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4209  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4210  Conv1DGenerator conv1dGen(rewriter, op);
4211  auto res = conv1dGen.generateNonChanneledConv();
4212  if (succeeded(res))
4213  return res;
4214  res = conv1dGen.generateNwcConv();
4215  if (succeeded(res))
4216  return res;
4217  res = conv1dGen.generateNcwConv();
4218  if (succeeded(res))
4219  return res;
4220  res = conv1dGen.generateNwcPooling();
4221  if (succeeded(res))
4222  return res;
4223  res = conv1dGen.generateNcwPooling();
4224  if (succeeded(res))
4225  return res;
4226 
4227  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4228  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4229  // masked/scalable) is the channel dim (i.e. the trailing dim).
4230  uint64_t vecChDimSize = ShapedType::kDynamic;
4231  bool vecChDimScalableFlag = false;
4232  if (!inputVecSizes.empty()) {
4233  // Only use the input vector size corresponding to the channel dim. Other
4234  // vector dims will be inferred from the Ops.
4235  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4236  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4237  "Not a 1D depthwise conv!");
4238  size_t chDimIdx =
4240  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4241  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4242 
4243  vecChDimSize = inputVecSizes[chDimIdx];
4244  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4245  }
4246  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4247  flatten1DDepthwiseConv);
4248 }
4249 
4252 
4253  LogicalResult matchAndRewrite(LinalgOp op,
4254  PatternRewriter &rewriter) const override {
4255  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4256  if (failed(resultOrFail))
4257  return failure();
4258  Operation *newOp = *resultOrFail;
4259  if (newOp->getNumResults() == 0) {
4260  rewriter.eraseOp(op.getOperation());
4261  return success();
4262  }
4263  assert(newOp->getNumResults() == 1 && "expected single result");
4264  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4265  return success();
4266  }
4267 };
4268 
4271  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4272 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5171
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:5170
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5169
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:138
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:600
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:218
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:380
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:200
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:654
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2767
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:332
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Definition: Transforms.h:885
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.