MLIR  22.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Support/LLVM.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 /// Try to vectorize `convOp` as a convolution.
53 static FailureOr<Operation *>
54 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
55  ArrayRef<int64_t> inputVecSizes = {},
56  ArrayRef<bool> inputVecScalableFlags = {},
57  bool flatten1DDepthwiseConv = false);
58 
59 /// Vectorize tensor::InsertSliceOp with:
60 /// * vector::TransferReadOp + vector::TransferWriteOp
61 /// The vector sizes are either:
62 /// * user-provided in `inputVectorSizes`, or
63 /// * inferred from the static dims in the input and output tensors.
64 /// Bails out if:
65 /// * vector sizes are not user-provided, and
66 /// * at least one dim is dynamic (in both the input and output tensors).
67 ///
68 /// Before:
69 /// !t_in_type = tensor<1x2x3xf32>
70 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
71 /// !v_type = vector<1x2x3xf32>
72 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
73 /// into !t_out_type
74 /// After:
75 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
76 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
77 static LogicalResult
78 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
79  ArrayRef<int64_t> inputVectorSizes,
80  SmallVectorImpl<Value> &newResults);
81 
82 /// Returns the effective Pad value for the input op, provided it's a scalar.
83 ///
84 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
85 /// this Op performs padding, retrieve the padding value provided that it's
86 /// a scalar and static/fixed for all the padded values. Returns an empty value
87 /// otherwise.
88 static Value getStaticPadVal(Operation *op);
89 
90 /// Return the unique instance of OpType in `block` if it is indeed unique.
91 /// Return null if none or more than 1 instances exist.
92 template <typename OpType>
93 static OpType getSingleOpOfType(Block &block) {
94  OpType res;
95  block.walk([&](OpType op) {
96  if (res) {
97  res = nullptr;
98  return WalkResult::interrupt();
99  }
100  res = op;
101  return WalkResult::advance();
102  });
103  return res;
104 }
105 
106 /// Helper function to extract the input slices after filter is unrolled along
107 /// kw.
108 static SmallVector<Value>
110  int64_t nSize, int64_t wSize, int64_t cSize,
111  int64_t kwSize, int strideW, int dilationW,
112  int64_t wSizeStep, bool isSingleChanneled) {
113  SmallVector<Value> result;
114  if (isSingleChanneled) {
115  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
116  // convolution.
117  SmallVector<int64_t> sizes = {wSizeStep};
118  SmallVector<int64_t> strides = {1};
119  for (int64_t kw = 0; kw < kwSize; ++kw) {
120  for (int64_t w = 0; w < wSize; w += wSizeStep) {
121  result.push_back(vector::ExtractStridedSliceOp::create(
122  rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
123  strides));
124  }
125  }
126  } else {
127  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128  // for channeled convolution.
129  SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130  SmallVector<int64_t> strides = {1, 1, 1};
131  for (int64_t kw = 0; kw < kwSize; ++kw) {
132  for (int64_t w = 0; w < wSize; w += wSizeStep) {
133  result.push_back(vector::ExtractStridedSliceOp::create(
134  rewriter, loc, input,
135  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136  sizes, strides));
137  }
138  }
139  }
140  return result;
141 }
142 
143 /// Helper function to extract the filter slices after filter is unrolled along
144 /// kw.
146  Location loc, Value filter,
147  int64_t kwSize) {
148  SmallVector<Value> result;
149  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150  // non-chanelled convolution] @ [kw].
151  for (int64_t kw = 0; kw < kwSize; ++kw) {
152  result.push_back(vector::ExtractOp::create(
153  rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
154  }
155  return result;
156 }
157 
158 /// Helper function to extract the result slices after filter is unrolled along
159 /// kw.
160 static SmallVector<Value>
162  int64_t nSize, int64_t wSize, int64_t fSize,
163  int64_t wSizeStep, bool isSingleChanneled) {
164  SmallVector<Value> result;
165  if (isSingleChanneled) {
166  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167  SmallVector<int64_t> sizes = {wSizeStep};
168  SmallVector<int64_t> strides = {1};
169  for (int64_t w = 0; w < wSize; w += wSizeStep) {
170  result.push_back(vector::ExtractStridedSliceOp::create(
171  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
172  strides));
173  }
174  } else {
175  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
176  // convolution.
177  SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
178  SmallVector<int64_t> strides = {1, 1, 1};
179  for (int64_t w = 0; w < wSize; w += wSizeStep) {
180  result.push_back(vector::ExtractStridedSliceOp::create(
181  rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
182  strides));
183  }
184  }
185  return result;
186 }
187 
188 /// Helper function to insert the computed result slices.
190  Value res, int64_t wSize, int64_t wSizeStep,
191  SmallVectorImpl<Value> &resVals,
192  bool isSingleChanneled) {
193 
194  if (isSingleChanneled) {
195  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196  // This does not depend on kw.
197  SmallVector<int64_t> strides = {1};
198  for (int64_t w = 0; w < wSize; w += wSizeStep) {
199  res = vector::InsertStridedSliceOp::create(
200  rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
201  strides);
202  }
203  } else {
204  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
205  // convolution. This does not depend on kw.
206  SmallVector<int64_t> strides = {1, 1, 1};
207  for (int64_t w = 0; w < wSize; w += wSizeStep) {
208  res = vector::InsertStridedSliceOp::create(
209  rewriter, loc, resVals[w], res,
210  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
211  }
212  }
213  return res;
214 }
215 
216 /// Contains the vectorization state and related methods used across the
217 /// vectorization process of a given operation.
219  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
220 
221  /// Initializes the vectorization state, including the computation of the
222  /// canonical vector shape for vectorization.
223  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224  ArrayRef<int64_t> inputVectorSizes,
225  ArrayRef<bool> inputScalableVecDims,
226  bool assumeDynamicDimsMatchVecSizes = false);
227 
228  /// Returns the canonical vector shape used to vectorize the iteration space.
229  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
230 
231  /// Returns the vector dimensions that are scalable in the canonical vector
232  /// shape.
233  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
234 
235  /// Returns a vector type of the provided `elementType` with the canonical
236  /// vector shape and the corresponding fixed/scalable dimensions bit. If
237  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
238  /// accordingly.
240  Type elementType,
241  std::optional<AffineMap> dimPermutation = std::nullopt) const {
243  SmallVector<bool> scalableDims;
244  if (dimPermutation.has_value()) {
245  vectorShape =
246  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
247  scalableDims =
248  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
249  } else {
250  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
252  }
253 
254  return VectorType::get(vectorShape, elementType, scalableDims);
255  }
256 
257  /// Masks an operation with the canonical vector mask if the operation needs
258  /// masking. Returns the masked operation or the original operation if masking
259  /// is not needed. If provided, the canonical mask for this operation is
260  /// permuted using `maybeIndexingMap`.
261  Operation *
262  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
263  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
264 
265 private:
266  /// Initializes the iteration space static sizes using the Linalg op
267  /// information. This may become more complicated in the future.
268  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
270  }
271 
272  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
273  /// all the static and dynamic dimensions of the iteration space to be
274  /// vectorized and store them in `iterSpaceValueSizes`.
275  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276  LinalgOp linalgOp);
277 
278  /// Create or retrieve an existing mask value to mask `opToMask` in the
279  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
280  /// permuted using that permutation map. If a new mask is created, it will be
281  /// cached for future users.
282  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
283  LinalgOp linalgOp,
284  std::optional<AffineMap> maybeMaskingMap);
285 
286  /// Check whether this permutation map can be used for masking. At the
287  /// moment we only make sure that there are no broadcast dimensions, but this
288  /// might change if indexing maps evolve.
289  bool isValidMaskingMap(AffineMap maskingMap) {
290  return maskingMap.getBroadcastDims().empty();
291  }
292 
293  /// Turn the input indexing map into a valid masking map.
294  ///
295  /// The input indexing map may contain "zero" results, e.g.:
296  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
297  /// Applying such maps to canonical vector shapes like this one:
298  /// (1, 16, 16, 4)
299  /// would yield an invalid vector shape like this:
300  /// (16, 16, 1, 0)
301  /// Instead, drop the broadcasting dims that make no sense for masking perm.
302  /// maps:
303  /// (d0, d1, d2, d3) -> (d2, d1, d0)
304  /// This way, the corresponding vector/mask type will be:
305  /// vector<16x16x1xty>
306  /// rather than this invalid Vector type:
307  /// vector<16x16x1x0xty>
308  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
309  return indexingMap.dropZeroResults();
310  }
311 
312  // Holds the compile-time static sizes of the iteration space to vectorize.
313  // Dynamic dimensions are represented using ShapedType::kDynamic.
314  SmallVector<int64_t> iterSpaceStaticSizes;
315 
316  /// Holds the value sizes of the iteration space to vectorize. Static
317  /// dimensions are represented by 'arith.constant' and dynamic
318  /// dimensions by 'tensor/memref.dim'.
319  SmallVector<Value> iterSpaceValueSizes;
320 
321  /// Holds the canonical vector shape used to vectorize the iteration space.
322  SmallVector<int64_t> canonicalVecShape;
323 
324  /// Holds the vector dimensions that are scalable in the canonical vector
325  /// shape.
326  SmallVector<bool> scalableVecDims;
327 
328  /// Holds the active masks for permutations of the canonical vector iteration
329  /// space.
330  DenseMap<AffineMap, Value> activeMaskCache;
331 
332  /// Global vectorization guard for the incoming rewriter. It's initialized
333  /// when the vectorization state is initialized.
334  OpBuilder::InsertionGuard rewriterGuard;
335 
336  /// Do all dynamic dims match the corresponding vector sizes?
337  ///
338  /// When a dynamic tensor/memref dimension matches the corresponding vector
339  /// dimension, masking can be safely skipped, despite the presence of dynamic
340  /// shapes. Use this flag with care and only for cases where you are
341  /// confident the assumption holds.
342  bool assumeDynamicDimsMatchVecSizes = false;
343 };
344 
345 LogicalResult
346 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
347  LinalgOp linalgOp) {
348  // TODO: Support 0-d vectors.
349  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350  if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
351  // Create constant index op for static dimensions.
352  iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
353  rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
354  continue;
355  }
356 
357  // Find an operand defined on this dimension of the iteration space to
358  // extract the runtime dimension size.
359  Value operand;
360  unsigned operandDimPos;
361  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
362  operandDimPos)))
363  return failure();
364 
365  Value dynamicDim =
366  linalgOp.hasPureTensorSemantics()
367  ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
368  operandDimPos)
369  : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370  operandDimPos);
371  iterSpaceValueSizes.push_back(dynamicDim);
372  }
373 
374  return success();
375 }
376 
377 /// Initializes the vectorization state, including the computation of the
378 /// canonical vector shape for vectorization.
379 // TODO: Move this to the constructor when we can remove the failure cases.
381  LinalgOp linalgOp,
382  ArrayRef<int64_t> inputVectorSizes,
383  ArrayRef<bool> inputScalableVecDims,
384  bool assumeDimsMatchVec) {
385  assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
386  // Initialize the insertion point.
387  rewriter.setInsertionPoint(linalgOp);
388 
389  if (!inputVectorSizes.empty()) {
390  // Get the canonical vector shape from the input vector sizes provided. This
391  // path should be taken to vectorize code with dynamic shapes and when using
392  // vector sizes greater than the iteration space sizes.
393  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394  scalableVecDims.append(inputScalableVecDims.begin(),
395  inputScalableVecDims.end());
396  } else {
397  // Compute the canonical vector shape from the operation shape. If there are
398  // dynamic shapes, the operation won't be vectorized. We assume all the
399  // vector dimensions are fixed.
400  canonicalVecShape = linalgOp.getStaticLoopRanges();
401  scalableVecDims.append(linalgOp.getNumLoops(), false);
402  }
403 
404  LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405  LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
406 
407  if (ShapedType::isDynamicShape(canonicalVecShape))
408  return failure();
409 
410  // Initialize iteration space static sizes.
411  initIterSpaceStaticSizes(linalgOp);
412 
413  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
414  // all the static and dynamic dimensions of the iteration space, needed to
415  // compute a mask during vectorization.
416  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417  return failure();
418 
419  return success();
420 }
421 
422 /// Create or retrieve an existing mask value to mask `opToMask` in the
423 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
424 /// using that permutation map. If a new mask is created, it will be cached for
425 /// future users.
426 Value VectorizationState::getOrCreateMaskFor(
427  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
428  std::optional<AffineMap> maybeMaskingMap) {
429 
430  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431  "Ill-formed masking map.");
432 
433  // No mask is needed if the operation is not maskable.
434  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
435  if (!maskableOp)
436  return Value();
437 
438  assert(!maskableOp.isMasked() &&
439  "Masking an operation that is already masked");
440 
441  // If no masking map was provided, use an identity map with the loop dims.
442  assert((!maybeMaskingMap || *maybeMaskingMap) &&
443  "Unexpected null mask permutation map");
444  AffineMap maskingMap =
445  maybeMaskingMap ? *maybeMaskingMap
447  linalgOp.getNumLoops(), rewriter.getContext());
448 
449  LDBG() << "Masking map: " << maskingMap;
450 
451  // Return the active mask for the masking map of this operation if it was
452  // already created.
453  auto activeMaskIt = activeMaskCache.find(maskingMap);
454  if (activeMaskIt != activeMaskCache.end()) {
455  Value mask = activeMaskIt->second;
456  LDBG() << "Reusing mask: " << mask;
457  return mask;
458  }
459 
460  // Compute permuted projection of the iteration space to be masked and the
461  // corresponding mask shape. If the resulting iteration space dimensions are
462  // static and identical to the mask shape, masking is not needed for this
463  // operation.
464  // TODO: Improve this check. Only projected permutation indexing maps are
465  // supported.
466  SmallVector<int64_t> permutedStaticSizes =
467  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
469  auto maskShape = maskType.getShape();
470 
471  LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
472 
473  if (permutedStaticSizes == maskShape) {
474  LDBG() << "Masking is not needed for masking map: " << maskingMap;
475  activeMaskCache[maskingMap] = Value();
476  return Value();
477  }
478 
479  if (assumeDynamicDimsMatchVecSizes) {
480  // While for _dynamic_ dim sizes we can _assume_ that the corresponding
481  // vector sizes match, we still need to check the _static_ dim sizes. Only
482  // then we can be 100% sure that masking is not required.
483  if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
484  [](auto it) {
485  return std::get<0>(it) == ShapedType::kDynamic
486  ? true
487  : std::get<0>(it) == std::get<1>(it);
488  })) {
489  LDBG()
490  << "Dynamic + static dimensions match vector sizes, masking is not "
491  "required.";
492  activeMaskCache[maskingMap] = Value();
493  return Value();
494  }
495  }
496 
497  // Permute the iteration space value sizes to compute the mask upper bounds.
498  SmallVector<Value> upperBounds =
499  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
500  assert(!maskShape.empty() && !upperBounds.empty() &&
501  "Masked 0-d vectors are not supported yet");
502 
503  // Create the mask based on the dimension values.
504  Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505  maskType, upperBounds);
506  LDBG() << "Creating new mask: " << mask;
507  activeMaskCache[maskingMap] = mask;
508  return mask;
509 }
510 
511 Operation *
513  LinalgOp linalgOp,
514  std::optional<AffineMap> maybeIndexingMap) {
515  LDBG() << "Trying to mask: " << *opToMask;
516 
517  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518  if (maybeIndexingMap)
519  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
520 
521  // Create or retrieve mask for this operation.
522  Value mask =
523  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
524 
525  if (!mask) {
526  LDBG() << "No mask required";
527  if (assumeDynamicDimsMatchVecSizes) {
529  .Case<vector::TransferReadOp, vector::TransferWriteOp>(
530  [&](auto xferOp) {
531  // For vector.transfer_read and vector.transfer_write, there is
532  // also the `in-bounds` attribute that has to be set explicitly
533  // to true. Otherwise, "out-of-bounds" access will be assumed
534  // and masks will be generated while lowering these.
535  LDBG() << "Assuming dynamic dimensions match vector sizes and "
536  "setting their in-bounds to true!";
537  SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
538  ShapedType xferType = xferOp.getShapedType();
539  AffineMap permMap = xferOp.getPermutationMap();
540  // Only set the in-bounds values to true for dynamic dims.
541  // Different mechanisms will set these accordingly for the
542  // static dims.
543  for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
544  auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
545  // Skip broadcast dimensions.
546  if (!dimExpr)
547  continue;
548  unsigned pos = dimExpr.getPosition();
549  if (xferType.isDynamicDim(pos))
550  inBoundsMap[i] = true;
551  }
552  rewriter.modifyOpInPlace(xferOp, [&]() {
553  xferOp.setInBoundsAttr(
554  rewriter.getBoolArrayAttr(inBoundsMap));
555  });
556  })
557  .Default([](Operation *op) {
558  // No-op if the operation is not an xfer read or write.
559  });
560  }
561  return opToMask;
562  }
563 
564  // Wrap the operation with a new `vector.mask` and update D-U chain.
565  assert(opToMask && "Expected a valid operation to mask");
566  auto maskOp = cast<vector::MaskOp>(
567  mlir::vector::maskOperation(rewriter, opToMask, mask));
568  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
569 
570  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
571  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
572  maskOpTerminator);
573 
574  LDBG() << "Masked operation: " << *maskOp;
575  return maskOp;
576 }
577 
578 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
579 /// projectedPermutation, compress the unused dimensions to serve as a
580 /// permutation_map for a vector transfer operation.
581 /// For example, given a linalg op such as:
582 ///
583 /// ```
584 /// %0 = linalg.generic {
585 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
586 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
587 /// }
588 /// ins(%0 : tensor<2x3x4xf32>)
589 /// outs(%1 : tensor<5x6xf32>)
590 /// ```
591 ///
592 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
593 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
594 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
596  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
597  "expected projected permutation");
598  auto res = compressUnusedDims(map);
599  assert(res.getNumDims() ==
600  (res.getNumResults() - res.getNumOfZeroResults()) &&
601  "expected reindexed map with same number of dims and results");
602  return res;
603 }
604 
605 /// Helper enum to represent conv1d input traversal order.
606 enum class Conv1DOpOrder {
607  W, // Corresponds to non-channeled 1D convolution operation.
608  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
609  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
610 };
611 
612 /// Helper data structure to represent the result of vectorization for a single
613 /// operation. In certain specific cases, like terminators, we do not want to
614 /// propagate.
616  /// Op failed to vectorize.
617  Failure = 0,
618  /// Op vectorized and custom function took care of replacement logic
620  /// Op vectorized into a new Op whose results will replace original Op's
621  /// results.
622  NewOp
623  // TODO: support values if Op vectorized to Many-Ops whose results we need to
624  // aggregate for replacement.
625 };
626 /// VectorizationHookResult contains the vectorized op returned from a
627 /// CustomVectorizationHook. This is an internal implementation detail of
628 /// linalg vectorization, not to be confused with VectorizationResult.
630  /// Return status from vectorizing the current op.
632  /// New vectorized operation to replace the current op.
633  /// Replacement behavior is specified by `status`.
635 };
636 
637 std::optional<vector::CombiningKind>
639  using ::mlir::vector::CombiningKind;
640 
641  if (!combinerOp)
642  return std::nullopt;
644  .Case<arith::AddIOp, arith::AddFOp>(
645  [&](auto op) { return CombiningKind::ADD; })
646  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
647  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
648  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
649  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
650  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
651  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
652  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
653  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
654  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
655  .Case<arith::MulIOp, arith::MulFOp>(
656  [&](auto op) { return CombiningKind::MUL; })
657  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
658  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
659  .Default([&](auto op) { return std::nullopt; });
660 }
661 
662 /// Check whether `outputOperand` is a reduction with a single combiner
663 /// operation. Return the combiner operation of the reduction. Return
664 /// nullptr otherwise. Multiple reduction operations would impose an
665 /// ordering between reduction dimensions and is currently unsupported in
666 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
667 /// max(min(X))
668 // TODO: use in LinalgOp verification, there is a circular dependency atm.
669 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
670  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
671  unsigned outputPos =
672  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
673  // Only single combiner operations are supported for now.
674  SmallVector<Operation *, 4> combinerOps;
675  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
676  combinerOps.size() != 1)
677  return nullptr;
678 
679  // Return the combiner operation.
680  return combinerOps[0];
681 }
682 
683 /// Broadcast `value` to a vector of `shape` if possible. Return value
684 /// otherwise.
685 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
686  auto dstVecType = dyn_cast<VectorType>(dstType);
687  // If no shape to broadcast to, just return `value`.
688  if (dstVecType.getRank() == 0)
689  return value;
690  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
692  return value;
693  Location loc = b.getInsertionPoint()->getLoc();
694  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
695 }
696 
697 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
698 /// assumes that `reductionOp` has two operands and one of them is the reduction
699 /// initial value.buildMultiDimReduce
700 // Note: this is a true builder that notifies the OpBuilder listener.
701 // TODO: Consider moving as a static helper on the ReduceOp.
703  Value valueToReduce, Value acc,
704  ArrayRef<bool> dimsToMask) {
705  auto maybeKind = getCombinerOpKind(reduceOp);
706  assert(maybeKind && "Failed precondition: could not get reduction kind");
707  return vector::MultiDimReductionOp::create(
708  b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
709 }
710 
711 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
712  return llvm::to_vector(
713  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
714 }
715 
716 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
717 /// reduction iterator.
718 static bool hasReductionIterator(LinalgOp &op) {
719  return isa<linalg::ReduceOp>(op) ||
720  (isa<linalg::GenericOp>(op) &&
721  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
722 }
723 
724 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
725 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
726 /// currently being vectorized. If `dest` has null rank, build an memref.store.
727 /// Return the produced value or null if no value is produced.
728 // Note: this is a true builder that notifies the OpBuilder listener.
729 // TODO: Consider moving as a static helper on the ReduceOp.
730 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
731  OpOperand *outputOperand,
732  VectorizationState &state) {
733  Location loc = value.getLoc();
734  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
735  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
736 
737  // Compute the vector type of the value to store. This type should be an
738  // identity or projection of the canonical vector type without any permutation
739  // applied, given that any permutation in a transfer write happens as part of
740  // the write itself.
742  opOperandMap.getContext(), opOperandMap.getNumInputs(),
743  [&](AffineDimExpr dimExpr) -> bool {
744  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
745  });
746  auto vectorType = state.getCanonicalVecType(
747  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
748 
749  Operation *write;
750  if (vectorType.getRank() > 0) {
751  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
752  SmallVector<Value> indices(
753  linalgOp.getRank(outputOperand),
754  arith::ConstantIndexOp::create(rewriter, loc, 0));
755  value = broadcastIfNeeded(rewriter, value, vectorType);
756  assert(value.getType() == vectorType && "Incorrect type");
757  write = vector::TransferWriteOp::create(
758  rewriter, loc, value, outputOperand->get(), indices, writeMap);
759  } else {
760  // 0-d case is still special: do not invert the reindexing writeMap.
761  if (!isa<VectorType>(value.getType()))
762  value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
763  assert(value.getType() == vectorType && "Incorrect type");
764  write = vector::TransferWriteOp::create(rewriter, loc, value,
765  outputOperand->get(), ValueRange{});
766  }
767 
768  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
769 
770  // If masked, set in-bounds to true. Masking guarantees that the access will
771  // be in-bounds.
772  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
773  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
774  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
775  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
776  }
777 
778  LDBG() << "vectorized op: " << *write;
779  if (!write->getResults().empty())
780  return write->getResult(0);
781  return Value();
782 }
783 
784 // Custom vectorization precondition function type. This is intented to be used
785 // with CustomVectorizationHook. Returns success if the corresponding custom
786 // hook can vectorize the op.
788  std::function<LogicalResult(Operation *, bool)>;
789 
790 // Custom vectorization function type. Produce a vector form of Operation*
791 // assuming all its vectorized operands are already in the IRMapping.
792 // Return nullptr if the Operation cannot be vectorized.
794  std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
795 
796 /// Helper function to vectorize the terminator of a `linalgOp`. New result
797 /// vector values are appended to `newResults`. Return
798 /// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
799 /// that it should not try to map produced operations and instead return the
800 /// results using the `newResults` vector making them available to the
801 /// vectorization algorithm for RAUW. This function is meant to be used as a
802 /// CustomVectorizationHook.
805  const IRMapping &bvm, VectorizationState &state,
806  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
807  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
808  if (!yieldOp)
810  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
811  // TODO: Scan for an opportunity for reuse.
812  // TODO: use a map.
813  Value vectorValue = bvm.lookup(output.value());
814  Value newResult =
815  buildVectorWrite(rewriter, vectorValue,
816  linalgOp.getDpsInitOperand(output.index()), state);
817  if (newResult)
818  newResults.push_back(newResult);
819  }
820 
822 }
823 
824 /// Helper function to vectorize the index operations of a `linalgOp`. Return
825 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
826 /// should map the produced operations. This function is meant to be used as a
827 /// CustomVectorizationHook.
829  VectorizationState &state,
830  Operation *op,
831  LinalgOp linalgOp) {
832  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
833  if (!indexOp)
835  auto loc = indexOp.getLoc();
836  // Compute the static loop sizes of the index op.
837  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
838  auto dim = indexOp.getDim();
839  // Compute a one-dimensional index vector for the index op dimension.
840  auto indexVectorType =
841  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
842  state.getScalableVecDims()[dim]);
843  auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
844  // Return the one-dimensional index vector if it lives in the trailing
845  // dimension of the iteration space since the vectorization algorithm in this
846  // case can handle the broadcast.
847  if (dim == targetShape.size() - 1)
849  // Otherwise permute the targetShape to move the index dimension last,
850  // broadcast the one-dimensional index vector to the permuted shape, and
851  // finally transpose the broadcasted index vector to undo the permutation.
852  auto permPattern =
853  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
854  std::swap(permPattern[dim], permPattern.back());
855  auto permMap =
856  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
857 
858  auto broadCastOp = vector::BroadcastOp::create(
859  rewriter, loc,
860  state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
861  SmallVector<int64_t> transposition =
862  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
863  std::swap(transposition.back(), transposition[dim]);
864  auto transposeOp =
865  vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
867 }
868 
869 /// Helper function to check if the tensor.extract can be vectorized by the
870 /// custom hook vectorizeTensorExtract.
871 static LogicalResult
872 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
873  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
874  if (!extractOp)
875  return failure();
876 
877  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
878  return failure();
879 
880  // Check the index type, but only for non 0-d tensors (for which we do need
881  // access indices).
882  if (not extractOp.getIndices().empty()) {
883  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
884  return failure();
885  }
886 
887  if (!llvm::all_of(extractOp->getResultTypes(),
888  VectorType::isValidElementType)) {
889  return failure();
890  }
891 
892  return success();
893 }
894 
895 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
896 /// generated from `tensor.extract`. The offset is calculated as follows
897 /// (example using scalar values):
898 ///
899 /// offset = extractOp.indices[0]
900 /// for (i = 1; i < numIndices; i++)
901 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
902 ///
903 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
904 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
906  VectorizationState &state,
907  tensor::ExtractOp extractOp,
908  const IRMapping &bvm) {
909  // The vector of indices for GatherOp should be shaped as the output vector.
910  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
911  auto loc = extractOp.getLoc();
912 
913  Value offset = broadcastIfNeeded(
914  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
915 
916  const size_t numIndices = extractOp.getIndices().size();
917  for (size_t i = 1; i < numIndices; i++) {
918  Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
919 
920  auto dimSize = broadcastIfNeeded(
921  rewriter,
922  tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
923  indexVecType);
924 
925  offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
926 
927  auto extractOpIndex = broadcastIfNeeded(
928  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
929 
930  offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
931  }
932 
933  return offset;
934 }
935 
937 
938 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
939 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
940 /// represents a contiguous load operation.
941 ///
942 /// Note that when calling this hook, it is assumed that the output vector is
943 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
944 /// labelled as a gather load before entering this method.
945 ///
946 /// Following on from the above, it is assumed that:
947 /// * for statically shaped loops, when no masks are used, only one dim is !=
948 /// 1 (that's what the shape of the output vector is based on).
949 /// * for dynamically shaped loops, there might be more non-unit dims
950 /// as the output vector type is user-specified.
951 ///
952 /// TODO: Statically shaped loops + vector masking
953 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
954  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
955  assert(
956  (linalgOp.hasDynamicShape() ||
957  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
958  "For statically shaped Linalg Ops, only one "
959  "non-unit loop dim is expected");
960  assert(!loopRanges.empty() && "Empty loops, nothing to analyse.");
961 
962  size_t idx = loopRanges.size() - 1;
963  for (; idx != 0; idx--)
964  if (loopRanges[idx] != 1)
965  break;
966 
967  return idx;
968 }
969 
970 /// Checks whether `val` can be used for calculating a loop invariant index.
971 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
972  VectorType resType) {
973 
974  assert(((llvm::count_if(resType.getShape(),
975  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976  "n-D vectors are not yet supported");
977 
978  // Blocks outside _this_ linalg.generic are effectively loop invariant.
979  // However, analysing block arguments for _this_ linalg.generic Op is a bit
980  // tricky. Just bail out in the latter case.
981  // TODO: We could try analysing the corresponding affine map here.
982  auto *block = linalgOp.getBlock();
983  if (isa<BlockArgument>(val))
984  return !llvm::is_contained(block->getArguments(), val);
985 
986  Operation *defOp = val.getDefiningOp();
987  assert(defOp && "This is neither a block argument nor an operation result");
988 
989  // IndexOp is loop invariant as long as its result remains constant across
990  // iterations. Note that for dynamic shapes, the corresponding dim will also
991  // be conservatively treated as != 1.
992  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
993  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
994  }
995 
996  auto *ancestor = block->findAncestorOpInBlock(*defOp);
997 
998  // Values define outside `linalgOp` are loop invariant.
999  if (!ancestor)
1000  return true;
1001 
1002  // Values defined inside `linalgOp`, which are constant, are loop invariant.
1003  if (isa<arith::ConstantOp>(ancestor))
1004  return true;
1005 
1006  bool result = true;
1007  for (auto op : ancestor->getOperands())
1008  result &= isLoopInvariantIdx(linalgOp, op, resType);
1009 
1010  return result;
1011 }
1012 
1013 /// Check whether `val` could be used for calculating the trailing index for a
1014 /// contiguous load operation.
1015 ///
1016 /// There are currently 3 types of values that are allowed here:
1017 /// 1. loop-invariant values,
1018 /// 2. values that increment by 1 with every loop iteration,
1019 /// 3. results of basic arithmetic operations (linear and continuous)
1020 /// involving 1., 2. and 3.
1021 /// This method returns True if indeed only such values are used in calculating
1022 /// `val.`
1023 ///
1024 /// Additionally, the trailing index for a contiguous load operation should
1025 /// increment by 1 with every loop iteration, i.e. be based on:
1026 /// * `linalg.index <dim>` ,
1027 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
1028 /// `linalg.index <dim>` increments by 1 with every loop iteration).
1029 /// `foundIndexOp` is updated to `true` when such Op is found.
1030 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
1031  bool &foundIndexOp, VectorType resType) {
1032 
1033  assert(((llvm::count_if(resType.getShape(),
1034  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1035  "n-D vectors are not yet supported");
1036 
1037  // Blocks outside _this_ linalg.generic are effectively loop invariant.
1038  // However, analysing block arguments for _this_ linalg.generic Op is a bit
1039  // tricky. Just bail out in the latter case.
1040  // TODO: We could try analysing the corresponding affine map here.
1041  auto *block = linalgOp.getBlock();
1042  if (isa<BlockArgument>(val))
1043  return !llvm::is_contained(block->getArguments(), val);
1044 
1045  Operation *defOp = val.getDefiningOp();
1046  assert(defOp && "This is neither a block argument nor an operation result");
1047 
1048  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1049  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1050 
1051  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1052  return true;
1053  }
1054 
1055  auto *ancestor = block->findAncestorOpInBlock(*defOp);
1056 
1057  if (!ancestor)
1058  return false;
1059 
1060  // Conservatively reject Ops that could lead to indices with stride other
1061  // than 1.
1062  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1063  return false;
1064 
1065  bool result = false;
1066  for (auto op : ancestor->getOperands())
1067  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1068 
1069  return result;
1070 }
1071 
1072 /// Infer the memory access pattern for the input ExtractOp
1073 ///
1074 /// Based on the ExtratOp result shape and the access indices, decides whether
1075 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1076 /// or a gather load. When analysing the ExtractOp indices (to identify
1077 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
1078 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
1079 /// Op).
1080 ///
1081 /// Note that it is always safe to use gather load operations for contiguous
1082 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1083 /// that `extractOp` is a gather load.
1085 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1086  LinalgOp &linalgOp, VectorType resType) {
1087 
1088  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1089 
1090  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1091  if (inputShape.getShape().empty())
1093 
1094  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1095  // otherwise.
1096  bool isOutput1DVector =
1097  (llvm::count_if(resType.getShape(),
1098  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1099  // 1. Assume that it's a gather load when reading non-1D vector.
1100  if (!isOutput1DVector)
1102 
1103  bool leadingIdxsLoopInvariant = true;
1104 
1105  // 2. Analyze the leading indices of `extractOp`.
1106  // Look at the way each index is calculated and decide whether it is suitable
1107  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1108  // gather load.
1109  auto indices = extractOp.getIndices();
1110  auto leadIndices = indices.drop_back(1);
1111 
1112  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1113  if (inputShape.getShape()[i] == 1)
1114  continue;
1115 
1116  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1117  }
1118 
1119  if (!leadingIdxsLoopInvariant) {
1120  LDBG() << "Found gather load: " << extractOp;
1122  }
1123 
1124  // 3. Analyze the trailing index for `extractOp`.
1125  // At this point we know that the leading indices are loop invariant. This
1126  // means that is potentially a scalar or a contiguous load. We can decide
1127  // based on the trailing idx.
1128  auto extractOpTrailingIdx = indices.back();
1129 
1130  // 3a. Scalar broadcast load
1131  // If the trailing index is loop invariant then this is a scalar load.
1132  if (leadingIdxsLoopInvariant &&
1133  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1134  LDBG() << "Found scalar broadcast load: " << extractOp;
1135 
1137  }
1138 
1139  // 3b. Contiguous loads
1140  // The trailing `extractOp` index should increment with every loop iteration.
1141  // This effectively means that it must be based on the trailing loop index.
1142  // This is what the following bool captures.
1143  bool foundIndexOp = false;
1144  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1145  foundIndexOp, resType);
1146  // TODO: Support generating contiguous loads for column vectors - that will
1147  // require adding a permutation map to tranfer_read Ops.
1148  bool isRowVector = resType.getShape().back() != 1;
1149  isContiguousLoad &= (foundIndexOp && isRowVector);
1150 
1151  if (isContiguousLoad) {
1152  LDBG() << "Found contigous load: " << extractOp;
1154  }
1155 
1156  // 4. Fallback case - gather load.
1157  LDBG() << "Found gather load: " << extractOp;
1159 }
1160 
1161 /// Helper function to vectorize the tensor.extract operations. Returns
1162 /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1163 /// should map the produced operations. This function is meant to be used as a
1164 /// CustomVectorizationHook.
1167  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1168  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1169  if (!extractOp)
1171  auto loc = extractOp.getLoc();
1172 
1173  // Compute the static loop sizes of the extract op.
1174  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1175  auto maskConstantOp = arith::ConstantOp::create(
1176  rewriter, loc,
1177  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1178  /*value=*/true));
1179  auto passThruConstantOp = arith::ConstantOp::create(
1180  rewriter, loc, rewriter.getZeroAttr(resultType));
1181 
1182  // Base indices are currently set to 0. We will need to re-visit if more
1183  // generic scenarios are to be supported.
1184  SmallVector<Value> baseIndices(
1185  extractOp.getIndices().size(),
1186  arith::ConstantIndexOp::create(rewriter, loc, 0));
1187 
1188  VectorMemoryAccessKind memAccessKind =
1189  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1190 
1191  // 1. Handle gather access
1192  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1193  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1194 
1195  // Generate the gather load
1196  Operation *gatherOp = vector::GatherOp::create(
1197  rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1198  maskConstantOp, passThruConstantOp);
1199  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1200 
1201  LDBG() << "Vectorised as gather load: " << extractOp;
1203  }
1204 
1205  // 2. Handle:
1206  // a. scalar loads + broadcast,
1207  // b. contiguous loads.
1208  // Both cases use vector.transfer_read.
1209 
1210  // Collect indices for `vector.transfer_read`. At this point, the indices will
1211  // either be scalars or would have been broadcast to vectors matching the
1212  // result type. For indices that are vectors, there are two options:
1213  // * for non-trailing indices, all elements are identical (contiguous
1214  // loads are identified by looking for non-trailing indices that are
1215  // invariant with respect to the corresponding linalg.generic), or
1216  // * for trailing indices, the index vector will contain values with stride
1217  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1218  // needed.
1219  // This means that
1220  // * for scalar indices - just re-use it,
1221  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1222  // (0th) element and use that.
1223  SmallVector<Value> transferReadIdxs;
1224  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1225  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1226  if (idx.getType().isIndex()) {
1227  transferReadIdxs.push_back(idx);
1228  continue;
1229  }
1230 
1231  auto indexAs1dVector = vector::ShapeCastOp::create(
1232  rewriter, loc,
1233  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1234  resultType.getScalableDims().back()),
1235  idx);
1236  transferReadIdxs.push_back(
1237  vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1238  }
1239 
1240  // `tensor.extract_element` is always in-bounds, hence the following holds.
1241  auto dstRank = resultType.getRank();
1242  auto srcRank = extractOp.getTensor().getType().getRank();
1243  SmallVector<bool> inBounds(dstRank, true);
1244 
1245  // 2a. Handle scalar broadcast access.
1246  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1247  MLIRContext *ctx = rewriter.getContext();
1248  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1249  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1250 
1251  auto transferReadOp = vector::TransferReadOp::create(
1252  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1253  /*padding=*/std::nullopt, permutationMap, inBounds);
1254 
1255  // Mask this broadcasting xfer_read here rather than relying on the generic
1256  // path (the generic path assumes identity masking map, which wouldn't be
1257  // valid here).
1258  SmallVector<int64_t> readMaskShape = {1};
1259  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1260  auto allTrue = vector::ConstantMaskOp::create(
1261  rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1262  auto *maskedReadOp =
1263  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1264 
1265  LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1267  maskedReadOp};
1268  }
1269 
1270  // 2b. Handle contiguous access.
1271  auto permutationMap = AffineMap::getMinorIdentityMap(
1272  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1273 
1274  int32_t rankDiff = dstRank - srcRank;
1275  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1276  // dims so that the ranks match. This is done by extending the map with 0s.
1277  // For example, for dstRank = 3, srcRank = 2, the following map created
1278  // above:
1279  // (d0, d1) --> (d0, d1)
1280  // is extended as:
1281  // (d0, d1) --> (0, d0, d1)
1282  while (rankDiff > 0) {
1283  permutationMap = permutationMap.insertResult(
1284  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1285  rankDiff--;
1286  }
1287 
1288  auto transferReadOp = vector::TransferReadOp::create(
1289  rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1290  /*padding=*/std::nullopt, permutationMap, inBounds);
1291 
1292  LDBG() << "Vectorised as contiguous load: " << extractOp;
1294  transferReadOp};
1295 }
1296 
1297 /// Emit reduction operations if the shapes of the value to reduce is different
1298 /// that the result shape.
1299 // Note: this is a true builder that notifies the OpBuilder listener.
1300 // TODO: Consider moving as a static helper on the ReduceOp.
1301 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1302  Value reduceValue, Value initialValue,
1303  const IRMapping &bvm) {
1304  Value reduceVec = bvm.lookup(reduceValue);
1305  Value outputVec = bvm.lookup(initialValue);
1306  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1307  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1308  // Reduce only if needed as the value may already have been reduce for
1309  // contraction vectorization.
1310  if (!reduceType ||
1311  (outputType && reduceType.getShape() == outputType.getShape()))
1312  return nullptr;
1313  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1314  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1315 }
1316 
1317 /// Generic vectorization for a single operation `op`, given already vectorized
1318 /// operands carried by `bvm`. Vectorization occurs as follows:
1319 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1320 /// result on success.
1321 /// 2. Clone any constant in the current scope without vectorization: each
1322 /// consumer of the constant will later determine the shape to which the
1323 /// constant needs to be broadcast to.
1324 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1325 /// of the `customVectorizationHooks` to cover such cases.
1326 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1327 /// operand of maximal rank. Other operands have smaller rank and are
1328 /// broadcast accordingly. It is assumed this broadcast is always legal,
1329 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1330 ///
1331 /// This function assumes all operands of `op` have been vectorized and are in
1332 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1333 /// a topologically-sorted list of ops.
1334 /// This function does not update `bvm` but returns a VectorizationHookStatus
1335 /// that instructs the caller what `bvm` update needs to occur.
1338  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1339  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1340  LDBG() << "vectorize op " << *op;
1341 
1342  // 1. Try to apply any CustomVectorizationHook.
1343  if (!customVectorizationHooks.empty()) {
1344  for (auto &customFunc : customVectorizationHooks) {
1345  VectorizationHookResult result = customFunc(op, bvm);
1347  continue;
1348  return result;
1349  }
1350  }
1351 
1352  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1353  // Clone so that the constant is not confined to the linalgOp block .
1354  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1356  rewriter.clone(*op)};
1357 
1358  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1361 
1362  // 4 . Check if the operation is a reduction.
1363  SmallVector<std::pair<Value, Value>> reductionOperands;
1364  for (Value operand : op->getOperands()) {
1365  auto blockArg = dyn_cast<BlockArgument>(operand);
1366  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1367  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1368  continue;
1369  SmallVector<Operation *> reductionOps;
1370  Value reduceValue = matchReduction(
1371  linalgOp.getRegionOutputArgs(),
1372  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1373  if (!reduceValue)
1374  continue;
1375  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1376  }
1377  if (!reductionOperands.empty()) {
1378  assert(reductionOperands.size() == 1);
1379  Operation *reduceOp =
1380  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1381  reductionOperands[0].second, bvm);
1382  if (reduceOp)
1384  }
1385 
1386  // 5. Generic vectorization path for ElementwiseMappable ops.
1387  // a. Get the first max ranked shape.
1388  VectorType firstMaxRankedType;
1389  for (Value operand : op->getOperands()) {
1390  auto vecOperand = bvm.lookup(operand);
1391  assert(vecOperand && "Vector operand couldn't be found");
1392 
1393  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1394  if (vecType && (!firstMaxRankedType ||
1395  firstMaxRankedType.getRank() < vecType.getRank()))
1396  firstMaxRankedType = vecType;
1397  }
1398  // b. Broadcast each op if needed.
1399  SmallVector<Value> vecOperands;
1400  for (Value scalarOperand : op->getOperands()) {
1401  Value vecOperand = bvm.lookup(scalarOperand);
1402  assert(vecOperand && "Vector operand couldn't be found");
1403 
1404  if (firstMaxRankedType) {
1405  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1406  getElementTypeOrSelf(vecOperand.getType()),
1407  firstMaxRankedType.getScalableDims());
1408  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1409  } else {
1410  vecOperands.push_back(vecOperand);
1411  }
1412  }
1413  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1414  SmallVector<Type> resultTypes;
1415  for (Type resultType : op->getResultTypes()) {
1416  resultTypes.push_back(
1417  firstMaxRankedType
1418  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1419  firstMaxRankedType.getScalableDims())
1420  : resultType);
1421  }
1422  // d. Build and return the new op.
1423  return VectorizationHookResult{
1425  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1426  resultTypes, op->getAttrs())};
1427 }
1428 
1429 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1430 /// vector form. Generic vectorization proceeds as follows:
1431 /// 1. Verify the `linalgOp` has one non-empty region.
1432 /// 2. Values defined above the region are mapped to themselves and will be
1433 /// broadcasted on a per-need basis by their consumers.
1434 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1435 /// load).
1436 /// TODO: Reuse opportunities for RAR dependencies.
1437 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1438 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1439 /// iteration indices.
1440 /// 5. Iteratively call vectorizeOneOp on the region operations.
1441 ///
1442 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1443 /// performed to the maximal common vector size implied by the `linalgOp`
1444 /// iteration space. This eager broadcasting is introduced in the
1445 /// permutation_map of the vector.transfer_read operations. The eager
1446 /// broadcasting makes it trivial to determine where broadcast, transposes and
1447 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1448 /// the absence of good canonicalizations, the amount of work increases.
1449 /// This is not deemed a problem as we expect canonicalizations and foldings to
1450 /// aggressively clean up the useless work.
1451 static LogicalResult
1453  LinalgOp linalgOp,
1454  SmallVectorImpl<Value> &newResults) {
1455  LDBG() << "Vectorizing operation as linalg generic/n";
1456  Block *block = linalgOp.getBlock();
1457 
1458  // 2. Values defined above the region can only be broadcast for now. Make them
1459  // map to themselves.
1460  IRMapping bvm;
1461  SetVector<Value> valuesSet;
1462  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1463  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1464 
1465  if (linalgOp.getNumDpsInits() == 0)
1466  return failure();
1467 
1468  // 3. Turn all BBArgs into vector.transfer_read / load.
1469  Location loc = linalgOp.getLoc();
1470  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1471  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1472  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1473  if (linalgOp.isScalar(opOperand)) {
1474  bvm.map(bbarg, opOperand->get());
1475  continue;
1476  }
1477 
1478  // 3.a. Convert the indexing map for this input/output to a transfer read
1479  // permutation map and masking map.
1480  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1481 
1482  AffineMap readMap;
1483  VectorType readType;
1484  Type elemType = getElementTypeOrSelf(opOperand->get());
1485  if (linalgOp.isDpsInput(opOperand)) {
1486  // 3.a.i. For input reads we use the canonical vector shape.
1487  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1488  readType = state.getCanonicalVecType(elemType);
1489  } else {
1490  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1491  // reductions), the vector shape is computed by mapping the canonical
1492  // vector shape to the output domain and back to the canonical domain.
1493  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1494  readType =
1495  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1496  }
1497 
1498  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1499 
1500  Operation *read = vector::TransferReadOp::create(
1501  rewriter, loc, readType, opOperand->get(), indices,
1502  /*padding=*/std::nullopt, readMap);
1503  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1504  Value readValue = read->getResult(0);
1505 
1506  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1507  // will be in-bounds.
1508  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1509  SmallVector<bool> inBounds(readType.getRank(), true);
1510  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1511  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1512  }
1513 
1514  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1515  // TODO: remove this.
1516  if (readType.getRank() == 0)
1517  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1518  ArrayRef<int64_t>());
1519 
1520  LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1521  << "): " << readValue;
1522  bvm.map(bbarg, readValue);
1523  bvm.map(opOperand->get(), readValue);
1524  }
1525 
1527  // 4a. Register CustomVectorizationHook for yieldOp.
1528  CustomVectorizationHook vectorizeYield =
1529  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1530  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1531  };
1532  hooks.push_back(vectorizeYield);
1533 
1534  // 4b. Register CustomVectorizationHook for indexOp.
1535  CustomVectorizationHook vectorizeIndex =
1536  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1537  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1538  };
1539  hooks.push_back(vectorizeIndex);
1540 
1541  // 4c. Register CustomVectorizationHook for extractOp.
1542  CustomVectorizationHook vectorizeExtract =
1543  [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1544  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1545  };
1546  hooks.push_back(vectorizeExtract);
1547 
1548  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1549  for (Operation &op : block->getOperations()) {
1550  VectorizationHookResult result =
1551  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1552  if (result.status == VectorizationHookStatus::Failure) {
1553  LDBG() << "failed to vectorize: " << op;
1554  return failure();
1555  }
1556  if (result.status == VectorizationHookStatus::NewOp) {
1557  Operation *maybeMaskedOp =
1558  state.maskOperation(rewriter, result.newOp, linalgOp);
1559  LDBG() << "New vector op: " << *maybeMaskedOp;
1560  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1561  }
1562  }
1563 
1564  return success();
1565 }
1566 
1567 /// Given a linalg::PackOp, return the `dest` shape before any packing
1568 /// permutations.
1569 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1570  ArrayRef<int64_t> destShape) {
1571  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1572 }
1573 
1574 /// Determines whether a mask for xfer_write is trivially "all true"
1575 ///
1576 /// Given all the inputs required to generate a mask (mask sizes and shapes),
1577 /// and an xfer_write operation (write indices and the destination tensor
1578 /// shape), determines whether the corresponding mask would be trivially
1579 /// foldable (i.e., trivially "all true").
1580 ///
1581 /// Use this method to avoid generating spurious masks and relaying on
1582 /// vectorization post-processing to remove them.
1583 ///
1584 /// Pre-conditions for a mask to be trivially foldable:
1585 /// * All involved shapes (mask + destination tensor) are static.
1586 /// * All write indices are constant.
1587 /// * All mask sizes are constant (including `arith.constant`).
1588 ///
1589 /// If the pre-conditions are met, the method checks for each destination
1590 /// dimension `d`:
1591 /// (1) destDimSize[rankDiff + d] <= maskShape[d]
1592 /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1593 ///
1594 /// rankDiff = rank(dest) - rank(mask).
1595 ///
1596 /// This method takes a conservative view: it may return false even if the mask
1597 /// is technically foldable.
1598 ///
1599 /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1600 /// of the dest tensor):
1601 /// %c0 = arith.constant 0 : index
1602 /// %mask = vector.create_mask 5, 1
1603 /// vector.mask %mask {
1604 /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1605 /// {in_bounds = [true, true]}
1606 /// : vector<5x1xi32>, tensor<5x1xi32>
1607 /// }
1608 ///
1609 /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1610 /// mask is required to avoid out-of-bounds write):
1611 /// %c0 = arith.constant 0 : index
1612 /// %mask = vector.create_mask 5, 1
1613 /// vector.mask %mask {
1614 /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1615 /// {in_bounds = [true, true]}
1616 /// : vector<8x1xi32>, tensor<5x1xi32>
1617 /// }
1618 ///
1619 /// TODO: Re-use in createReadOrMaskedRead
1621  SmallVector<Value> &writeIdxs,
1622  ArrayRef<int64_t> destShape,
1623  ArrayRef<int64_t> maskShape) {
1624  // Masking is unavoidable in the case of dynamic tensors.
1625  if (ShapedType::isDynamicShape(destShape))
1626  return false;
1627 
1628  // Collect all constant mask sizes.
1629  SmallVector<int64_t, 4> cstMaskSizes;
1630  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1631  if (auto intSize = getConstantIntValue(dimSize)) {
1632  cstMaskSizes.push_back(*intSize);
1633  }
1634  }
1635 
1636  // If any of the mask sizes is non-constant, bail out.
1637  if (cstMaskSizes.size() != maskShape.size())
1638  return false;
1639 
1640  // Collect all constant write indices.
1641  SmallVector<int64_t, 4> cstWriteIdxs;
1642  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1643  APSInt intVal;
1644  if (matchPattern(idx, m_ConstantInt(&intVal))) {
1645  cstWriteIdxs.push_back(intVal.getSExtValue());
1646  }
1647  }
1648 
1649  // If any of the write indices is non-constant, bail out.
1650  if (cstWriteIdxs.size() != destShape.size())
1651  return false;
1652 
1653  // Go over all destination dims and check (1) and (2). Take into account that:
1654  // * The number of mask sizes will match the rank of the vector to store.
1655  // This could be lower than the rank of the destination tensor.
1656  // * Mask sizes could be larger than the corresponding mask shape (hence
1657  // `clamp`).
1658  // TODO: The 2nd item should be rejected by the verifier.
1659  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1660  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1661  if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1662  /*(2)*/ destShape[rankDiff + i] <
1663  (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1664  cstWriteIdxs[i]))
1665  return false;
1666  }
1667 
1668  return true;
1669 }
1670 
1671 /// Creates an optionally masked TransferWriteOp
1672 ///
1673 /// Generates the following operation:
1674 /// %res = vector.transfer_write %vecToStore into %dest
1675 ///
1676 /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1677 ///
1678 /// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1679 /// %res = vector.mask %mask {
1680 /// vector.transfer_write %vecToStore into %dest
1681 /// }
1682 ///
1683 /// The mask shape is identical to `vecToStore` (with the element type ==
1684 /// i1), and the mask values are based on the shape of the `dest` tensor.
1685 ///
1686 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1687 /// is used instead of masking:
1688 ///
1689 /// %write = vector.transfer_write %vecToStore into %dest
1690 /// in_bounds_flags = (...)
1691 /// %res = vector.transfer_write %input into %dest
1692 /// {in_bounds = in_bounds_flags}
1693 ///
1694 /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1695 /// are set to 0.
1696 static Operation *
1698  Value dest, SmallVector<Value> writeIndices = {},
1699  bool useInBoundsInsteadOfMasking = false) {
1700 
1701  ShapedType destType = cast<ShapedType>(dest.getType());
1702  int64_t destRank = destType.getRank();
1703  auto destShape = destType.getShape();
1704 
1705  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1706  int64_t vecToStoreRank = vecToStoreType.getRank();
1707  auto vecToStoreShape = vecToStoreType.getShape();
1708 
1709  // Compute the in_bounds attribute
1710  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1711  if (useInBoundsInsteadOfMasking) {
1712  // Update the inBounds attribute.
1713  // FIXME: This computation is too weak - it ignores the write indices.
1714  for (unsigned i = 0; i < vecToStoreRank; i++)
1715  inBoundsVal[i] =
1716  (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1717  ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1718  }
1719 
1720  // If missing, initialize the write indices to 0.
1721  assert((writeIndices.empty() ||
1722  writeIndices.size() == static_cast<size_t>(destRank)) &&
1723  "Invalid number of write indices!");
1724  if (writeIndices.empty()) {
1725  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1726  writeIndices.assign(destRank, zero);
1727  }
1728 
1729  // Generate the xfer_write Op
1730  Operation *write = vector::TransferWriteOp::create(builder, loc,
1731  /*vector=*/vecToStore,
1732  /*source=*/dest,
1733  /*indices=*/writeIndices,
1734  /*inBounds=*/inBoundsVal);
1735 
1736  // If masking is disabled, exit.
1737  if (useInBoundsInsteadOfMasking)
1738  return write;
1739 
1740  // Check if masking is needed. If not, exit.
1741  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1742  return write;
1743 
1744  // Compute the mask and mask the write Op.
1745  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1746  vecToStoreType.getScalableDims());
1747 
1748  SmallVector<OpFoldResult> destSizes =
1749  isa<MemRefType>(dest.getType())
1750  ? memref::getMixedSizes(builder, loc, dest)
1751  : tensor::getMixedSizes(builder, loc, dest);
1752  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1753  destSizes.end());
1754 
1755  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1756  vecToStoreShape))
1757  return write;
1758 
1759  Value maskForWrite =
1760  builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1761  return mlir::vector::maskOperation(builder, write, maskForWrite);
1762 }
1763 
1764 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1765 /// padding value and (3) input vector sizes into:
1766 ///
1767 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1768 ///
1769 /// As in the following example:
1770 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1771 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1772 ///
1773 /// This pack would be vectorized to:
1774 ///
1775 /// %load = vector.mask %mask {
1776 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1777 /// {in_bounds = [true, true, true]} :
1778 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1779 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1780 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1781 /// to vector<32x4x2x1x16xf32>
1782 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1783 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1784 /// %write = vector.transfer_write %transpose,
1785 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1786 /// {in_bounds = [true, true, true, true, true]}
1787 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1788 ///
1789 /// If the (3) input vector sizes are not provided, the vector sizes are
1790 /// determined by the result tensor shape and the `in_bounds`
1791 /// attribute is used instead of masking to mark out-of-bounds accesses.
1792 ///
1793 /// NOTE: The input vector sizes specify the dimensions corresponding to the
1794 /// outer dimensions of the output tensor. The remaining dimensions are
1795 /// computed based on, e.g., the static inner tiles.
1796 /// Supporting dynamic inner tiles will require the user to specify the
1797 /// missing vector sizes. This is left as a TODO.
1798 static LogicalResult
1799 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1800  ArrayRef<int64_t> inputVectorSizes,
1801  SmallVectorImpl<Value> &newResults) {
1802  // TODO: Introduce a parent class that will handle the insertion point update.
1803  OpBuilder::InsertionGuard g(rewriter);
1804  rewriter.setInsertionPoint(packOp);
1805 
1806  Location loc = packOp.getLoc();
1807  std::optional<Value> padValue = packOp.getPaddingValue()
1808  ? std::optional(packOp.getPaddingValue())
1809  : std::nullopt;
1810 
1811  // If the input vector sizes are not provided, then the vector sizes are
1812  // determined by the result tensor shape. In case the vector sizes aren't
1813  // provided, we update the inBounds attribute instead of masking.
1814  bool useInBoundsInsteadOfMasking = false;
1815  if (inputVectorSizes.empty()) {
1816  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1817  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1818  useInBoundsInsteadOfMasking = true;
1819  }
1820 
1821  // Create masked TransferReadOp.
1822  SmallVector<int64_t> inputShape(inputVectorSizes);
1823  auto innerTiles = packOp.getStaticInnerTiles();
1824  auto innerDimsPos = packOp.getInnerDimsPos();
1825  auto outerDimsPerm = packOp.getOuterDimsPerm();
1826  if (!outerDimsPerm.empty())
1827  applyPermutationToVector(inputShape,
1829  for (auto [idx, size] : enumerate(innerTiles))
1830  inputShape[innerDimsPos[idx]] *= size;
1831  auto maskedRead = vector::createReadOrMaskedRead(
1832  rewriter, loc, packOp.getSource(), inputShape, padValue,
1833  useInBoundsInsteadOfMasking,
1834  /*inputScalableVecSizes=*/{});
1835 
1836  // Create ShapeCastOp.
1837  SmallVector<int64_t> destShape(inputVectorSizes);
1838  destShape.append(innerTiles.begin(), innerTiles.end());
1839  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1840  packOp.getDestType().getElementType());
1841  auto shapeCastOp =
1842  vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1843 
1844  // Create TransposeOp.
1845  auto destPermutation =
1847  auto transposeOp = vector::TransposeOp::create(
1848  rewriter, loc, shapeCastOp.getResult(), destPermutation);
1849 
1850  // Create TransferWriteOp.
1852  rewriter, loc, transposeOp.getResult(), packOp.getDest());
1853  newResults.push_back(write->getResult(0));
1854  return success();
1855 }
1856 
1857 /// Given the re-associations, "collapses" the input Vector type
1858 ///
1859 /// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1860 /// differences:
1861 /// * We can safely assume that there are no dynamic sizes.
1862 /// * Scalable flags are updated alongside regular dims.
1863 ///
1864 /// When collapsing scalable flags, conservatively avoids cases with two
1865 /// scalable dims. We could re-visit this in the future.
1866 ///
1867 /// EXAMPLE:
1868 /// type = vector<4x16x[8]x16xf32>
1869 /// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1870 /// (d0, d1, d2, d3) -> (d2, d3)]
1871 /// Result:
1872 /// vector<64x[128]xf32>
1873 static VectorType getCollapsedVecType(VectorType type,
1874  ArrayRef<AffineMap> reassociation) {
1875  assert(type.getNumScalableDims() < 2 &&
1876  "Collapsing more than 1 scalable dim is not supported ATM");
1877 
1878  // Use the fact that reassociation is valid to simplify the logic: only use
1879  // each map's rank.
1880  assert(isReassociationValid(reassociation) && "invalid reassociation");
1881 
1882  auto shape = type.getShape();
1883  auto scalableFlags = type.getScalableDims();
1884  SmallVector<int64_t> newShape;
1885  SmallVector<bool> newScalableFlags;
1886 
1887  unsigned currentDim = 0;
1888  for (AffineMap m : reassociation) {
1889  unsigned dim = m.getNumResults();
1890  int64_t size = 1;
1891  bool flag = false;
1892  for (unsigned d = 0; d < dim; ++d) {
1893  size *= shape[currentDim + d];
1894  flag |= scalableFlags[currentDim + d];
1895  }
1896  newShape.push_back(size);
1897  newScalableFlags.push_back(flag);
1898  currentDim += dim;
1899  }
1900 
1901  return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1902 }
1903 
1904 /// Vectorize `linalg.unpack` as:
1905 /// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1906 ///
1907 /// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1908 /// for the xfer_read operation). This is sufficient to infer the other vector
1909 /// sizes required here.
1910 ///
1911 /// If the vector sizes are not provided:
1912 /// * the vector sizes are determined from the input tensor static shape.
1913 /// * the inBounds attribute is used instead of masking.
1914 ///
1915 /// EXAMPLE (no vector sizes):
1916 /// ```
1917 /// %unpack = linalg.unpack %src
1918 /// inner_dims_pos = [0, 1]
1919 /// inner_tiles = [8, 8]
1920 /// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1921 /// ```
1922 /// is vectorized as:
1923 /// ```
1924 /// %read = vector.transfer_read %src
1925 /// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1926 /// %tr = vector.transpose %read, [0, 2, 1, 3]
1927 /// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1928 /// %sc = vector.shape_cast %tr
1929 /// : vector<1x8x1x8xf32> to vector<8x8xf32>
1930 /// %vector = vector.transfer_write %sc into %dest
1931 /// : vector<8x8xf32>, tensor<8x8xf32>
1932 /// ```
1933 static LogicalResult
1934 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1935  ArrayRef<int64_t> inputVectorSizes,
1936  ArrayRef<bool> inputScalableVecDims,
1937  SmallVectorImpl<Value> &newResults) {
1938  if (!inputVectorSizes.empty()) {
1939  assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1940  "Invalid number of input vector sizes!");
1941  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1942  "Incompatible number of vector sizes and vector scalable flags!");
1943  }
1944 
1945  // TODO: Introduce a parent class that will handle the insertion point update.
1946  OpBuilder::InsertionGuard g(rewriter);
1947  rewriter.setInsertionPoint(unpackOp);
1948 
1949  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1950 
1951  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1952  bool useInBoundsInsteadOfMasking = false;
1953 
1954  Location loc = unpackOp->getLoc();
1955 
1956  // Obtain vector sizes for the read operation.
1957  SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1958  SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1959 
1960  // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1961  if (inputVectorSizes.empty()) {
1962  if (ShapedType::isDynamicShape(sourceShape))
1963  return failure();
1964 
1965  readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1966  useInBoundsInsteadOfMasking = true;
1967  }
1968 
1969  // -- Generate the read operation --
1971  rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
1972  useInBoundsInsteadOfMasking, readScalableVectorFlags);
1973 
1974  // -- Generate the transpose operation --
1975  PackingMetadata packMetadata;
1976  SmallVector<int64_t> lastDimToInsertPosPerm =
1977  getUnPackInverseSrcPerm(unpackOp, packMetadata);
1978  vector::TransposeOp transposeOp = vector::TransposeOp::create(
1979  rewriter, loc, readResult, lastDimToInsertPosPerm);
1980 
1981  // -- Generate the shape_cast operation --
1982  VectorType collapsedVecType = getCollapsedVecType(
1983  transposeOp.getType(),
1985  rewriter.getContext(), packMetadata.reassociations)));
1986  vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1987  rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1988 
1989  // -- Generate the write operation --
1991  rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1992  /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1993 
1994  newResults.push_back(write->getResult(0));
1995  return success();
1996 }
1997 
1998 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1999 /// and (3) all-zero lowPad to
2000 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
2001 static LogicalResult
2002 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2003  ArrayRef<int64_t> inputVectorSizes,
2004  SmallVectorImpl<Value> &newResults) {
2005  auto padValue = padOp.getConstantPaddingValue();
2006  Location loc = padOp.getLoc();
2007 
2008  // TODO: Introduce a parent class that will handle the insertion point update.
2009  OpBuilder::InsertionGuard g(rewriter);
2010  rewriter.setInsertionPoint(padOp);
2011 
2012  ReifiedRankedShapedTypeDims reifiedReturnShapes;
2013  LogicalResult status =
2014  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2015  .reifyResultShapes(rewriter, reifiedReturnShapes);
2016  (void)status; // prevent unused variable warning on non-assert builds
2017  assert(succeeded(status) && "failed to reify result shapes");
2018  auto maskedRead = vector::createReadOrMaskedRead(
2019  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2020  /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
2021 
2022  // Create Xfer write Op
2023  Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2024  padOp.getResultType().getElementType());
2025  Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2026  newResults.push_back(write->getResult(0));
2027  return success();
2028 }
2029 
2030 // TODO: probably need some extra checks for reduction followed by consumer
2031 // ops that may not commute (e.g. linear reduction + non-linear instructions).
2032 static LogicalResult reductionPreconditions(LinalgOp op) {
2033  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2034  LDBG() << "reduction precondition failed: no reduction iterator";
2035  return failure();
2036  }
2037  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2038  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2039  if (indexingMap.isPermutation())
2040  continue;
2041 
2042  Operation *reduceOp = matchLinalgReduction(&opOperand);
2043  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2044  LDBG() << "reduction precondition failed: reduction detection failed";
2045  return failure();
2046  }
2047  }
2048  return success();
2049 }
2050 
2051 static LogicalResult
2053  bool flatten1DDepthwiseConv) {
2054  if (flatten1DDepthwiseConv) {
2055  LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2056  "supported";
2057  return failure();
2058  }
2059 
2060  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2061  LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2062  return failure();
2063  }
2064 
2065  // Support dynamic shapes in 1D depthwise convolution, but only in the
2066  // _channel_ dimension.
2067  Value lhs = conv.getDpsInputOperand(0)->get();
2068  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2069  auto shapeWithoutCh = lhsShape.drop_back(1);
2070  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2071  LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2072  "channel dim can be dynamic";
2073  return failure();
2074  }
2075 
2076  return success();
2077 }
2078 
2079 static LogicalResult
2081  bool flatten1DDepthwiseConv) {
2082  if (isa<ConvolutionOpInterface>(op.getOperation()))
2083  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2084 
2085  if (hasReductionIterator(op))
2086  return reductionPreconditions(op);
2087 
2088  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2089  // linalg.copy ops and ops that implement ContractionOpInterface for now.
2090  if (!isElementwise(op) &&
2091  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2092  op.getOperation()))
2093  return failure();
2094 
2095  LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2096  return success();
2097 }
2098 
2099 //// This hook considers two cases:
2100 /// (1) If the input-vector-sizes are empty, then the vector sizes will be
2101 /// infered. This is only possible when all shapes are static.
2102 /// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2103 /// carry out basic sanity-checking.
2104 static LogicalResult
2105 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2106  ArrayRef<int64_t> inputVectorSizes) {
2107  // If there are no input vector sizes and all shapes are static, there is
2108  // nothing left to check.
2109  if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2110  unpackOp.getSourceType().hasStaticShape())
2111  return success();
2112 
2113  // The number of input vector sizes must be equal to:
2114  // * read-vector-rank
2115  if (!inputVectorSizes.empty() &&
2116  (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2117  LDBG() << "Incorrect number of input vector sizes";
2118  return failure();
2119  }
2120 
2121  // Check the vector sizes for the read operation.
2123  unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2124  LDBG() << "Invalid vector sizes for the read operation";
2125  return failure();
2126  }
2127 
2128  return success();
2129 }
2130 
2131 static LogicalResult
2132 vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2133  ArrayRef<int64_t> inputVectorSizes) {
2134 
2135  TypedValue<RankedTensorType> source = sliceOp.getSource();
2136  auto sourceType = source.getType();
2137  if (!VectorType::isValidElementType(sourceType.getElementType()))
2138  return failure();
2139 
2140  // Get the pad value.
2141  // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2142  // scalar padding value. Note that:
2143  // * for in-bounds accesses,
2144  // the value is actually irrelevant. There are 2 cases in which xfer.read
2145  // accesses are known to be in-bounds:
2146  // 1. The source shape is static (output vector sizes would be based on
2147  // the source shape and hence all memory accesses would be in-bounds),
2148  // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2149  // this case it is safe to assume that all memory accesses are in-bounds.
2150  //
2151  // When the value is not known and not needed, use 0. Otherwise, bail out.
2152  Value padValue = getStaticPadVal(sliceOp);
2153  bool isOutOfBoundsRead =
2154  !sourceType.hasStaticShape() && inputVectorSizes.empty();
2155 
2156  if (!padValue && isOutOfBoundsRead) {
2157  LDBG() << "Failed to get a pad value for out-of-bounds read access";
2158  return failure();
2159  }
2160  return success();
2161 }
2162 
2163 /// Vectorize a named linalg contraction op into:
2164 /// vector::TransferReadOp - Reads vectors from the operands
2165 /// vector::ContractionOp - Performs contraction
2166 /// vector::TransferWriteOp - Write the result vector back to the
2167 /// destination
2168 /// The operands shapes are preserved and loaded directly into vectors.
2169 /// Any further permutations or numerical casting remain within contraction op.
2170 static LogicalResult
2172  LinalgOp linalgOp,
2173  SmallVectorImpl<Value> &newResults) {
2174  Location loc = linalgOp.getLoc();
2175  MLIRContext *ctx = linalgOp.getContext();
2176 
2177  // For simplicity, contraction vectorization is limited to linalg named ops.
2178  // Generic op is ignored as not every arbitrary contraction body can be
2179  // expressed by a vector.contract.
2180  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2181  return failure();
2182 
2183  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2184  Operation *reduceOp = matchLinalgReduction(outOperand);
2185  auto maybeKind = getCombinerOpKind(reduceOp);
2186  if (!maybeKind) {
2187  LDBG() << "Failed to determine contraction combining kind.";
2188  return failure();
2189  }
2190 
2191  // Check that all dimensions are present in the input operands.
2192  // Arbitrary broadcasts are not supported by the vector contraction.
2193  // Broadcasts are expected to be decomposed before vectorization.
2194  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2195  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2196  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2197  LDBG() << "Contractions with broadcasts are not supported.";
2198  return failure();
2199  }
2200 
2201  // Load operands.
2202  SmallVector<Value> vecOperands;
2203  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2204  // The operand vector shape is computed by mapping the canonical vector
2205  // shape to the operand's domain. Further permutations are left as a part of
2206  // the contraction.
2207  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2209  indexingMap.getNumResults(), rewriter.getContext());
2210  Type elemType = getElementTypeOrSelf(opOperand.get());
2211  VectorType readType =
2212  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2213 
2215  rewriter, loc, opOperand.get(), readType.getShape(),
2216  /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2217  /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2218  vecOperands.push_back(read);
2219  }
2220 
2221  // Remap iterators from linalg to vector.
2222  SmallVector<Attribute> iterAttrs;
2223  auto iterators = linalgOp.getIteratorTypesArray();
2224  for (utils::IteratorType iter : iterators) {
2225  auto vecIter = iter == utils::IteratorType::parallel
2226  ? vector::IteratorType::parallel
2227  : vector::IteratorType::reduction;
2228  iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2229  }
2230 
2231  // Create contraction.
2232  Operation *contractOp = vector::ContractionOp::create(
2233  rewriter, loc, /*lhs=*/vecOperands[0],
2234  /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2235  linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2236  contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2237 
2238  // Store result.
2240  rewriter, loc, contractOp->getResult(0), outOperand->get());
2241 
2242  // Finalize.
2243  if (!write->getResults().empty())
2244  newResults.push_back(write->getResult(0));
2245 
2246  return success();
2247 }
2248 
2249 namespace {
2250 enum class ConvOperationKind { Conv, Pool };
2251 } // namespace
2252 
2254  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2255  isa<BlockArgument>(op->getOperand(0));
2256 }
2257 
2258 // Returns the ConvOperationKind of the op using reduceOp of the generic
2259 // payload. If it is neither a convolution nor a pooling, it returns
2260 // std::nullopt.
2261 //
2262 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2263 // + yield) and rhs is not used) then it is the body of a pooling
2264 // If conv, check for single `mul` predecessor. The `mul` operands must be
2265 // block arguments or extension of block arguments.
2266 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2267 // must be block arguments or extension of block arguments.
2268 static std::optional<ConvOperationKind>
2270  int numBlockArguments =
2271  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2272 
2273  switch (numBlockArguments) {
2274  case 1: {
2275  // Will be convolution if feeder is a MulOp.
2276  // A strength reduced version of MulOp for i1 type is AndOp which is also
2277  // supported. Otherwise, it can be pooling. This strength reduction logic
2278  // is in `buildBinaryFn` helper in the Linalg dialect.
2279  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2280  llvm::IsaPred<BlockArgument>);
2281  assert(feedValIt != reduceOp->operand_end() &&
2282  "Expected a non-block argument operand");
2283  Operation *feedOp = (*feedValIt).getDefiningOp();
2284  if (isCastOfBlockArgument(feedOp)) {
2285  return ConvOperationKind::Pool;
2286  }
2287 
2288  if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2289  (isa<arith::AndIOp>(feedOp) &&
2290  feedOp->getResultTypes()[0].isInteger(1))) &&
2291  llvm::all_of(feedOp->getOperands(), [](Value v) {
2292  if (isa<BlockArgument>(v))
2293  return true;
2294  if (Operation *op = v.getDefiningOp())
2295  return isCastOfBlockArgument(op);
2296  return false;
2297  }))) {
2298  return std::nullopt;
2299  }
2300 
2301  return ConvOperationKind::Conv;
2302  }
2303  case 2:
2304  // Must be pooling
2305  return ConvOperationKind::Pool;
2306  default:
2307  return std::nullopt;
2308  }
2309 }
2310 
2311 static bool isSupportedPoolKind(vector::CombiningKind kind) {
2312  switch (kind) {
2313  case vector::CombiningKind::ADD:
2314  case vector::CombiningKind::MAXNUMF:
2315  case vector::CombiningKind::MAXIMUMF:
2316  case vector::CombiningKind::MAXSI:
2317  case vector::CombiningKind::MAXUI:
2318  case vector::CombiningKind::MINNUMF:
2319  case vector::CombiningKind::MINIMUMF:
2320  case vector::CombiningKind::MINSI:
2322  return true;
2323  default:
2324  return false;
2325  }
2326 }
2327 
2328 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2329  auto getOperandType = [&](auto operand) {
2330  return dyn_cast<ShapedType>((operand->get()).getType());
2331  };
2332  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2333  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2334  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2335  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2336  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2337  // Note that this also ensures 2D and 3D convolutions are rejected.
2338  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2339  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2340  return failure();
2341 
2342  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2343  if (!reduceOp)
2344  return failure();
2345 
2346  auto maybeOper = getConvOperationKind(reduceOp);
2347  if (!maybeOper.has_value())
2348  return failure();
2349 
2350  auto maybeKind = getCombinerOpKind(reduceOp);
2351  // Typically convolution will have a `Add` CombiningKind but for i1 type it
2352  // can get strength reduced to `OR` which is also supported. This strength
2353  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2354  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2355  *maybeKind != vector::CombiningKind::OR) &&
2356  (*maybeOper != ConvOperationKind::Pool ||
2357  !isSupportedPoolKind(*maybeKind)))) {
2358  return failure();
2359  }
2360 
2361  auto rhsRank = rhsShapedType.getRank();
2362  if (*maybeOper == ConvOperationKind::Pool) {
2363  if (rhsRank != 1)
2364  return failure();
2365  } else {
2366  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2367  return failure();
2368  }
2369 
2370  return success();
2371 }
2372 
2373 static LogicalResult vectorizeLinalgOpPrecondition(
2374  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2375  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2376  // tensor with dimension of 0 cannot be vectorized.
2377  if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2378  return llvm::is_contained(linalgOp.getShape(&operand), 0);
2379  }))
2380  return failure();
2381  // Check API contract for input vector sizes.
2382  if (!inputVectorSizes.empty() &&
2383  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2384  inputVectorSizes)))
2385  return failure();
2386 
2387  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2388  linalgOp, flatten1DDepthwiseConv))) {
2389  LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2390  return failure();
2391  }
2392 
2394 
2395  // Register CustomVectorizationPrecondition for extractOp.
2396  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2397 
2398  // All types in the body should be a supported element type for VectorType.
2399  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2400  // Check if any custom hook can vectorize the inner op.
2401  if (llvm::any_of(
2402  customPreconditions,
2403  [&](const CustomVectorizationPrecondition &customPrecondition) {
2404  return succeeded(
2405  customPrecondition(&innerOp, vectorizeNDExtract));
2406  })) {
2407  continue;
2408  }
2409  if (!llvm::all_of(innerOp.getOperandTypes(),
2410  VectorType::isValidElementType)) {
2411  return failure();
2412  }
2413  if (!llvm::all_of(innerOp.getResultTypes(),
2414  VectorType::isValidElementType)) {
2415  return failure();
2416  }
2417  }
2418  if (isElementwise(linalgOp))
2419  return success();
2420 
2421  // TODO: isaConvolutionOpInterface that can also infer from generic
2422  // features. But we will still need stride/dilation attributes that will be
2423  // annoying to reverse-engineer...
2424  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2425  return vectorizeConvOpPrecondition(linalgOp);
2426 
2427  // TODO: the common vector shape is equal to the static loop sizes only when
2428  // all indexing maps are projected permutations. For convs and stencils the
2429  // logic will need to evolve.
2430  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2431  LDBG() << "precondition failed: not projected permutations";
2432  return failure();
2433  }
2434  if (failed(reductionPreconditions(linalgOp))) {
2435  LDBG() << "precondition failed: reduction preconditions";
2436  return failure();
2437  }
2438  return success();
2439 }
2440 
2441 static LogicalResult
2442 vectorizePackOpPrecondition(linalg::PackOp packOp,
2443  ArrayRef<int64_t> inputVectorSizes) {
2444  auto padValue = packOp.getPaddingValue();
2445  Attribute cstAttr;
2446  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2447  LDBG() << "pad value is not constant: " << packOp;
2448  return failure();
2449  }
2450 
2451  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2452  bool satisfyEmptyCond = true;
2453  if (inputVectorSizes.empty()) {
2454  if (!packOp.getDestType().hasStaticShape() ||
2455  !packOp.getSourceType().hasStaticShape())
2456  satisfyEmptyCond = false;
2457  }
2458 
2459  if (!satisfyEmptyCond &&
2461  resultTensorShape.take_front(packOp.getSourceRank()),
2462  inputVectorSizes)))
2463  return failure();
2464 
2465  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2466  return !getConstantIntValue(v).has_value();
2467  })) {
2468  LDBG() << "inner_tiles must be constant: " << packOp;
2469  return failure();
2470  }
2471 
2472  return success();
2473 }
2474 
2475 static LogicalResult
2476 vectorizePadOpPrecondition(tensor::PadOp padOp,
2477  ArrayRef<int64_t> inputVectorSizes) {
2478  auto padValue = padOp.getConstantPaddingValue();
2479  if (!padValue) {
2480  LDBG() << "pad value is not constant: " << padOp;
2481  return failure();
2482  }
2483 
2484  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2485  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2486  inputVectorSizes)))
2487  return failure();
2488 
2489  // Padding with non-zero low pad values is not supported, unless the
2490  // corresponding result dim is 1 as this would require shifting the results to
2491  // the right for the low padded dims by the required amount of low padding.
2492  // However, we do support low padding if the dims being low padded have result
2493  // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2494  // input size of that dimension will be dynamically zero (as the sum of the
2495  // low pad and input dim size has to be one) and hence we will create a zero
2496  // mask as the lowering logic just makes the mask one for the input dim size -
2497  // which is zero here. Hence we will load the pad value which is what we want
2498  // in this case. If the low pad is dynamically zero then the lowering is
2499  // correct as well as no shifts are necessary.
2500  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2501  Value padValue = en.value();
2502  unsigned pos = en.index();
2503  std::optional<int64_t> pad = getConstantIntValue(padValue);
2504  return (!pad.has_value() || pad.value() != 0) &&
2505  resultTensorShape[pos] != 1;
2506  })) {
2507  LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2508  return failure();
2509  }
2510 
2511  return success();
2512 }
2513 
2514 /// Preconditions for scalable vectors.
2515 ///
2516 /// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2517 /// models the fact that in practice we would only make selected dimensions
2518 /// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2519 /// unconditionally - we are yet to identify meaningful conditions.
2520 static LogicalResult
2522  ArrayRef<int64_t> inputVectorSizes,
2523  ArrayRef<bool> inputScalableVecDims) {
2524  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2525  "Number of input vector sizes and scalable dims doesn't match");
2526 
2527  size_t numOfScalableDims =
2528  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2529 
2530  if (numOfScalableDims == 0)
2531  return success();
2532 
2533  auto linalgOp = dyn_cast<LinalgOp>(op);
2534 
2535  // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2536  // exception of UnpackOp for which there is a dedicated hook.
2537  if (!linalgOp) {
2538  return success(isa<linalg::UnPackOp>(op));
2539  }
2540 
2541  // Cond 2: There's been no need for more than 2 scalable dims so far
2542  if (numOfScalableDims > 2)
2543  return failure();
2544 
2545  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2546  // it matches one of the supported cases:
2547  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2548  // (*).
2549  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2550  // parallel dims.
2551  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2552  // dim.
2553  // The 2nd restriction above means that only Matmul-like Ops are supported
2554  // when 2 dims are scalable, e.g. :
2555  // * iterators = [parallel, parallel, reduction]
2556  // * scalable flags = [true, true, false]
2557  //
2558  // (*) Non-unit dims get folded away in practice.
2559  // TODO: Relax these conditions as good motivating examples are identified.
2560 
2561  // Find the first scalable flag.
2562  bool seenNonUnitParallel = false;
2563  auto iterators = linalgOp.getIteratorTypesArray();
2564  SmallVector<bool> scalableFlags(inputScalableVecDims);
2565  int64_t idx = scalableFlags.size() - 1;
2566  while (!scalableFlags[idx]) {
2567  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2568  seenNonUnitParallel |=
2569  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2570 
2571  iterators.pop_back();
2572  scalableFlags.pop_back();
2573  --idx;
2574  }
2575 
2576  // Analyze the iterator corresponding to the first scalable dim.
2577  switch (iterators.back()) {
2578  case utils::IteratorType::reduction: {
2579  // Check 3. above is met.
2580  if (iterators.size() != inputVectorSizes.size()) {
2581  LDBG() << "Non-trailing reduction dim requested for scalable "
2582  "vectorization";
2583  return failure();
2584  }
2585  if (isa<linalg::MatmulOp>(op)) {
2586  LDBG()
2587  << "Scalable vectorization of the reduction dim in Matmul-like ops "
2588  "is not supported";
2589  return failure();
2590  }
2591  break;
2592  }
2593  case utils::IteratorType::parallel: {
2594  // Check 1. and 2. above are met.
2595  if (seenNonUnitParallel) {
2596  LDBG() << "Inner parallel dim not requested for scalable "
2597  "vectorization";
2598  return failure();
2599  }
2600  break;
2601  }
2602  }
2603 
2604  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2605  // supported for which expect the folowing config:
2606  // * iterators = [parallel, parallel, reduction]
2607  // * scalable flags = [true, true, false]
2608  if (numOfScalableDims == 2) {
2609  // Disallow below case which breaks 3. above:
2610  // * iterators = [..., parallel, reduction]
2611  // * scalable flags = [..., true, true]
2612  if (iterators.back() == utils::IteratorType::reduction) {
2613  LDBG() << "Higher dim than the trailing reduction dim requested for "
2614  "scalable "
2615  "vectorizatio";
2616  return failure();
2617  }
2618  scalableFlags.pop_back();
2619  iterators.pop_back();
2620 
2621  if (!scalableFlags.back() ||
2622  (iterators.back() != utils::IteratorType::parallel))
2623  return failure();
2624  }
2625 
2626  // Cond 4: Only the following ops are supported in the
2627  // presence of scalable vectors
2628  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2629  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2630  isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2631  isa<linalg::BatchMmt4DOp>(op) ||
2632  hasReductionIterator(linalgOp));
2633 }
2634 
2636  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2637  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2638  bool flatten1DDepthwiseConv) {
2639 
2640  if (!hasVectorizationImpl(op))
2641  return failure();
2642 
2643  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2644  inputScalableVecDims)))
2645  return failure();
2646 
2648  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2649  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2650  vectorizeNDExtract,
2651  flatten1DDepthwiseConv);
2652  })
2653  .Case<tensor::PadOp>([&](auto padOp) {
2654  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2655  })
2656  .Case<linalg::PackOp>([&](auto packOp) {
2657  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2658  })
2659  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2660  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2661  })
2662  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2663  return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2664  })
2665  .Default([](auto) { return failure(); });
2666 }
2667 
2668 /// Converts affine.apply Ops to arithmetic operations.
2669 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2670  OpBuilder::InsertionGuard g(rewriter);
2671  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2672 
2673  for (auto op : make_early_inc_range(toReplace)) {
2674  rewriter.setInsertionPoint(op);
2675  auto expanded = affine::expandAffineExpr(
2676  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2677  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2678  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2679  rewriter.replaceOp(op, expanded);
2680  }
2681 }
2682 
2684  return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2685  tensor::InsertSliceOp>(op);
2686 }
2687 
2688 FailureOr<VectorizationResult> mlir::linalg::vectorize(
2689  RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2690  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2691  bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2692  bool createNamedContraction) {
2693  LDBG() << "Attempting to vectorize: " << *op;
2694  LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2695  LDBG() << "Input scalable vector dims: "
2696  << llvm::interleaved(inputScalableVecDims);
2697 
2698  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2699  vectorizeNDExtract,
2700  flatten1DDepthwiseConv))) {
2701  LDBG() << "Vectorization pre-conditions failed";
2702  return failure();
2703  }
2704 
2705  // Initialize vectorization state.
2706  VectorizationState state(rewriter);
2707  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2708  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2709  inputScalableVecDims,
2710  assumeDynamicDimsMatchVecSizes))) {
2711  LDBG() << "Vectorization state couldn't be initialized";
2712  return failure();
2713  }
2714  }
2715 
2716  SmallVector<Value> results;
2717  auto vectorizeResult =
2719  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2720  // TODO: isaConvolutionOpInterface that can also infer from
2721  // generic features. Will require stride/dilation attributes
2722  // inference.
2723  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2724  FailureOr<Operation *> convOr = vectorizeConvolution(
2725  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2726  flatten1DDepthwiseConv);
2727  if (succeeded(convOr)) {
2728  llvm::append_range(results, (*convOr)->getResults());
2729  return success();
2730  }
2731 
2732  LDBG() << "Unsupported convolution can't be vectorized.";
2733  return failure();
2734  }
2735 
2736  if (createNamedContraction &&
2737  isa<ContractionOpInterface>(linalgOp.getOperation()))
2738  return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2739  results);
2740 
2741  LDBG()
2742  << "Vectorize generic by broadcasting to the canonical vector "
2743  "shape";
2744 
2745  // Pre-process before proceeding.
2746  convertAffineApply(rewriter, linalgOp);
2747 
2748  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2749  // to 'OpBuilder' when it is passed over to some methods like
2750  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2751  // erase an op within these methods, the actual rewriter won't be
2752  // notified and we will end up with read-after-free issues!
2753  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2754  })
2755  .Case<tensor::PadOp>([&](auto padOp) {
2756  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2757  results);
2758  })
2759  .Case<linalg::PackOp>([&](auto packOp) {
2760  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2761  results);
2762  })
2763  .Case<linalg::UnPackOp>([&](auto unpackOp) {
2764  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2765  inputVectorSizes,
2766  inputScalableVecDims, results);
2767  })
2768  .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2769  return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2770  results);
2771  })
2772  .Default([](auto) { return failure(); });
2773 
2774  if (failed(vectorizeResult)) {
2775  LDBG() << "Vectorization failed";
2776  return failure();
2777  }
2778 
2779  return VectorizationResult{results};
2780 }
2781 
2783  memref::CopyOp copyOp) {
2784  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2785  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2786  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2787  return failure();
2788 
2789  auto srcElementType = getElementTypeOrSelf(srcType);
2790  auto dstElementType = getElementTypeOrSelf(dstType);
2791  if (!VectorType::isValidElementType(srcElementType) ||
2792  !VectorType::isValidElementType(dstElementType))
2793  return failure();
2794 
2795  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2796  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2797 
2798  Location loc = copyOp->getLoc();
2799  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2800  SmallVector<Value> indices(srcType.getRank(), zero);
2801 
2802  Value readValue = vector::TransferReadOp::create(
2803  rewriter, loc, readType, copyOp.getSource(), indices,
2804  /*padding=*/std::nullopt,
2805  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2806  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2807  readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2808  ArrayRef<int64_t>());
2809  readValue =
2810  vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2811  }
2812  Operation *writeValue = vector::TransferWriteOp::create(
2813  rewriter, loc, readValue, copyOp.getTarget(), indices,
2814  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2815  rewriter.replaceOp(copyOp, writeValue->getResults());
2816  return success();
2817 }
2818 
2819 //----------------------------------------------------------------------------//
2820 // Misc. vectorization patterns.
2821 //----------------------------------------------------------------------------//
2822 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2823 /// given operation type OpTy.
2824 template <typename OpTy>
2825 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2827 
2828  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2829  PatternRewriter &rewriter) const final {
2830  bool changed = false;
2831  // Insert users in vector, because some users may be replaced/removed.
2832  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2833  if (auto op = dyn_cast<OpTy>(user))
2834  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2835  return success(changed);
2836  }
2837 
2838 protected:
2839  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2840  tensor::PadOp padOp, OpTy op) const = 0;
2841 };
2842 
2843 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2844 /// ```
2845 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2846 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2847 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2848 /// ```
2849 /// is rewritten to:
2850 /// ```
2851 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2852 /// {in_bounds = [true, true]}
2853 /// : tensor<?x?xf32>, vector<17x5xf32>
2854 /// ```
2855 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2856 /// sure that the original padding value %cst was never used.
2857 ///
2858 /// This rewrite is possible if:
2859 /// - `xferOp` has no out-of-bounds dims or mask.
2860 /// - Low padding is static 0.
2861 /// - Single, scalar padding value.
2863  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2865  vector::TransferReadOp>::VectorizePadOpUserPattern;
2866 
2867  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2868  vector::TransferReadOp xferOp) const override {
2869  // Low padding must be static 0.
2870  if (!padOp.hasZeroLowPad())
2871  return failure();
2872  // Pad value must be a constant.
2873  auto padValue = padOp.getConstantPaddingValue();
2874  if (!padValue)
2875  return failure();
2876  // Padding value of existing `xferOp` is unused.
2877  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2878  return failure();
2879 
2880  rewriter.modifyOpInPlace(xferOp, [&]() {
2881  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2882  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2883  rewriter.getBoolArrayAttr(inBounds));
2884  xferOp.getBaseMutable().assign(padOp.getSource());
2885  xferOp.getPaddingMutable().assign(padValue);
2886  });
2887 
2888  return success();
2889  }
2890 };
2891 
2892 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2893 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2894 /// value, where the same amount of padding is immediately removed again after
2895 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2896 /// tensor value and apply out-of-bounds masking. E.g.:
2897 /// ```
2898 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2899 /// : tensor<...> to tensor<?x?xf32>
2900 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2901 /// %2 = vector.transfer_write %vec, %1[...]
2902 /// : vector<17x5xf32>, tensor<17x5xf32>
2903 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2904 /// : tensor<17x5xf32> to tensor<?x?xf32>
2905 /// ```
2906 /// is rewritten to:
2907 /// ```
2908 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2909 /// : tensor<...> to tensor<?x?xf32>
2910 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2911 /// tensor<?x?xf32>
2912 /// ```
2913 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2914 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2915 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2916 /// from %r's old dimensions.
2917 ///
2918 /// This rewrite is possible if:
2919 /// - Low padding is static 0.
2920 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2921 /// ExtractSliceOp trims the same amount of padding that was added
2922 /// beforehand.
2923 /// - Single, scalar padding value.
2925  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2927  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2928 
2929  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2930  vector::TransferWriteOp xferOp) const override {
2931  // TODO: support 0-d corner case.
2932  if (xferOp.getTransferRank() == 0)
2933  return failure();
2934 
2935  // Low padding must be static 0.
2936  if (!padOp.hasZeroLowPad())
2937  return failure();
2938  // Pad value must be a constant.
2939  auto padValue = padOp.getConstantPaddingValue();
2940  if (!padValue)
2941  return failure();
2942  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2943  if (!xferOp->hasOneUse())
2944  return failure();
2945  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2946  if (!trimPadding)
2947  return failure();
2948  // Only static zero offsets supported when trimming padding.
2949  if (!trimPadding.hasZeroOffset())
2950  return failure();
2951  // trimPadding must remove the amount of padding that was added earlier.
2952  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2953  return failure();
2954 
2955  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2956  rewriter.setInsertionPoint(xferOp);
2957 
2958  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2959  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2960  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2961  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2962  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2963  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2964 
2965  return success();
2966  }
2967 
2968  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2969  /// i.e., same dimensions.
2970  ///
2971  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2972  /// dimensions, this function tries to infer the (static) tensor size by
2973  /// looking at the defining op and utilizing op-specific knowledge.
2974  ///
2975  /// This is a conservative analysis. In case equal tensor sizes cannot be
2976  /// proven statically, this analysis returns `false` even though the tensor
2977  /// sizes may turn out to be equal at runtime.
2978  bool hasSameTensorSize(Value beforePadding,
2979  tensor::ExtractSliceOp afterTrimming) const {
2980  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2981  // result and CastOp operand.
2982  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2983  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2984  return true;
2985 
2986  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2987  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2988  // Only RankedTensorType supported.
2989  if (!t1 || !t2)
2990  return false;
2991  // Rank of both values must be the same.
2992  if (t1.getRank() != t2.getRank())
2993  return false;
2994 
2995  // All static dimensions must be the same. Mixed cases (e.g., dimension
2996  // static in `t1` but dynamic in `t2`) are not supported.
2997  for (unsigned i = 0; i < t1.getRank(); ++i) {
2998  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2999  return false;
3000  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3001  return false;
3002  }
3003 
3004  // Nothing more to check if all dimensions are static.
3005  if (t1.getNumDynamicDims() == 0)
3006  return true;
3007 
3008  // All dynamic sizes must be the same. The only supported case at the
3009  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
3010  // thereof).
3011 
3012  // Apart from CastOp, only ExtractSliceOp is supported.
3013  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
3014  if (!beforeSlice)
3015  return false;
3016 
3017  assert(static_cast<size_t>(t1.getRank()) ==
3018  beforeSlice.getMixedSizes().size());
3019  assert(static_cast<size_t>(t2.getRank()) ==
3020  afterTrimming.getMixedSizes().size());
3021 
3022  for (unsigned i = 0; i < t1.getRank(); ++i) {
3023  // Skip static dimensions.
3024  if (!t1.isDynamicDim(i))
3025  continue;
3026  auto size1 = beforeSlice.getMixedSizes()[i];
3027  auto size2 = afterTrimming.getMixedSizes()[i];
3028 
3029  // Case 1: Same value or same constant int.
3030  if (isEqualConstantIntOrValue(size1, size2))
3031  continue;
3032 
3033  // Other cases: Take a deeper look at defining ops of values.
3034  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3035  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3036  if (!v1 || !v2)
3037  return false;
3038 
3039  // Case 2: Both values are identical AffineMinOps. (Should not happen if
3040  // CSE is run.)
3041  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3042  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3043  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3044  minOp1.getOperands() == minOp2.getOperands())
3045  continue;
3046 
3047  // Add additional cases as needed.
3048  }
3049 
3050  // All tests passed.
3051  return true;
3052  }
3053 };
3054 
3055 /// Returns the effective Pad value for the input op, provided it's a scalar.
3056 ///
3057 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3058 /// this Op performs padding, retrieve the padding value provided that it's
3059 /// a scalar and static/fixed for all the padded values. Returns an empty value
3060 /// otherwise.
3061 ///
3062 /// TODO: This is used twice (when checking vectorization pre-conditions and
3063 /// when vectorizing). Cache results instead of re-running.
3065  if (!op)
3066  return {};
3067 
3068  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3069  // being broadcast, provided that it's a scalar.
3070  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3071  auto source = bcast.getSource();
3072  if (llvm::dyn_cast<VectorType>(source.getType()))
3073  return {};
3074 
3075  return source;
3076  }
3077 
3078  // 2. linalg.fill - use the scalar input value that used to fill the output
3079  // tensor.
3080  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3081  return fill.getInputs()[0];
3082  }
3083 
3084  // 3. tensor.generateOp - can't guarantee the value is fixed without
3085  // analysing, bail out.
3086  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3087  return {};
3088  }
3089 
3090  // 4. vector.transfer_write - inspect the input vector that's written from. If
3091  // if contains a single value that has been broadcast (e.g. via
3092  // vector.broadcast), extract it, fail otherwise.
3093  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3094  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3095 
3096  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3097  // than the input tensor, then, provided it's constant, we'll extract the
3098  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3099  // TODO: Clarify the semantics when the input tensor is larger than the
3100  // destination.
3101  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3102  return getStaticPadVal(slice.getDest().getDefiningOp());
3103 
3104  return {};
3105 }
3106 
3107 static LogicalResult
3108 vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3109  ArrayRef<int64_t> inputVectorSizes,
3110  SmallVectorImpl<Value> &newResults) {
3111  // TODO: Introduce a parent class that will handle the insertion point update.
3112  OpBuilder::InsertionGuard g(rewriter);
3113  rewriter.setInsertionPoint(sliceOp);
3114 
3115  TypedValue<RankedTensorType> source = sliceOp.getSource();
3116  auto sourceType = source.getType();
3117  auto resultType = sliceOp.getResultType();
3118 
3119  Value padValue = getStaticPadVal(sliceOp);
3120 
3121  if (!padValue) {
3122  auto elemType = sourceType.getElementType();
3123  padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3124  rewriter.getZeroAttr(elemType));
3125  }
3126 
3127  // 2. Get the vector shape
3128  SmallVector<int64_t> vecShape;
3129  size_t rankDiff = resultType.getRank() - sourceType.getRank();
3130  for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3131  if (!inputVectorSizes.empty()) {
3132  vecShape.push_back(inputVectorSizes[i]);
3133  } else if (!sourceType.isDynamicDim(i)) {
3134  vecShape.push_back(sourceType.getDimSize(i));
3135  } else if (!resultType.isDynamicDim(i)) {
3136  // Source shape is not statically known, but result shape is.
3137  // Vectorize with size of result shape. This may be larger than the
3138  // source size.
3139  // FIXME: Using rankDiff implies that the source tensor is inserted at
3140  // the end of the destination tensor. However, that's not required.
3141  vecShape.push_back(resultType.getDimSize(rankDiff + i));
3142  } else {
3143  // Neither source nor result dim of padOp is static. Cannot vectorize
3144  // the copy.
3145  return failure();
3146  }
3147  }
3148  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3149 
3150  // 3. Generate TransferReadOp + TransferWriteOp
3151  auto loc = sliceOp.getLoc();
3152 
3153  // Create read
3154  SmallVector<Value> readIndices(
3155  vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3157  rewriter, loc, source, vecType.getShape(), padValue,
3158  /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3159  /*inputScalableVecSizes=*/{});
3160 
3161  // Create write
3162  auto writeIndices =
3163  getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3164  Operation *write =
3165  createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3166  writeIndices, inputVectorSizes.empty());
3167 
3168  // 4. Finalize
3169  newResults.push_back(write->getResult(0));
3170 
3171  return success();
3172 }
3173 
3174 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3175 /// ```
3176 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3177 /// %r = tensor.insert_slice %0
3178 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3179 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3180 /// ```
3181 /// is rewritten to:
3182 /// ```
3183 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
3184 /// : tensor<?x?xf32>, vector<17x5xf32>
3185 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3186 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3187 /// ```
3188 ///
3189 /// This rewrite is possible if:
3190 /// - Low padding is static 0.
3191 /// - `padOp` result shape is static.
3192 /// - The entire padded tensor is inserted.
3193 /// (Implies that sizes of `insertOp` are all static.)
3194 /// - Only unit strides in `insertOp`.
3195 /// - Single, scalar padding value.
3196 /// - `padOp` result not used as destination.
3198  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3200  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3201 
3202  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3203  tensor::InsertSliceOp insertOp) const override {
3204  // Low padding must be static 0.
3205  if (!padOp.hasZeroLowPad())
3206  return failure();
3207  // Only unit stride supported.
3208  if (!insertOp.hasUnitStride())
3209  return failure();
3210  // Pad value must be a constant.
3211  auto padValue = padOp.getConstantPaddingValue();
3212  if (!padValue)
3213  return failure();
3214  // Dynamic shapes not supported.
3215  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3216  return failure();
3217  // Pad result not used as destination.
3218  if (insertOp.getDest() == padOp.getResult())
3219  return failure();
3220 
3221  auto vecType = VectorType::get(padOp.getType().getShape(),
3222  padOp.getType().getElementType());
3223  unsigned vecRank = vecType.getRank();
3224  unsigned tensorRank = insertOp.getType().getRank();
3225 
3226  // Check if sizes match: Insert the entire tensor into most minor dims.
3227  // (No permutations allowed.)
3228  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3229  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3230  if (!llvm::all_of(
3231  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3232  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3233  }))
3234  return failure();
3235 
3236  // Insert the TransferReadOp and TransferWriteOp at the position of the
3237  // InsertSliceOp.
3238  rewriter.setInsertionPoint(insertOp);
3239 
3240  // Generate TransferReadOp: Read entire source tensor and add high
3241  // padding.
3242  SmallVector<Value> readIndices(
3243  vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3244  auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3245  vecType, padOp.getSource(),
3246  readIndices, padValue);
3247 
3248  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3249  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3250  // source must fit into the destination at the specified offsets.
3251  auto writeIndices = getValueOrCreateConstantIndexOp(
3252  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3253  SmallVector<bool> inBounds(vecRank, true);
3254  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3255  insertOp, read, insertOp.getDest(), writeIndices,
3256  ArrayRef<bool>{inBounds});
3257 
3258  return success();
3259  }
3260 };
3261 
3263  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3267  patterns.getContext(), baseBenefit.getBenefit() + 1);
3268 }
3269 
3270 //----------------------------------------------------------------------------//
3271 // Forwarding patterns
3272 //----------------------------------------------------------------------------//
3273 
3274 /// Check whether there is any interleaved use of any `values` between
3275 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3276 /// is in a different block.
3277 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3278  ValueRange values) {
3279  if (firstOp->getBlock() != secondOp->getBlock() ||
3280  !firstOp->isBeforeInBlock(secondOp)) {
3281  LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3282  << ", second op: " << *secondOp;
3283  return true;
3284  }
3285  for (auto v : values) {
3286  for (auto &u : v.getUses()) {
3287  Operation *owner = u.getOwner();
3288  if (owner == firstOp || owner == secondOp)
3289  continue;
3290  // TODO: this is too conservative, use dominance info in the future.
3291  if (owner->getBlock() == firstOp->getBlock() &&
3292  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3293  continue;
3294  LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3295  << ", second op: " << *secondOp;
3296  return true;
3297  }
3298  }
3299  return false;
3300 }
3301 
3302 /// Return the unique subview use of `v` if it is indeed unique, null
3303 /// otherwise.
3304 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3305  memref::SubViewOp subViewOp;
3306  for (auto &u : v.getUses()) {
3307  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3308  if (subViewOp)
3309  return memref::SubViewOp();
3310  subViewOp = newSubViewOp;
3311  }
3312  }
3313  return subViewOp;
3314 }
3315 
3316 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3317 /// when available.
3319  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3320 
3321  // TODO: support mask.
3322  if (xferOp.getMask())
3323  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3324 
3325  // Transfer into `view`.
3326  Value viewOrAlloc = xferOp.getBase();
3327  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3328  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3329  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3330 
3331  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3332  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3333  if (!subViewOp)
3334  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3335  Value subView = subViewOp.getResult();
3336 
3337  // Find the copy into `subView` without interleaved uses.
3338  memref::CopyOp copyOp;
3339  for (auto &u : subView.getUses()) {
3340  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3341  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3342  if (newCopyOp.getTarget() != subView)
3343  continue;
3344  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3345  continue;
3346  copyOp = newCopyOp;
3347  break;
3348  }
3349  }
3350  if (!copyOp)
3351  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3352 
3353  // Find the fill into `viewOrAlloc` without interleaved uses before the
3354  // copy.
3355  FillOp maybeFillOp;
3356  for (auto &u : viewOrAlloc.getUses()) {
3357  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3358  assert(isa<MemRefType>(newFillOp.output().getType()));
3359  if (newFillOp.output() != viewOrAlloc)
3360  continue;
3361  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3362  continue;
3363  maybeFillOp = newFillOp;
3364  break;
3365  }
3366  }
3367  // Ensure padding matches.
3368  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3369  return rewriter.notifyMatchFailure(xferOp,
3370  "padding value does not match fill");
3371 
3372  // `in` is the subview that memref.copy reads. Replace it.
3373  Value in = copyOp.getSource();
3374 
3375  // memref.copy + linalg.fill can be used to create a padded local buffer.
3376  // The `masked` attribute is only valid on this padded buffer.
3377  // When forwarding to vector.transfer_read, the attribute must be reset
3378  // conservatively.
3379  auto vectorType = xferOp.getVectorType();
3380  Value res = vector::TransferReadOp::create(
3381  rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3382  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3383  rewriter.getBoolArrayAttr(
3384  SmallVector<bool>(vectorType.getRank(), false)));
3385 
3386  if (maybeFillOp)
3387  rewriter.eraseOp(maybeFillOp);
3388  rewriter.eraseOp(copyOp);
3389  rewriter.replaceOp(xferOp, res);
3390 
3391  return success();
3392 }
3393 
3394 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3395 /// when available.
3397  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3398  // TODO: support mask.
3399  if (xferOp.getMask())
3400  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3401 
3402  // Transfer into `viewOrAlloc`.
3403  Value viewOrAlloc = xferOp.getBase();
3404  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3405  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3406  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3407 
3408  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3409  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3410  if (!subViewOp)
3411  return rewriter.notifyMatchFailure(xferOp, "no subview found");
3412  Value subView = subViewOp.getResult();
3413 
3414  // Find the copy from `subView` without interleaved uses.
3415  memref::CopyOp copyOp;
3416  for (auto &u : subViewOp.getResult().getUses()) {
3417  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3418  if (newCopyOp.getSource() != subView)
3419  continue;
3420  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3421  continue;
3422  copyOp = newCopyOp;
3423  break;
3424  }
3425  }
3426  if (!copyOp)
3427  return rewriter.notifyMatchFailure(xferOp, "no copy found");
3428 
3429  // `out` is the subview copied into that we replace.
3430  assert(isa<MemRefType>(copyOp.getTarget().getType()));
3431  Value out = copyOp.getTarget();
3432 
3433  // Forward vector.transfer into copy.
3434  // memref.copy + linalg.fill can be used to create a padded local buffer.
3435  // The `masked` attribute is only valid on this padded buffer.
3436  // When forwarding to vector.transfer_write, the attribute must be reset
3437  // conservatively.
3438  auto vector = xferOp.getVector();
3439  vector::TransferWriteOp::create(
3440  rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3441  xferOp.getPermutationMapAttr(), xferOp.getMask(),
3443  dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3444 
3445  rewriter.eraseOp(copyOp);
3446  rewriter.eraseOp(xferOp);
3447 
3448  return success();
3449 }
3450 
3451 //===----------------------------------------------------------------------===//
3452 // Convolution vectorization patterns
3453 //===----------------------------------------------------------------------===//
3454 
3455 template <int N>
3456 static void bindShapeDims(ShapedType shapedType) {}
3457 
3458 template <int N, typename IntTy, typename... IntTy2>
3459 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3460  val = shapedType.getShape()[N];
3461  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3462 }
3463 
3464 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3465 template <typename... IntTy>
3466 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3467  bindShapeDims<0>(shapedType, vals...);
3468 }
3469 
3470 namespace {
3471 /// Generate a vector implementation for either:
3472 /// ```
3473 /// Op def: ( w, kw )
3474 /// Iters: ({Par(), Red()})
3475 /// Layout: {{w + kw}, {kw}, {w}}
3476 /// ```
3477 /// kw is unrolled.
3478 ///
3479 /// or
3480 ///
3481 /// ```
3482 /// Op def: ( n, w, c, kw, f )
3483 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3484 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3485 /// ```
3486 /// kw is unrolled, w is unrolled iff dilationW > 1.
3487 ///
3488 /// or
3489 ///
3490 /// ```
3491 /// Op def: ( n, c, w, f, kw )
3492 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3493 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3494 /// ```
3495 /// kw is unrolled, w is unrolled iff dilationW > 1.
3496 ///
3497 /// or
3498 ///
3499 /// ```
3500 /// Op def: ( n, w, c, kw )
3501 /// Iters: ({Par(), Par(), Par(), Red()})
3502 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3503 /// ```
3504 /// kw is unrolled, w is unrolled iff dilationW > 1.
3505 struct Conv1DGenerator
3506  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3507  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3508  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3509 
3510  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3511  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3512  resShaped = linalgOp.getDpsInitOperand(0)->get();
3513  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3514  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3515  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3516 
3517  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3518  redOp = reduceOp->getName().getIdentifier();
3519 
3520  setConvOperationKind(reduceOp);
3521 
3522  auto maybeKind = getCombinerOpKind(reduceOp);
3523  reductionKind = maybeKind.value();
3524 
3525  // The ConvolutionOpInterface gives us guarantees of existence for
3526  // strides/dilations. However, we do not need to rely on those, we can
3527  // simply use them if present, otherwise use the default and let the generic
3528  // conv. matcher in the ConvGenerator succeed or fail.
3529  auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3530  auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3531  strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3532  dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3533  }
3534 
3535  /// Generate a vector implementation for:
3536  /// ```
3537  /// Op def: ( w, kw )
3538  /// Iters: ({Par(), Red()})
3539  /// Layout: {{w + kw}, {kw}, {w}}
3540  /// ```
3541  /// kw is always unrolled.
3542  ///
3543  /// or
3544  ///
3545  /// ```
3546  /// Op def: ( n, w, c, kw, f )
3547  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3548  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3549  /// ```
3550  /// kw is always unrolled.
3551  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3552  /// > 1.
3553  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3554  int64_t nSize, wSize, cSize, kwSize, fSize;
3555  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3556  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3557  switch (conv1DOpOrder) {
3558  case Conv1DOpOrder::W:
3559  // Initialize unused dimensions
3560  nSize = fSize = cSize = 0;
3561  // out{W}
3562  bindShapeDims(resShapedType, wSize);
3563  // kernel{kw}
3564  bindShapeDims(rhsShapedType, kwSize);
3565  lhsShape = {// iw = ow + kw - 1
3566  // (i.e. 16 convolved with 3 -> 14)
3567  (wSize + kwSize - 1)};
3568  rhsShape = {kwSize};
3569  resShape = {wSize};
3570  break;
3571  case Conv1DOpOrder::Nwc:
3572  // out{n, w, f}
3573  bindShapeDims(resShapedType, nSize, wSize, fSize);
3574  switch (oper) {
3575  case ConvOperationKind::Conv:
3576  // kernel{kw, c, f}
3577  bindShapeDims(rhsShapedType, kwSize, cSize);
3578  break;
3579  case ConvOperationKind::Pool:
3580  // kernel{kw}
3581  bindShapeDims(rhsShapedType, kwSize);
3582  cSize = fSize;
3583  break;
3584  }
3585  lhsShape = {nSize,
3586  // iw = ow * sw + kw * dw - 1
3587  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3588  // Perform the proper inclusive -> exclusive -> inclusive.
3589  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3590  1,
3591  cSize};
3592  switch (oper) {
3593  case ConvOperationKind::Conv:
3594  rhsShape = {kwSize, cSize, fSize};
3595  break;
3596  case ConvOperationKind::Pool:
3597  rhsShape = {kwSize};
3598  break;
3599  }
3600  resShape = {nSize, wSize, fSize};
3601  break;
3602  case Conv1DOpOrder::Ncw:
3603  // out{n, f, w}
3604  bindShapeDims(resShapedType, nSize, fSize, wSize);
3605  switch (oper) {
3606  case ConvOperationKind::Conv:
3607  // kernel{f, c, kw}
3608  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3609  break;
3610  case ConvOperationKind::Pool:
3611  // kernel{kw}
3612  bindShapeDims(rhsShapedType, kwSize);
3613  cSize = fSize;
3614  break;
3615  }
3616  lhsShape = {nSize, cSize,
3617  // iw = ow * sw + kw * dw - 1
3618  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3619  // Perform the proper inclusive -> exclusive -> inclusive.
3620  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3621  1};
3622  switch (oper) {
3623  case ConvOperationKind::Conv:
3624  rhsShape = {fSize, cSize, kwSize};
3625  break;
3626  case ConvOperationKind::Pool:
3627  rhsShape = {kwSize};
3628  break;
3629  }
3630  resShape = {nSize, fSize, wSize};
3631  break;
3632  }
3633 
3634  vector::TransferWriteOp write;
3635  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3636 
3637  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3638  // When strideW == 1, we can batch the contiguous loads and avoid
3639  // unrolling
3640  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3641 
3642  Type lhsEltType = lhsShapedType.getElementType();
3643  Type rhsEltType = rhsShapedType.getElementType();
3644  Type resEltType = resShapedType.getElementType();
3645  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3646  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3647  auto resType = VectorType::get(resShape, resEltType);
3648  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3649  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3650  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3651  SmallVector<Value> resPadding(resShape.size(), zero);
3652 
3653  // Read the whole lhs, rhs and res in one shot (with zero padding).
3654  Value lhs = vector::TransferReadOp::create(
3655  rewriter, loc, lhsType, lhsShaped, lhsPadding,
3656  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3657  // This is needed only for Conv.
3658  Value rhs = nullptr;
3659  if (oper == ConvOperationKind::Conv)
3660  rhs = vector::TransferReadOp::create(
3661  rewriter, loc, rhsType, rhsShaped, rhsPadding,
3662  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3663  Value res = vector::TransferReadOp::create(
3664  rewriter, loc, resType, resShaped, resPadding,
3665  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3666 
3667  // The base vectorization case for channeled convolution is input:
3668  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3669  // vectorization case, we do pre transpose on input, weight, and output.
3670  switch (conv1DOpOrder) {
3671  case Conv1DOpOrder::W:
3672  case Conv1DOpOrder::Nwc:
3673  // Base case, so no transposes necessary.
3674  break;
3675  case Conv1DOpOrder::Ncw: {
3676  // To match base vectorization case, we pre-transpose current case.
3677  // ncw -> nwc
3678  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3679  lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3680  // fcw -> wcf
3681  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3682 
3683  // This is needed only for Conv.
3684  if (oper == ConvOperationKind::Conv)
3685  rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3686  // nfw -> nwf
3687  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3688  res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3689  break;
3690  }
3691  }
3692 
3693  //===------------------------------------------------------------------===//
3694  // Begin vector-only rewrite part
3695  //===------------------------------------------------------------------===//
3696  // Unroll along kw and read slices of lhs and rhs.
3697  SmallVector<Value> lhsVals, rhsVals, resVals;
3698  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3699  kwSize, strideW, dilationW, wSizeStep,
3700  isSingleChanneled);
3701  // Do not do for pooling.
3702  if (oper == ConvOperationKind::Conv)
3703  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3704  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3705  wSizeStep, isSingleChanneled);
3706 
3707  auto linearIndex = [&](int64_t kw, int64_t w) {
3708  return kw * (wSize / wSizeStep) + w;
3709  };
3710 
3711  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3712  // or perform outerproduct for non-channeled convolution or perform simple
3713  // arith operation for pooling
3714  for (int64_t kw = 0; kw < kwSize; ++kw) {
3715  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3716  switch (oper) {
3717  case ConvOperationKind::Conv:
3718  if (isSingleChanneled) {
3719  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3720  lhsVals[linearIndex(kw, w)],
3721  rhsVals[kw], resVals[w]);
3722  } else {
3723  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3724  lhsVals[linearIndex(kw, w)],
3725  rhsVals[kw], resVals[w]);
3726  }
3727  break;
3728  case ConvOperationKind::Pool:
3729  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3730  resVals[w]);
3731  break;
3732  }
3733  }
3734  }
3735 
3736  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3737  isSingleChanneled);
3738  //===------------------------------------------------------------------===//
3739  // End vector-only rewrite part
3740  //===------------------------------------------------------------------===//
3741 
3742  // The base vectorization case for channeled convolution is output:
3743  // {n,w,f} To reuse the result from base pattern vectorization case, we
3744  // post transpose the base case result.
3745  switch (conv1DOpOrder) {
3746  case Conv1DOpOrder::W:
3747  case Conv1DOpOrder::Nwc:
3748  // Base case, so no transposes necessary.
3749  break;
3750  case Conv1DOpOrder::Ncw: {
3751  // nwf -> nfw
3752  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3753  res = vector::TransposeOp::create(rewriter, loc, res, perm);
3754  break;
3755  }
3756  }
3757 
3758  return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3759  resPadding)
3760  .getOperation();
3761  }
3762 
3763  // Take a value and widen to have the same element type as `ty`.
3764  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3765  const Type srcElementType = getElementTypeOrSelf(val.getType());
3766  const Type dstElementType = getElementTypeOrSelf(ty);
3767  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3768  if (srcElementType == dstElementType)
3769  return val;
3770 
3771  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3772  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3773  const Type dstType =
3774  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3775 
3776  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3777  return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3778  }
3779 
3780  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3781  srcWidth < dstWidth)
3782  return arith::ExtFOp::create(rewriter, loc, dstType, val);
3783 
3784  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3785  srcWidth < dstWidth)
3786  return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3787 
3788  assert(false && "unhandled promotion case");
3789  return nullptr;
3790  }
3791 
3792  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3793  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3794  Value lhs, Value rhs, Value res) {
3795  vector::IteratorType par = vector::IteratorType::parallel;
3796  vector::IteratorType red = vector::IteratorType::reduction;
3797  AffineExpr n, w, f, c;
3798  bindDims(ctx, n, w, f, c);
3799  lhs = promote(rewriter, loc, lhs, res.getType());
3800  rhs = promote(rewriter, loc, rhs, res.getType());
3801  auto contrationOp = vector::ContractionOp::create(
3802  rewriter, loc, lhs, rhs, res,
3803  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3804  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3805  contrationOp.setKind(reductionKind);
3806  return contrationOp;
3807  }
3808 
3809  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3810  // convolution.
3811  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3812  Value lhs, Value rhs, Value res) {
3813  return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3814  rhs, res, vector::CombiningKind::ADD);
3815  }
3816 
3817  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3818  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3819  Value res) {
3820  if (isPoolExt)
3821  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3822  return rewriter
3823  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3824  ->getResult(0);
3825  }
3826 
3827  /// Generate a vector implementation for:
3828  /// ```
3829  /// Op def: ( n, w, c, kw)
3830  /// Iters: ({Par(), Par(), Par(), Red()})
3831  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3832  /// ```
3833  /// kw is always unrolled.
3834  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3835  /// > 1.
3836  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3837  bool channelDimScalableFlag,
3838  bool flatten) {
3839  bool scalableChDim = false;
3840  bool useMasking = false;
3841  int64_t nSize, wSize, cSize, kwSize;
3842  // kernel{kw, c}
3843  bindShapeDims(rhsShapedType, kwSize, cSize);
3844  if (ShapedType::isDynamic(cSize)) {
3845  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3846  cSize = channelDimVecSize;
3847  // Scalable vectors are only used when both conditions are met:
3848  // 1. channel dim is dynamic
3849  // 2. channelDimScalableFlag is set
3850  scalableChDim = channelDimScalableFlag;
3851  useMasking = true;
3852  }
3853 
3854  assert(!(useMasking && flatten) &&
3855  "Unsupported flattened conv with dynamic shapes");
3856 
3857  // out{n, w, c}
3858  bindShapeDims(resShapedType, nSize, wSize);
3859 
3860  vector::TransferWriteOp write;
3861  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3862 
3863  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3864  // When strideW == 1, we can batch the contiguous loads and avoid
3865  // unrolling
3866  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3867 
3868  Type lhsEltType = lhsShapedType.getElementType();
3869  Type rhsEltType = rhsShapedType.getElementType();
3870  Type resEltType = resShapedType.getElementType();
3871  VectorType lhsType = VectorType::get(
3872  {nSize,
3873  // iw = ow * sw + kw * dw - 1
3874  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3875  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3876  cSize},
3877  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3878  VectorType rhsType =
3879  VectorType::get({kwSize, cSize}, rhsEltType,
3880  /*scalableDims=*/{false, scalableChDim});
3881  VectorType resType =
3882  VectorType::get({nSize, wSize, cSize}, resEltType,
3883  /*scalableDims=*/{false, false, scalableChDim});
3884 
3885  // Masks the input xfer Op along the channel dim, iff the corresponding
3886  // scalable flag is set.
3887  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3888  ArrayRef<bool> scalableDims,
3889  Operation *opToMask) {
3890  if (!useMasking)
3891  return opToMask;
3892  auto maskType =
3893  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3894 
3895  SmallVector<bool> inBounds(maskShape.size(), true);
3896  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3897  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3898  rewriter.getBoolArrayAttr(inBounds));
3899 
3901  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3902 
3903  Value maskOp =
3904  vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3905 
3906  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3907  };
3908 
3909  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3910  // 0].
3911  Value lhs = vector::TransferReadOp::create(
3912  rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3913  /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3914  auto maybeMaskedLhs = maybeMaskXferOp(
3915  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3916 
3917  // Read rhs slice of size {kw, c} @ [0, 0].
3918  Value rhs = vector::TransferReadOp::create(
3919  rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3920  /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3921  auto maybeMaskedRhs = maybeMaskXferOp(
3922  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3923 
3924  // Read res slice of size {n, w, c} @ [0, 0, 0].
3925  Value res = vector::TransferReadOp::create(
3926  rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
3927  /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3928  auto maybeMaskedRes = maybeMaskXferOp(
3929  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3930 
3931  //===------------------------------------------------------------------===//
3932  // Begin vector-only rewrite part
3933  //===------------------------------------------------------------------===//
3934  // Unroll along kw and read slices of lhs and rhs.
3935  SmallVector<Value> lhsVals, rhsVals, resVals;
3936  SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3937  SmallVector<int64_t> inOutStrides = {1, 1, 1};
3938 
3939  // Extract lhs slice of size {n, wSizeStep, c}
3940  // @ [0, sw * w + dw * kw, 0].
3941  for (int64_t kw = 0; kw < kwSize; ++kw) {
3942  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3943  lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3944  rewriter, loc, maybeMaskedLhs->getResult(0),
3945  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3946  inOutSliceSizes, inOutStrides));
3947  }
3948  }
3949  // Extract rhs slice of size {c} @ [kw].
3950  for (int64_t kw = 0; kw < kwSize; ++kw) {
3951  rhsVals.push_back(
3952  vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3953  /*offsets=*/ArrayRef<int64_t>{kw}));
3954  }
3955  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3956  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3957  resVals.push_back(vector::ExtractStridedSliceOp::create(
3958  rewriter, loc, maybeMaskedRes->getResult(0),
3959  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3960  inOutStrides));
3961  }
3962 
3963  auto linearIndex = [&](int64_t kw, int64_t w) {
3964  return kw * (wSize / wSizeStep) + w;
3965  };
3966 
3967  // Note - the scalable flags are ignored as flattening combined with
3968  // scalable vectorization is not supported.
3969  SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3970  auto lhsTypeAfterFlattening =
3971  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3972  auto resTypeAfterFlattening =
3973  VectorType::get(inOutFlattenSliceSizes, resEltType);
3974 
3975  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3976  for (int64_t kw = 0; kw < kwSize; ++kw) {
3977  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3978  Value lhsVal = lhsVals[linearIndex(kw, w)];
3979  Value resVal = resVals[w];
3980  if (flatten) {
3981  // Flatten the input and output vectors (collapse the channel
3982  // dimension)
3983  lhsVal =
3984  vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3985  lhsVals[linearIndex(kw, w)]);
3986  resVal = vector::ShapeCastOp::create(
3987  rewriter, loc, resTypeAfterFlattening, resVals[w]);
3988  }
3989  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3990  rhsVals[kw], resVal, flatten);
3991  if (flatten) {
3992  // Un-flatten the output vector (restore the channel dimension)
3993  resVals[w] = vector::ShapeCastOp::create(
3994  rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
3995  resVals[w]);
3996  }
3997  }
3998  }
3999 
4000  // Its possible we failed to create the Fma.
4001  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
4002  // Manually revert (in reverse order) to avoid leaving a bad IR state.
4003  for (auto &collection :
4004  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
4005  for (Value v : collection)
4006  rewriter.eraseOp(v.getDefiningOp());
4007  return rewriter.notifyMatchFailure(op, "failed to create FMA");
4008  }
4009 
4010  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
4011  // This does not depend on kw.
4012  for (int64_t w = 0; w < wSize; w += wSizeStep) {
4013  maybeMaskedRes = vector::InsertStridedSliceOp::create(
4014  rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4015  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
4016  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
4017  }
4018  //===------------------------------------------------------------------===//
4019  // End vector-only rewrite part
4020  //===------------------------------------------------------------------===//
4021 
4022  // Write back res slice of size {n, w, c} @ [0, 0, 0].
4023  Operation *resOut = vector::TransferWriteOp::create(
4024  rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4025  ValueRange{zero, zero, zero});
4026  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4027  resOut);
4028  }
4029 
4030  /// Lower:
4031  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4032  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4033  /// to MulAcc.
4034  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4035  Value lhs, Value rhs, Value res,
4036  bool flatten) {
4037  auto rhsTy = cast<ShapedType>(rhs.getType());
4038  auto resTy = cast<ShapedType>(res.getType());
4039 
4040  // TODO(suderman): Change this to use a vector.ima intrinsic.
4041  lhs = promote(rewriter, loc, lhs, resTy);
4042 
4043  if (flatten) {
4044  // NOTE: This following logic won't work for scalable vectors. For this
4045  // reason, "flattening" is not supported when shapes are dynamic (this
4046  // should be captured by one of the pre-conditions).
4047 
4048  // There are two options for handling the filter:
4049  // * shape_cast(broadcast(filter))
4050  // * broadcast(shuffle(filter))
4051  // Opt for the option without shape_cast to simplify the codegen.
4052  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4053  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4054 
4055  SmallVector<int64_t, 16> indices;
4056  for (int i = 0; i < resSize / rhsSize; ++i) {
4057  for (int j = 0; j < rhsSize; ++j)
4058  indices.push_back(j);
4059  }
4060 
4061  rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4062  }
4063  // Broadcast the filter to match the output vector
4064  rhs = vector::BroadcastOp::create(rewriter, loc,
4065  resTy.clone(rhsTy.getElementType()), rhs);
4066 
4067  rhs = promote(rewriter, loc, rhs, resTy);
4068 
4069  if (!lhs || !rhs)
4070  return nullptr;
4071 
4072  if (isa<FloatType>(resTy.getElementType()))
4073  return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4074 
4075  auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4076  return arith::AddIOp::create(rewriter, loc, mul, res);
4077  }
4078 
4079  /// Entry point for non-channeled convolution:
4080  /// {{w + kw}, {kw}, {w}}
4081  FailureOr<Operation *> generateNonChanneledConv() {
4082  AffineExpr w, kw;
4083  bindDims(ctx, w, kw);
4084  if (!iters({Par(), Red()}))
4085  return rewriter.notifyMatchFailure(op,
4086  "failed to match conv::W 1-par 1-red");
4087 
4088  // No transposition needed.
4089  if (layout({/*lhsIndex*/ {w + kw},
4090  /*rhsIndex*/ {kw},
4091  /*resIndex*/ {w}}))
4092  return conv(Conv1DOpOrder::W);
4093 
4094  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4095  }
4096 
4097  /// Entry point that transposes into the common form:
4098  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4099  FailureOr<Operation *> generateNwcConv() {
4100  AffineExpr n, w, f, kw, c;
4101  bindDims(ctx, n, w, f, kw, c);
4102  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4103  return rewriter.notifyMatchFailure(
4104  op, "failed to match conv::Nwc 3-par 2-red");
4105 
4106  // No transposition needed.
4107  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4108  /*rhsIndex*/ {kw, c, f},
4109  /*resIndex*/ {n, w, f}}))
4110  return conv(Conv1DOpOrder::Nwc);
4111 
4112  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4113  }
4114 
4115  /// Entry point that transposes into the common form:
4116  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4117  FailureOr<Operation *> generateNcwConv() {
4118  AffineExpr n, w, f, kw, c;
4119  bindDims(ctx, n, f, w, c, kw);
4120  if (!iters({Par(), Par(), Par(), Red(), Red()}))
4121  return rewriter.notifyMatchFailure(
4122  op, "failed to match conv::Ncw 3-par 2-red");
4123 
4124  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4125  /*rhsIndex*/ {f, c, kw},
4126  /*resIndex*/ {n, f, w}}))
4127  return conv(Conv1DOpOrder::Ncw);
4128 
4129  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4130  }
4131 
4132  /// Entry point that transposes into the common form:
4133  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4134  FailureOr<Operation *> generateNwcPooling() {
4135  AffineExpr n, w, c, kw;
4136  bindDims(ctx, n, w, c, kw);
4137  if (!iters({Par(), Par(), Par(), Red()}))
4138  return rewriter.notifyMatchFailure(op,
4139  "failed to match pooling 3-par 1-red");
4140 
4141  // No transposition needed.
4142  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4143  /*rhsIndex*/ {kw},
4144  /*resIndex*/ {n, w, c}}))
4145  return conv(Conv1DOpOrder::Nwc);
4146 
4147  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4148  }
4149 
4150  /// Entry point that transposes into the common form:
4151  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4152  FailureOr<Operation *> generateNcwPooling() {
4153  AffineExpr n, w, c, kw;
4154  bindDims(ctx, n, c, w, kw);
4155  if (!iters({Par(), Par(), Par(), Red()}))
4156  return rewriter.notifyMatchFailure(op,
4157  "failed to match pooling 3-par 1-red");
4158 
4159  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4160  /*rhsIndex*/ {kw},
4161  /*resIndex*/ {n, c, w}}))
4162  return conv(Conv1DOpOrder::Ncw);
4163 
4164  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4165  }
4166 
4167  /// Entry point that transposes into the common form:
4168  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4169  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4170  bool vecChDimScalableFlag = false,
4171  bool flatten = false) {
4172  AffineExpr n, w, c, kw;
4173  bindDims(ctx, n, w, c, kw);
4174  if (!iters({Par(), Par(), Par(), Red()}))
4175  return rewriter.notifyMatchFailure(
4176  op, "failed to match depthwise::Nwc conv 3-par 1-red");
4177 
4178  // No transposition needed.
4179  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4180  /*rhsIndex*/ {kw, c},
4181  /*resIndex*/ {n, w, c}}))
4182  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4183 
4184  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4185  }
4186 
4187 private:
4188  ConvOperationKind oper = ConvOperationKind::Conv;
4189  StringAttr redOp;
4190  StringAttr poolExtOp;
4191  bool isPoolExt = false;
4192  int strideW, dilationW;
4193  Value lhsShaped, rhsShaped, resShaped;
4194  ShapedType lhsShapedType, rhsShapedType, resShapedType;
4195  vector::CombiningKind reductionKind;
4196 
4197  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4198  void setConvOperationKind(Operation *reduceOp) {
4199  int numBlockArguments =
4200  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4201  if (numBlockArguments == 1) {
4202  // Will be convolution if feeder is a MulOp.
4203  // A strength reduced version of MulOp for i1 type is AndOp which is also
4204  // supported. Otherwise, it can be pooling. This strength reduction logic
4205  // is in `buildBinaryFn` helper in the Linalg dialect.
4206  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4207  llvm::IsaPred<BlockArgument>);
4208  Operation *feedOp = (*feedValIt).getDefiningOp();
4209  if (isCastOfBlockArgument(feedOp)) {
4210  oper = ConvOperationKind::Pool;
4211  isPoolExt = true;
4212  poolExtOp = feedOp->getName().getIdentifier();
4213  return;
4214  }
4215  oper = ConvOperationKind::Conv;
4216  return;
4217  }
4218  // numBlockArugments == 2 and this is a pooling op.
4219  oper = ConvOperationKind::Pool;
4220  isPoolExt = false;
4221  }
4222 };
4223 } // namespace
4224 
4225 /// Helper function to vectorize a LinalgOp with convolution semantics.
4226 // TODO: extend the generic vectorization to support windows and drop this.
4227 static FailureOr<Operation *> vectorizeConvolution(
4228  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4229  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4230  Conv1DGenerator conv1dGen(rewriter, op);
4231  auto res = conv1dGen.generateNonChanneledConv();
4232  if (succeeded(res))
4233  return res;
4234  res = conv1dGen.generateNwcConv();
4235  if (succeeded(res))
4236  return res;
4237  res = conv1dGen.generateNcwConv();
4238  if (succeeded(res))
4239  return res;
4240  res = conv1dGen.generateNwcPooling();
4241  if (succeeded(res))
4242  return res;
4243  res = conv1dGen.generateNcwPooling();
4244  if (succeeded(res))
4245  return res;
4246 
4247  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4248  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4249  // masked/scalable) is the channel dim (i.e. the trailing dim).
4250  uint64_t vecChDimSize = ShapedType::kDynamic;
4251  bool vecChDimScalableFlag = false;
4252  if (!inputVecSizes.empty()) {
4253  // Only use the input vector size corresponding to the channel dim. Other
4254  // vector dims will be inferred from the Ops.
4255  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4256  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4257  "Not a 1D depthwise conv!");
4258  size_t chDimIdx =
4260  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4261  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4262 
4263  vecChDimSize = inputVecSizes[chDimIdx];
4264  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4265  }
4266  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4267  flatten1DDepthwiseConv);
4268 }
4269 
4272 
4273  LogicalResult matchAndRewrite(LinalgOp op,
4274  PatternRewriter &rewriter) const override {
4275  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4276  if (failed(resultOrFail))
4277  return failure();
4278  Operation *newOp = *resultOrFail;
4279  if (newOp->getNumResults() == 0) {
4280  rewriter.eraseOp(op.getOperation());
4281  return success();
4282  }
4283  assert(newOp->getNumResults() == 1 && "expected single result");
4284  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4285  return success();
4286  }
4287 };
4288 
4291  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4292 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:5180
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:138
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:600
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:308
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:387
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:270
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:218
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition: ArithOps.cpp:380
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:200
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:699
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2916
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:333
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:338
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Definition: Transforms.h:885
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.