MLIR  20.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66  OpType res;
67  block.walk([&](OpType op) {
68  if (res) {
69  res = nullptr;
70  return WalkResult::interrupt();
71  }
72  res = op;
73  return WalkResult::advance();
74  });
75  return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
82  int64_t nSize, int64_t wSize, int64_t cSize,
83  int64_t kwSize, int strideW, int dilationW,
84  int64_t wSizeStep, bool isSingleChanneled) {
85  SmallVector<Value> result;
86  if (isSingleChanneled) {
87  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88  // convolution.
89  SmallVector<int64_t> sizes{wSizeStep};
90  SmallVector<int64_t> strides{1};
91  for (int64_t kw = 0; kw < kwSize; ++kw) {
92  for (int64_t w = 0; w < wSize; w += wSizeStep) {
93  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95  }
96  }
97  } else {
98  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99  // for channeled convolution.
100  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101  SmallVector<int64_t> strides{1, 1, 1};
102  for (int64_t kw = 0; kw < kwSize; ++kw) {
103  for (int64_t w = 0; w < wSize; w += wSizeStep) {
104  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105  loc, input,
106  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107  sizes, strides));
108  }
109  }
110  }
111  return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
117  Location loc, Value filter,
118  int64_t kwSize) {
119  SmallVector<Value> result;
120  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121  // non-chanelled convolution] @ [kw].
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  result.push_back(rewriter.create<vector::ExtractOp>(
124  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125  }
126  return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
133  int64_t nSize, int64_t wSize, int64_t fSize,
134  int64_t wSizeStep, bool isSingleChanneled) {
135  SmallVector<Value> result;
136  if (isSingleChanneled) {
137  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138  SmallVector<int64_t> sizes{wSizeStep};
139  SmallVector<int64_t> strides{1};
140  for (int64_t w = 0; w < wSize; w += wSizeStep) {
141  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143  }
144  } else {
145  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146  // convolution.
147  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148  SmallVector<int64_t> strides{1, 1, 1};
149  for (int64_t w = 0; w < wSize; w += wSizeStep) {
150  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152  }
153  }
154  return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
159  Value res, int64_t wSize, int64_t wSizeStep,
160  SmallVectorImpl<Value> &resVals,
161  bool isSingleChanneled) {
162 
163  if (isSingleChanneled) {
164  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165  // This does not depend on kw.
166  SmallVector<int64_t> strides{1};
167  for (int64_t w = 0; w < wSize; w += wSizeStep) {
168  res = rewriter.create<vector::InsertStridedSliceOp>(
169  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170  }
171  } else {
172  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173  // convolution. This does not depend on kw.
174  SmallVector<int64_t> strides{1, 1, 1};
175  for (int64_t w = 0; w < wSize; w += wSizeStep) {
176  res = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178  strides);
179  }
180  }
181  return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
187  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189  /// Initializes the vectorization state, including the computation of the
190  /// canonical vector shape for vectorization.
191  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192  ArrayRef<int64_t> inputVectorSizes,
193  ArrayRef<bool> inputScalableVecDims);
194 
195  /// Returns the canonical vector shape used to vectorize the iteration space.
196  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198  /// Returns the vector dimensions that are scalable in the canonical vector
199  /// shape.
200  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202  /// Returns a vector type of the provided `elementType` with the canonical
203  /// vector shape and the corresponding fixed/scalable dimensions bit. If
204  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205  /// accordingly.
207  Type elementType,
208  std::optional<AffineMap> dimPermutation = std::nullopt) const {
210  SmallVector<bool> scalableDims;
211  if (dimPermutation.has_value()) {
212  vectorShape =
213  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214  scalableDims =
215  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216  } else {
217  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219  }
220 
221  return VectorType::get(vectorShape, elementType, scalableDims);
222  }
223 
224  /// Masks an operation with the canonical vector mask if the operation needs
225  /// masking. Returns the masked operation or the original operation if masking
226  /// is not needed. If provided, the canonical mask for this operation is
227  /// permuted using `maybeIndexingMap`.
228  Operation *
229  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231 
232 private:
233  /// Initializes the iteration space static sizes using the Linalg op
234  /// information. This may become more complicated in the future.
235  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237  }
238 
239  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240  /// all the static and dynamic dimensions of the iteration space to be
241  /// vectorized and store them in `iterSpaceValueSizes`.
242  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243  LinalgOp linalgOp);
244 
245  /// Create or retrieve an existing mask value to mask `opToMask` in the
246  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247  /// permuted using that permutation map. If a new mask is created, it will be
248  /// cached for future users.
249  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250  LinalgOp linalgOp,
251  std::optional<AffineMap> maybeMaskingMap);
252 
253  /// Check whether this permutation map can be used for masking. At the
254  /// moment we only make sure that there are no broadcast dimensions, but this
255  /// might change if indexing maps evolve.
256  bool isValidMaskingMap(AffineMap maskingMap) {
257  return maskingMap.getBroadcastDims().size() == 0;
258  }
259 
260  /// Turn the input indexing map into a valid masking map.
261  ///
262  /// The input indexing map may contain "zero" results, e.g.:
263  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264  /// Applying such maps to canonical vector shapes like this one:
265  /// (1, 16, 16, 4)
266  /// would yield an invalid vector shape like this:
267  /// (16, 16, 1, 0)
268  /// Instead, drop the broadcasting dims that make no sense for masking perm.
269  /// maps:
270  /// (d0, d1, d2, d3) -> (d2, d1, d0)
271  /// This way, the corresponding vector/mask type will be:
272  /// vector<16x16x1xty>
273  /// rather than this invalid Vector type:
274  /// vector<16x16x1x0xty>
275  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
276  return indexingMap.dropZeroResults();
277  }
278 
279  // Holds the compile-time static sizes of the iteration space to vectorize.
280  // Dynamic dimensions are represented using ShapedType::kDynamic.
281  SmallVector<int64_t> iterSpaceStaticSizes;
282 
283  /// Holds the value sizes of the iteration space to vectorize. Static
284  /// dimensions are represented by 'arith.constant' and dynamic
285  /// dimensions by 'tensor/memref.dim'.
286  SmallVector<Value> iterSpaceValueSizes;
287 
288  /// Holds the canonical vector shape used to vectorize the iteration space.
289  SmallVector<int64_t> canonicalVecShape;
290 
291  /// Holds the vector dimensions that are scalable in the canonical vector
292  /// shape.
293  SmallVector<bool> scalableVecDims;
294 
295  /// Holds the active masks for permutations of the canonical vector iteration
296  /// space.
297  DenseMap<AffineMap, Value> activeMaskCache;
298 
299  /// Global vectorization guard for the incoming rewriter. It's initialized
300  /// when the vectorization state is initialized.
302 };
303 
304 LogicalResult
305 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
306  LinalgOp linalgOp) {
307  // TODO: Support 0-d vectors.
308  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
310  // Create constant index op for static dimensions.
311  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
312  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
313  continue;
314  }
315 
316  // Find an operand defined on this dimension of the iteration space to
317  // extract the runtime dimension size.
318  Value operand;
319  unsigned operandDimPos;
320  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
321  operandDimPos)))
322  return failure();
323 
324  Value dynamicDim = linalgOp.hasPureTensorSemantics()
325  ? (Value)rewriter.create<tensor::DimOp>(
326  linalgOp.getLoc(), operand, operandDimPos)
327  : (Value)rewriter.create<memref::DimOp>(
328  linalgOp.getLoc(), operand, operandDimPos);
329  iterSpaceValueSizes.push_back(dynamicDim);
330  }
331 
332  return success();
333 }
334 
335 /// Initializes the vectorization state, including the computation of the
336 /// canonical vector shape for vectorization.
337 // TODO: Move this to the constructor when we can remove the failure cases.
338 LogicalResult
339 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
340  ArrayRef<int64_t> inputVectorSizes,
341  ArrayRef<bool> inputScalableVecDims) {
342  // Initialize the insertion point.
343  rewriter.setInsertionPoint(linalgOp);
344 
345  if (!inputVectorSizes.empty()) {
346  // Get the canonical vector shape from the input vector sizes provided. This
347  // path should be taken to vectorize code with dynamic shapes and when using
348  // vector sizes greater than the iteration space sizes.
349  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350  scalableVecDims.append(inputScalableVecDims.begin(),
351  inputScalableVecDims.end());
352  } else {
353  // Compute the canonical vector shape from the operation shape. If there are
354  // dynamic shapes, the operation won't be vectorized. We assume all the
355  // vector dimensions are fixed.
356  canonicalVecShape = linalgOp.getStaticLoopRanges();
357  scalableVecDims.append(linalgOp.getNumLoops(), false);
358  }
359 
360  LDBG("Canonical vector shape: ");
361  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362  LLVM_DEBUG(llvm::dbgs() << "\n");
363  LDBG("Scalable vector dims: ");
364  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365  LLVM_DEBUG(llvm::dbgs() << "\n");
366 
367  if (ShapedType::isDynamicShape(canonicalVecShape))
368  return failure();
369 
370  // Initialize iteration space static sizes.
371  initIterSpaceStaticSizes(linalgOp);
372 
373  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
374  // all the static and dynamic dimensions of the iteration space, needed to
375  // compute a mask during vectorization.
376  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
377  return failure();
378 
379  return success();
380 }
381 
382 /// Create or retrieve an existing mask value to mask `opToMask` in the
383 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
384 /// using that permutation map. If a new mask is created, it will be cached for
385 /// future users.
386 Value VectorizationState::getOrCreateMaskFor(
387  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
388  std::optional<AffineMap> maybeMaskingMap) {
389 
390  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391  "Ill-formed masking map.");
392 
393  // No mask is needed if the operation is not maskable.
394  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
395  if (!maskableOp)
396  return Value();
397 
398  assert(!maskableOp.isMasked() &&
399  "Masking an operation that is already masked");
400 
401  // If no masking map was provided, use an identity map with the loop dims.
402  assert((!maybeMaskingMap || *maybeMaskingMap) &&
403  "Unexpected null mask permutation map");
404  AffineMap maskingMap =
405  maybeMaskingMap ? *maybeMaskingMap
407  linalgOp.getNumLoops(), rewriter.getContext());
408 
409  LDBG("Masking map: " << maskingMap << "\n");
410 
411  // Return the active mask for the masking map of this operation if it was
412  // already created.
413  auto activeMaskIt = activeMaskCache.find(maskingMap);
414  if (activeMaskIt != activeMaskCache.end()) {
415  Value mask = activeMaskIt->second;
416  LDBG("Reusing mask: " << mask << "\n");
417  return mask;
418  }
419 
420  // Compute permuted projection of the iteration space to be masked and the
421  // corresponding mask shape. If the resulting iteration space dimensions are
422  // static and identical to the mask shape, masking is not needed for this
423  // operation.
424  // TODO: Improve this check. Only projected permutation indexing maps are
425  // supported.
426  SmallVector<int64_t> permutedStaticSizes =
427  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
429  auto maskShape = maskType.getShape();
430 
431  LDBG("Mask shape: ");
432  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433  LLVM_DEBUG(llvm::dbgs() << "\n");
434 
435  if (permutedStaticSizes == maskShape) {
436  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
437  activeMaskCache[maskingMap] = Value();
438  return Value();
439  }
440 
441  // Permute the iteration space value sizes to compute the mask upper bounds.
442  SmallVector<Value> upperBounds =
443  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
444  assert(!maskShape.empty() && !upperBounds.empty() &&
445  "Masked 0-d vectors are not supported yet");
446 
447  // Create the mask based on the dimension values.
448  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
449  maskType, upperBounds);
450  LDBG("Creating new mask: " << mask << "\n");
451  activeMaskCache[maskingMap] = mask;
452  return mask;
453 }
454 
455 Operation *
457  LinalgOp linalgOp,
458  std::optional<AffineMap> maybeIndexingMap) {
459  LDBG("Trying to mask: " << *opToMask << "\n");
460 
461  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462  if (maybeIndexingMap)
463  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
464 
465  // Create or retrieve mask for this operation.
466  Value mask =
467  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
468 
469  if (!mask) {
470  LDBG("No mask required\n");
471  return opToMask;
472  }
473 
474  // Wrap the operation with a new `vector.mask` and update D-U chain.
475  assert(opToMask && "Expected a valid operation to mask");
476  auto maskOp = cast<vector::MaskOp>(
477  mlir::vector::maskOperation(rewriter, opToMask, mask));
478  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
479 
480  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
481  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
482  maskOpTerminator);
483 
484  LDBG("Masked operation: " << *maskOp << "\n");
485  return maskOp;
486 }
487 
488 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
489 /// projectedPermutation, compress the unused dimensions to serve as a
490 /// permutation_map for a vector transfer operation.
491 /// For example, given a linalg op such as:
492 ///
493 /// ```
494 /// %0 = linalg.generic {
495 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
496 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
497 /// }
498 /// ins(%0 : tensor<2x3x4xf32>)
499 /// outs(%1 : tensor<5x6xf32>)
500 /// ```
501 ///
502 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
503 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
504 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
506  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
507  "expected projected permutation");
508  auto res = compressUnusedDims(map);
509  assert(res.getNumDims() ==
510  (res.getNumResults() - res.getNumOfZeroResults()) &&
511  "expected reindexed map with same number of dims and results");
512  return res;
513 }
514 
515 /// Helper enum to represent conv1d input traversal order.
516 enum class Conv1DOpOrder {
517  W, // Corresponds to non-channeled 1D convolution operation.
518  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
519  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
520 };
521 
522 /// Helper data structure to represent the result of vectorization.
523 /// In certain specific cases, like terminators, we do not want to propagate/
525  /// Op failed to vectorize.
526  Failure = 0,
527  /// Op vectorized and custom function took care of replacement logic
529  /// Op vectorized into a new Op whose results will replace original Op's
530  /// results.
531  NewOp
532  // TODO: support values if Op vectorized to Many-Ops whose results we need to
533  // aggregate for replacement.
534 };
536  /// Return status from vectorizing the current op.
538  /// New vectorized operation to replace the current op.
539  /// Replacement behavior is specified by `status`.
541 };
542 
543 std::optional<vector::CombiningKind>
545  using ::mlir::vector::CombiningKind;
546 
547  if (!combinerOp)
548  return std::nullopt;
550  .Case<arith::AddIOp, arith::AddFOp>(
551  [&](auto op) { return CombiningKind::ADD; })
552  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
553  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
554  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
555  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
556  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
557  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
558  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
559  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
560  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
561  .Case<arith::MulIOp, arith::MulFOp>(
562  [&](auto op) { return CombiningKind::MUL; })
563  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
564  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
565  .Default([&](auto op) { return std::nullopt; });
566 }
567 
568 /// Check whether `outputOperand` is a reduction with a single combiner
569 /// operation. Return the combiner operation of the reduction. Return
570 /// nullptr otherwise. Multiple reduction operations would impose an
571 /// ordering between reduction dimensions and is currently unsupported in
572 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
573 /// max(min(X))
574 // TODO: use in LinalgOp verification, there is a circular dependency atm.
575 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
576  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
577  unsigned outputPos =
578  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
579  // Only single combiner operations are supported for now.
580  SmallVector<Operation *, 4> combinerOps;
581  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582  combinerOps.size() != 1)
583  return nullptr;
584 
585  // Return the combiner operation.
586  return combinerOps[0];
587 }
588 
589 /// Broadcast `value` to a vector of `shape` if possible. Return value
590 /// otherwise.
591 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
592  auto dstVecType = dyn_cast<VectorType>(dstType);
593  // If no shape to broadcast to, just return `value`.
594  if (dstVecType.getRank() == 0)
595  return value;
596  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
598  return value;
599  Location loc = b.getInsertionPoint()->getLoc();
600  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
601 }
602 
603 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
604 /// assumes that `reductionOp` has two operands and one of them is the reduction
605 /// initial value.buildMultiDimReduce
606 // Note: this is a true builder that notifies the OpBuilder listener.
607 // TODO: Consider moving as a static helper on the ReduceOp.
609  Value valueToReduce, Value acc,
610  ArrayRef<bool> dimsToMask) {
611  auto maybeKind = getCombinerOpKind(reduceOp);
612  assert(maybeKind && "Failed precondition: could not get reduction kind");
613  return b.create<vector::MultiDimReductionOp>(
614  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
615 }
616 
617 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
618  return llvm::to_vector(
619  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
620 }
621 
622 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
623 /// reduction iterator.
624 static bool hasReductionIterator(LinalgOp &op) {
625  return isa<linalg::ReduceOp>(op) ||
626  (isa<linalg::GenericOp>(op) &&
627  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
628 }
629 
630 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
631 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
632 /// currently being vectorized. If `dest` has null rank, build an memref.store.
633 /// Return the produced value or null if no value is produced.
634 // Note: this is a true builder that notifies the OpBuilder listener.
635 // TODO: Consider moving as a static helper on the ReduceOp.
636 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
637  OpOperand *outputOperand,
638  VectorizationState &state) {
639  Location loc = value.getLoc();
640  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
641  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
642 
643  // Compute the vector type of the value to store. This type should be an
644  // identity or projection of the canonical vector type without any permutation
645  // applied, given that any permutation in a transfer write happens as part of
646  // the write itself.
648  opOperandMap.getContext(), opOperandMap.getNumInputs(),
649  [&](AffineDimExpr dimExpr) -> bool {
650  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
651  });
652  auto vectorType = state.getCanonicalVecType(
653  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
654 
655  Operation *write;
656  if (vectorType.getRank() > 0) {
657  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
658  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
659  rewriter.create<arith::ConstantIndexOp>(loc, 0));
660  value = broadcastIfNeeded(rewriter, value, vectorType);
661  assert(value.getType() == vectorType && "Incorrect type");
662  write = rewriter.create<vector::TransferWriteOp>(
663  loc, value, outputOperand->get(), indices, writeMap);
664  } else {
665  // 0-d case is still special: do not invert the reindexing writeMap.
666  if (!isa<VectorType>(value.getType()))
667  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
668  assert(value.getType() == vectorType && "Incorrect type");
669  write = rewriter.create<vector::TransferWriteOp>(
670  loc, value, outputOperand->get(), ValueRange{});
671  }
672 
673  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
674 
675  // If masked, set in-bounds to true. Masking guarantees that the access will
676  // be in-bounds.
677  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
679  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
680  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
681  }
682 
683  LDBG("vectorized op: " << *write << "\n");
684  if (!write->getResults().empty())
685  return write->getResult(0);
686  return Value();
687 }
688 
689 // Custom vectorization precondition function type. This is intented to be used
690 // with CustomVectorizationHook. Returns success if the corresponding custom
691 // hook can vectorize the op.
693  std::function<LogicalResult(Operation *, bool)>;
694 
695 // Custom vectorization function type. Produce a vector form of Operation*
696 // assuming all its vectorized operands are already in the IRMapping.
697 // Return nullptr if the Operation cannot be vectorized.
699  std::function<VectorizationResult(Operation *, const IRMapping &)>;
700 
701 /// Helper function to vectorize the terminator of a `linalgOp`. New result
702 /// vector values are appended to `newResults`. Return
703 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
704 /// should not try to map produced operations and instead return the results
705 /// using the `newResults` vector making them available to the vectorization
706 /// algorithm for RAUW. This function is meant to be used as a
707 /// CustomVectorizationHook.
708 static VectorizationResult
710  const IRMapping &bvm, VectorizationState &state,
711  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
712  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
713  if (!yieldOp)
715  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
716  // TODO: Scan for an opportunity for reuse.
717  // TODO: use a map.
718  Value vectorValue = bvm.lookup(output.value());
719  Value newResult =
720  buildVectorWrite(rewriter, vectorValue,
721  linalgOp.getDpsInitOperand(output.index()), state);
722  if (newResult)
723  newResults.push_back(newResult);
724  }
725 
727 }
728 
729 /// Helper function to vectorize the index operations of a `linalgOp`. Return
730 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
731 /// should map the produced operations. This function is meant to be used as a
732 /// CustomVectorizationHook.
734  VectorizationState &state,
735  Operation *op,
736  LinalgOp linalgOp) {
737  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
738  if (!indexOp)
740  auto loc = indexOp.getLoc();
741  // Compute the static loop sizes of the index op.
742  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
743  auto dim = indexOp.getDim();
744  // Compute a one-dimensional index vector for the index op dimension.
745  auto indexVectorType =
746  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
747  state.getScalableVecDims()[dim]);
748  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
749  // Return the one-dimensional index vector if it lives in the trailing
750  // dimension of the iteration space since the vectorization algorithm in this
751  // case can handle the broadcast.
752  if (dim == targetShape.size() - 1)
754  // Otherwise permute the targetShape to move the index dimension last,
755  // broadcast the one-dimensional index vector to the permuted shape, and
756  // finally transpose the broadcasted index vector to undo the permutation.
757  auto permPattern =
758  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759  std::swap(permPattern[dim], permPattern.back());
760  auto permMap =
761  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
762 
763  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
764  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
765  indexSteps);
766  SmallVector<int64_t> transposition =
767  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768  std::swap(transposition.back(), transposition[dim]);
769  auto transposeOp =
770  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
771  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
772 }
773 
774 /// Helper function to check if the tensor.extract can be vectorized by the
775 /// custom hook vectorizeTensorExtract.
776 static LogicalResult
777 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
778  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
779  if (!extractOp)
780  return failure();
781 
782  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
783  return failure();
784 
785  // Check the index type, but only for non 0-d tensors (for which we do need
786  // access indices).
787  if (not extractOp.getIndices().empty()) {
788  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
789  return failure();
790  }
791 
792  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
793  return !VectorType::isValidElementType(type);
794  })) {
795  return failure();
796  }
797 
798  return success();
799 }
800 
801 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
802 /// generated from `tensor.extract`. The offset is calculated as follows
803 /// (example using scalar values):
804 ///
805 /// offset = extractOp.indices[0]
806 /// for (i = 1; i < numIndices; i++)
807 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
808 ///
809 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
810 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
812  VectorizationState &state,
813  tensor::ExtractOp extractOp,
814  const IRMapping &bvm) {
815  // The vector of indices for GatherOp should be shaped as the output vector.
816  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
817  auto loc = extractOp.getLoc();
818 
819  Value offset = broadcastIfNeeded(
820  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
821 
822  const size_t numIndices = extractOp.getIndices().size();
823  for (size_t i = 1; i < numIndices; i++) {
824  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
825 
826  auto dimSize = broadcastIfNeeded(
827  rewriter,
828  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
829  indexVecType);
830 
831  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
832 
833  auto extractOpIndex = broadcastIfNeeded(
834  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
835 
836  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
837  }
838 
839  return offset;
840 }
841 
843 
844 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
845 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
846 /// represents a contiguous load operation.
847 ///
848 /// Note that when calling this hook, it is assumed that the output vector is
849 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
850 /// labelled as a gather load before entering this method.
851 ///
852 /// Following on from the above, it is assumed that:
853 /// * for statically shaped loops, when no masks are used, only one dim is !=
854 /// 1 (that's what the shape of the output vector is based on).
855 /// * for dynamically shaped loops, there might be more non-unit dims
856 /// as the output vector type is user-specified.
857 ///
858 /// TODO: Statically shaped loops + vector masking
859 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
860  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
861  assert(
862  (linalgOp.hasDynamicShape() ||
863  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864  "For statically shaped Linalg Ops, only one "
865  "non-unit loop dim is expected");
866  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
867 
868  size_t idx = loopRanges.size() - 1;
869  for (; idx != 0; idx--)
870  if (loopRanges[idx] != 1)
871  break;
872 
873  return idx;
874 }
875 
876 /// Checks whether `val` can be used for calculating a loop invariant index.
877 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
878  VectorType resType) {
879 
880  assert(((llvm::count_if(resType.getShape(),
881  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882  "n-D vectors are not yet supported");
883 
884  // Blocks outside _this_ linalg.generic are effectively loop invariant.
885  // However, analysing block arguments for _this_ linalg.generic Op is a bit
886  // tricky. Just bail out in the latter case.
887  // TODO: We could try analysing the corresponding affine map here.
888  auto *block = linalgOp.getBlock();
889  if (isa<BlockArgument>(val))
890  return llvm::all_of(block->getArguments(),
891  [&val](Value v) { return (v != val); });
892 
893  Operation *defOp = val.getDefiningOp();
894  assert(defOp && "This is neither a block argument nor an operation result");
895 
896  // IndexOp is loop invariant as long as its result remains constant across
897  // iterations. Note that for dynamic shapes, the corresponding dim will also
898  // be conservatively treated as != 1.
899  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
901  }
902 
903  auto *ancestor = block->findAncestorOpInBlock(*defOp);
904 
905  // Values define outside `linalgOp` are loop invariant.
906  if (!ancestor)
907  return true;
908 
909  // Values defined inside `linalgOp`, which are constant, are loop invariant.
910  if (isa<arith::ConstantOp>(ancestor))
911  return true;
912 
913  bool result = true;
914  for (auto op : ancestor->getOperands())
915  result &= isLoopInvariantIdx(linalgOp, op, resType);
916 
917  return result;
918 }
919 
920 /// Check whether `val` could be used for calculating the trailing index for a
921 /// contiguous load operation.
922 ///
923 /// There are currently 3 types of values that are allowed here:
924 /// 1. loop-invariant values,
925 /// 2. values that increment by 1 with every loop iteration,
926 /// 3. results of basic arithmetic operations (linear and continuous)
927 /// involving 1., 2. and 3.
928 /// This method returns True if indeed only such values are used in calculating
929 /// `val.`
930 ///
931 /// Additionally, the trailing index for a contiguous load operation should
932 /// increment by 1 with every loop iteration, i.e. be based on:
933 /// * `linalg.index <dim>` ,
934 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
935 /// `linalg.index <dim>` increments by 1 with every loop iteration).
936 /// `foundIndexOp` is updated to `true` when such Op is found.
937 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
938  bool &foundIndexOp, VectorType resType) {
939 
940  assert(((llvm::count_if(resType.getShape(),
941  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942  "n-D vectors are not yet supported");
943 
944  // Blocks outside _this_ linalg.generic are effectively loop invariant.
945  // However, analysing block arguments for _this_ linalg.generic Op is a bit
946  // tricky. Just bail out in the latter case.
947  // TODO: We could try analysing the corresponding affine map here.
948  auto *block = linalgOp.getBlock();
949  if (isa<BlockArgument>(val))
950  return llvm::all_of(block->getArguments(),
951  [&val](Value v) { return (v != val); });
952 
953  Operation *defOp = val.getDefiningOp();
954  assert(defOp && "This is neither a block argument nor an operation result");
955 
956  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
957  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
958 
959  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
960  return true;
961  }
962 
963  auto *ancestor = block->findAncestorOpInBlock(*defOp);
964 
965  if (!ancestor)
966  return false;
967 
968  // Conservatively reject Ops that could lead to indices with stride other
969  // than 1.
970  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
971  return false;
972 
973  bool result = false;
974  for (auto op : ancestor->getOperands())
975  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
976 
977  return result;
978 }
979 
980 /// Infer the memory access pattern for the input ExtractOp
981 ///
982 /// Based on the ExtratOp result shape and the access indices, decides whether
983 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
984 /// or a gather load. When analysing the ExtractOp indices (to identify
985 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
986 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
987 /// Op).
988 ///
989 /// Note that it is always safe to use gather load operations for contiguous
990 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
991 /// that `extractOp` is a gather load.
993 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
994  LinalgOp &linalgOp, VectorType resType) {
995 
996  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
997 
998  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
999  if (inputShape.getShape().empty())
1001 
1002  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1003  // otherwise.
1004  bool isOutput1DVector =
1005  (llvm::count_if(resType.getShape(),
1006  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1007  // 1. Assume that it's a gather load when reading non-1D vector.
1008  if (!isOutput1DVector)
1010 
1011  bool leadingIdxsLoopInvariant = true;
1012 
1013  // 2. Analyze the leading indices of `extractOp`.
1014  // Look at the way each index is calculated and decide whether it is suitable
1015  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1016  // gather load.
1017  auto indices = extractOp.getIndices();
1018  auto leadIndices = indices.drop_back(1);
1019 
1020  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1021  if (inputShape.getShape()[i] == 1)
1022  continue;
1023 
1024  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1025  }
1026 
1027  if (!leadingIdxsLoopInvariant) {
1028  LDBG("Found gather load: " << extractOp);
1030  }
1031 
1032  // 3. Analyze the trailing index for `extractOp`.
1033  // At this point we know that the leading indices are loop invariant. This
1034  // means that is potentially a scalar or a contiguous load. We can decide
1035  // based on the trailing idx.
1036  auto extractOpTrailingIdx = indices.back();
1037 
1038  // 3a. Scalar broadcast load
1039  // If the trailing index is loop invariant then this is a scalar load.
1040  if (leadingIdxsLoopInvariant &&
1041  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1042  LDBG("Found scalar broadcast load: " << extractOp);
1043 
1045  }
1046 
1047  // 3b. Contiguous loads
1048  // The trailing `extractOp` index should increment with every loop iteration.
1049  // This effectively means that it must be based on the trailing loop index.
1050  // This is what the following bool captures.
1051  bool foundIndexOp = false;
1052  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1053  foundIndexOp, resType);
1054  // TODO: Support generating contiguous loads for column vectors - that will
1055  // require adding a permutation map to tranfer_read Ops.
1056  bool isRowVector = resType.getShape().back() != 1;
1057  isContiguousLoad &= (foundIndexOp && isRowVector);
1058 
1059  if (isContiguousLoad) {
1060  LDBG("Found contigous load: " << extractOp);
1062  }
1063 
1064  // 4. Fallback case - gather load.
1065  LDBG("Found gather load: " << extractOp);
1067 }
1068 
1069 /// Helper function to vectorize the tensor.extract operations. Returns
1070 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1071 /// should map the produced operations. This function is meant to be used as a
1072 /// CustomVectorizationHook.
1073 static VectorizationResult
1075  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1076  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1077  if (!extractOp)
1079  auto loc = extractOp.getLoc();
1080 
1081  // Compute the static loop sizes of the extract op.
1082  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1084  loc,
1085  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1086  /*value=*/true));
1087  auto passThruConstantOp =
1088  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1089 
1090  // Base indices are currently set to 0. We will need to re-visit if more
1091  // generic scenarios are to be supported.
1092  SmallVector<Value> baseIndices(
1093  extractOp.getIndices().size(),
1094  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1095 
1096  VectorMemoryAccessKind memAccessKind =
1097  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1098 
1099  // 1. Handle gather access
1100  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1101  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1102 
1103  // Generate the gather load
1104  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1105  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106  maskConstantOp, passThruConstantOp);
1107  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1108 
1109  LDBG("Vectorised as gather load: " << extractOp << "\n");
1111  }
1112 
1113  // 2. Handle:
1114  // a. scalar loads + broadcast,
1115  // b. contiguous loads.
1116  // Both cases use vector.transfer_read.
1117 
1118  assert(llvm::count_if(resultType.getShape(),
1119  [](uint64_t dim) { return dim != 1; }) &&
1120  "Contiguous loads and scalar loads + broadcast only support 1-D "
1121  "vectors ATM!");
1122 
1123  // Collect indices for `vector.transfer_read`. At this point, the indices will
1124  // either be scalars or would have been broadcast to vectors matching the
1125  // result type. For indices that are vectors, there are two options:
1126  // * for non-trailing indices, all elements are identical (contiguous
1127  // loads are identified by looking for non-trailing indices that are
1128  // invariant with respect to the corresponding linalg.generic), or
1129  // * for trailing indices, the index vector will contain values with stride
1130  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1131  // needed.
1132  // This means that
1133  // * for scalar indices - just re-use it,
1134  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1135  // (0th) element and use that.
1136  SmallVector<Value> transferReadIdxs;
1137  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1138  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1139  if (idx.getType().isIndex()) {
1140  transferReadIdxs.push_back(idx);
1141  continue;
1142  }
1143 
1144  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1145  loc,
1146  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1147  resultType.getScalableDims().back()),
1148  idx);
1149  transferReadIdxs.push_back(
1150  rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1151  }
1152 
1153  // `tensor.extract_element` is always in-bounds, hence the following holds.
1154  auto dstRank = resultType.getRank();
1155  auto srcRank = extractOp.getTensor().getType().getRank();
1156  SmallVector<bool> inBounds(dstRank, true);
1157 
1158  // 2a. Handle scalar broadcast access.
1159  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1160  MLIRContext *ctx = rewriter.getContext();
1161  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1162  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1163 
1164  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1165  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1166  permutationMap, inBounds);
1167 
1168  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1169  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1170  }
1171 
1172  // 2b. Handle contiguous access.
1173  auto permutationMap = AffineMap::getMinorIdentityMap(
1174  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1175 
1176  int32_t rankDiff = dstRank - srcRank;
1177  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1178  // dims so that the ranks match. This is done by extending the map with 0s.
1179  // For example, for dstRank = 3, srcRank = 2, the following map created
1180  // above:
1181  // (d0, d1) --> (d0, d1)
1182  // is extended as:
1183  // (d0, d1) --> (0, d0, d1)
1184  while (rankDiff > 0) {
1185  permutationMap = permutationMap.insertResult(
1186  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1187  rankDiff--;
1188  }
1189 
1190  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1191  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1192  inBounds);
1193 
1194  LDBG("Vectorised as contiguous load: " << extractOp);
1195  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1196 }
1197 
1198 /// Emit reduction operations if the shapes of the value to reduce is different
1199 /// that the result shape.
1200 // Note: this is a true builder that notifies the OpBuilder listener.
1201 // TODO: Consider moving as a static helper on the ReduceOp.
1202 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1203  Value reduceValue, Value initialValue,
1204  const IRMapping &bvm) {
1205  Value reduceVec = bvm.lookup(reduceValue);
1206  Value outputVec = bvm.lookup(initialValue);
1207  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1208  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1209  // Reduce only if needed as the value may already have been reduce for
1210  // contraction vectorization.
1211  if (!reduceType ||
1212  (outputType && reduceType.getShape() == outputType.getShape()))
1213  return nullptr;
1214  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1215  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1216 }
1217 
1218 /// Generic vectorization for a single operation `op`, given already vectorized
1219 /// operands carried by `bvm`. Vectorization occurs as follows:
1220 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1221 /// result on success.
1222 /// 2. Clone any constant in the current scope without vectorization: each
1223 /// consumer of the constant will later determine the shape to which the
1224 /// constant needs to be broadcast to.
1225 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1226 /// of the `customVectorizationHooks` to cover such cases.
1227 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1228 /// operand of maximal rank. Other operands have smaller rank and are
1229 /// broadcast accordingly. It is assumed this broadcast is always legal,
1230 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1231 ///
1232 /// This function assumes all operands of `op` have been vectorized and are in
1233 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1234 /// a topologically-sorted list of ops.
1235 /// This function does not update `bvm` but returns a VectorizationStatus that
1236 /// instructs the caller what `bvm` update needs to occur.
1237 static VectorizationResult
1239  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1240  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1241  LDBG("vectorize op " << *op << "\n");
1242 
1243  // 1. Try to apply any CustomVectorizationHook.
1244  if (!customVectorizationHooks.empty()) {
1245  for (auto &customFunc : customVectorizationHooks) {
1246  VectorizationResult result = customFunc(op, bvm);
1247  if (result.status == VectorizationStatus::Failure)
1248  continue;
1249  return result;
1250  }
1251  }
1252 
1253  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1254  // Clone so that the constant is not confined to the linalgOp block .
1255  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1256  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1257 
1258  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1261 
1262  // 4 . Check if the operation is a reduction.
1263  SmallVector<std::pair<Value, Value>> reductionOperands;
1264  for (Value operand : op->getOperands()) {
1265  auto blockArg = dyn_cast<BlockArgument>(operand);
1266  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1267  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1268  continue;
1269  SmallVector<Operation *> reductionOps;
1270  Value reduceValue = matchReduction(
1271  linalgOp.getRegionOutputArgs(),
1272  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1273  if (!reduceValue)
1274  continue;
1275  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1276  }
1277  if (!reductionOperands.empty()) {
1278  assert(reductionOperands.size() == 1);
1279  Operation *reduceOp =
1280  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1281  reductionOperands[0].second, bvm);
1282  if (reduceOp)
1284  }
1285 
1286  // 5. Generic vectorization path for ElementwiseMappable ops.
1287  // a. Get the first max ranked shape.
1288  VectorType firstMaxRankedType;
1289  for (Value operand : op->getOperands()) {
1290  auto vecOperand = bvm.lookup(operand);
1291  assert(vecOperand && "Vector operand couldn't be found");
1292 
1293  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1294  if (vecType && (!firstMaxRankedType ||
1295  firstMaxRankedType.getRank() < vecType.getRank()))
1296  firstMaxRankedType = vecType;
1297  }
1298  // b. Broadcast each op if needed.
1299  SmallVector<Value> vecOperands;
1300  for (Value scalarOperand : op->getOperands()) {
1301  Value vecOperand = bvm.lookup(scalarOperand);
1302  assert(vecOperand && "Vector operand couldn't be found");
1303 
1304  if (firstMaxRankedType) {
1305  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1306  getElementTypeOrSelf(vecOperand.getType()),
1307  firstMaxRankedType.getScalableDims());
1308  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1309  } else {
1310  vecOperands.push_back(vecOperand);
1311  }
1312  }
1313  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1314  SmallVector<Type> resultTypes;
1315  for (Type resultType : op->getResultTypes()) {
1316  resultTypes.push_back(
1317  firstMaxRankedType
1318  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1319  firstMaxRankedType.getScalableDims())
1320  : resultType);
1321  }
1322  // d. Build and return the new op.
1323  return VectorizationResult{
1325  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1326  resultTypes, op->getAttrs())};
1327 }
1328 
1329 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1330 /// vector form. Generic vectorization proceeds as follows:
1331 /// 1. Verify the `linalgOp` has one non-empty region.
1332 /// 2. Values defined above the region are mapped to themselves and will be
1333 /// broadcasted on a per-need basis by their consumers.
1334 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1335 /// load).
1336 /// TODO: Reuse opportunities for RAR dependencies.
1337 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1338 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1339 /// iteration indices.
1340 /// 5. Iteratively call vectorizeOneOp on the region operations.
1341 ///
1342 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1343 /// performed to the maximal common vector size implied by the `linalgOp`
1344 /// iteration space. This eager broadcasting is introduced in the
1345 /// permutation_map of the vector.transfer_read operations. The eager
1346 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1347 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1348 /// the absence of good canonicalizations, the amount of work increases.
1349 /// This is not deemed a problem as we expect canonicalizations and foldings to
1350 /// aggressively clean up the useless work.
1351 static LogicalResult
1353  LinalgOp linalgOp,
1354  SmallVectorImpl<Value> &newResults) {
1355  LDBG("Vectorizing operation as linalg generic\n");
1356  Block *block = linalgOp.getBlock();
1357 
1358  // 2. Values defined above the region can only be broadcast for now. Make them
1359  // map to themselves.
1360  IRMapping bvm;
1361  SetVector<Value> valuesSet;
1362  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1363  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1364 
1365  if (linalgOp.getNumDpsInits() == 0)
1366  return failure();
1367 
1368  // 3. Turn all BBArgs into vector.transfer_read / load.
1369  Location loc = linalgOp.getLoc();
1370  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1371  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1372  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1373  if (linalgOp.isScalar(opOperand)) {
1374  bvm.map(bbarg, opOperand->get());
1375  continue;
1376  }
1377 
1378  // 3.a. Convert the indexing map for this input/output to a transfer read
1379  // permutation map and masking map.
1380  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1381 
1382  AffineMap readMap;
1383  VectorType readType;
1384  Type elemType = getElementTypeOrSelf(opOperand->get());
1385  if (linalgOp.isDpsInput(opOperand)) {
1386  // 3.a.i. For input reads we use the canonical vector shape.
1387  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1388  readType = state.getCanonicalVecType(elemType);
1389  } else {
1390  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1391  // reductions), the vector shape is computed by mapping the canonical
1392  // vector shape to the output domain and back to the canonical domain.
1393  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1394  readType =
1395  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1396  }
1397 
1398  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1399 
1400  Operation *read = rewriter.create<vector::TransferReadOp>(
1401  loc, readType, opOperand->get(), indices, readMap);
1402  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1403  Value readValue = read->getResult(0);
1404 
1405  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1406  // will be in-bounds.
1407  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1408  SmallVector<bool> inBounds(readType.getRank(), true);
1409  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1410  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1411  }
1412 
1413  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1414  // TODO: remove this.
1415  if (readType.getRank() == 0)
1416  readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1417  ArrayRef<int64_t>());
1418 
1419  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1420  << "\n");
1421  bvm.map(bbarg, readValue);
1422  bvm.map(opOperand->get(), readValue);
1423  }
1424 
1426  // 4a. Register CustomVectorizationHook for yieldOp.
1427  CustomVectorizationHook vectorizeYield =
1428  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1429  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1430  };
1431  hooks.push_back(vectorizeYield);
1432 
1433  // 4b. Register CustomVectorizationHook for indexOp.
1434  CustomVectorizationHook vectorizeIndex =
1435  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1436  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1437  };
1438  hooks.push_back(vectorizeIndex);
1439 
1440  // 4c. Register CustomVectorizationHook for extractOp.
1441  CustomVectorizationHook vectorizeExtract =
1442  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1443  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1444  };
1445  hooks.push_back(vectorizeExtract);
1446 
1447  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1448  for (Operation &op : block->getOperations()) {
1449  VectorizationResult result =
1450  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1451  if (result.status == VectorizationStatus::Failure) {
1452  LDBG("failed to vectorize: " << op << "\n");
1453  return failure();
1454  }
1455  if (result.status == VectorizationStatus::NewOp) {
1456  Operation *maybeMaskedOp =
1457  state.maskOperation(rewriter, result.newOp, linalgOp);
1458  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1459  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1460  }
1461  }
1462 
1463  return success();
1464 }
1465 
1466 /// Given a tensor::PackOp, return the `dest` shape before any packing
1467 /// permutations.
1468 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1469  ArrayRef<int64_t> destShape) {
1470  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1471 }
1472 
1473 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1474 /// create an empty destination tensor and create a TransferWriteOp from the
1475 /// input to the empty tensor. If the destination shape is not the same as the
1476 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1477 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1478 /// inBounds attribute of the transfer write op instead of masking.
1480  Value input,
1481  SmallVector<OpFoldResult> destSizes,
1482  ArrayRef<int64_t> inputVectorSizes,
1483  bool useInBoundsInsteadOfMasking) {
1484 
1485  auto inputType = cast<VectorType>(input.getType());
1486  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1487  inputType.getElementType());
1488  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1489  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1490  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1491  SmallVector<bool> inBoundsVal(rank, true);
1492  if (useInBoundsInsteadOfMasking) {
1493  // Update the inBounds attribute.
1494  for (unsigned i = 0; i < rank; i++)
1495  inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1496  !ShapedType::isDynamic(destShape[i]);
1497  }
1498  Operation *write = builder.create<vector::TransferWriteOp>(
1499  loc,
1500  /*vector=*/input,
1501  /*source=*/dest,
1502  /*indices=*/SmallVector<Value>(rank, zero),
1503  /*inBounds=*/inBoundsVal);
1504  assert(llvm::none_of(
1505  destShape.drop_front(inputVectorSizes.size()),
1506  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1507  "Only dims aligned with inputVectorSizes may be dynamic");
1508  if (useInBoundsInsteadOfMasking)
1509  return write;
1510  bool needMaskForWrite = !llvm::equal(
1511  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1512  if (needMaskForWrite) {
1513  SmallVector<int64_t> writeMaskShape;
1514  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1515  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1516  destShape.end());
1517  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1518  Value maskForWrite =
1519  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1520  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1521  }
1522  return write;
1523 }
1524 
1525 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1526 /// padding value and (3) input vector sizes into:
1527 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1528 /// As in the following example:
1529 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1530 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1531 ///
1532 /// This pack would be vectorized to:
1533 ///
1534 /// %load = vector.mask %mask {
1535 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1536 /// {in_bounds = [true, true, true]} :
1537 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1538 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1539 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1540 /// to vector<32x4x2x1x16xf32>
1541 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1542 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1543 /// %write = vector.transfer_write %transpose,
1544 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1545 /// {in_bounds = [true, true, true, true, true]}
1546 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1547 ///
1548 /// If the (3) input vector sizes are not provided, the vector sizes are
1549 /// determined by the result tensor shape. Also, we update the inBounds
1550 /// attribute instead of masking.
1551 static LogicalResult
1552 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1553  ArrayRef<int64_t> inputVectorSizes,
1554  SmallVectorImpl<Value> &newResults) {
1555  OpBuilder::InsertionGuard g(rewriter);
1556  rewriter.setInsertionPoint(packOp);
1557 
1558  Location loc = packOp.getLoc();
1559  auto padValue = packOp.getPaddingValue();
1560  if (!padValue) {
1561  padValue = rewriter.create<arith::ConstantOp>(
1562  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1563  }
1564  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1565  LogicalResult status =
1566  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1567  .reifyResultShapes(rewriter, reifiedReturnShapes);
1568  (void)status; // prevent unused variable warning on non-assert builds.
1569  assert(succeeded(status) && "failed to reify result shapes");
1570 
1571  // If the input vector sizes are not provided, then the vector sizes are
1572  // determined by the result tensor shape. In case the vector sizes aren't
1573  // provided, we update the inBounds attribute instead of masking.
1574  bool useInBoundsInsteadOfMasking = false;
1575  if (inputVectorSizes.empty()) {
1576  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1577  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1578  useInBoundsInsteadOfMasking = true;
1579  }
1580 
1581  // Create masked TransferReadOp.
1582  SmallVector<int64_t> inputShape(inputVectorSizes);
1583  auto innerTiles = packOp.getStaticInnerTiles();
1584  auto innerDimsPos = packOp.getInnerDimsPos();
1585  auto outerDimsPerm = packOp.getOuterDimsPerm();
1586  if (!outerDimsPerm.empty())
1587  applyPermutationToVector(inputShape,
1588  invertPermutationVector(outerDimsPerm));
1589  for (auto [idx, size] : enumerate(innerTiles))
1590  inputShape[innerDimsPos[idx]] *= size;
1591  auto maskedRead = vector::createReadOrMaskedRead(
1592  rewriter, loc, packOp.getSource(), inputShape, padValue,
1593  useInBoundsInsteadOfMasking);
1594 
1595  // Create ShapeCastOp.
1596  SmallVector<int64_t> destShape(inputVectorSizes);
1597  destShape.append(innerTiles.begin(), innerTiles.end());
1598  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1599  packOp.getDestType().getElementType());
1600  auto shapeCastOp =
1601  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1602 
1603  // Create TransposeOp.
1604  auto destPermutation =
1606  auto transposeOp = rewriter.create<vector::TransposeOp>(
1607  loc, shapeCastOp.getResult(), destPermutation);
1608 
1609  // Create TransferWriteOp.
1611  rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1612  inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1613  newResults.push_back(write->getResult(0));
1614  return success();
1615 }
1616 
1617 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1618 /// Vector::TransferReadOp - Reads a vector from the source tensor
1619 /// vector::TransposeOp - Transpose the Source tensor
1620 /// ShapeCastOp - Reshape the data based on the target.
1621 /// vector::TransferWriteOp. - Write the result vector back to the destination
1622 /// tensor.
1623 /// If the vector sizes are not provided:
1624 /// * the vector sizes are determined by the input operand and attributes,
1625 /// * update the inBounds attribute instead of masking.
1626 static LogicalResult
1627 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1628  ArrayRef<int64_t> inputVectorSizes,
1629  SmallVectorImpl<Value> &newResults) {
1630 
1631  OpBuilder::InsertionGuard g(rewriter);
1632  rewriter.setInsertionPoint(unpackOp);
1633 
1634  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1635 
1636  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1637  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1638  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1639  bool useInBoundsInsteadOfMasking = false;
1640  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1641 
1642  auto destSize = unpackOp.getDestRank();
1643 
1644  if (!inputVectorSizes.empty())
1645  assert(inputVectorSizes.size() == destSize &&
1646  "Incorrect number of input vector sizes");
1647 
1648  // vectorSizes is the shape of the vector that will be used to do final
1649  // write on the destination tensor. It is set like this: Let's say the
1650  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1651  // Thus:
1652  // 1. vectorSizes = sourceShape.take_front(N)
1653  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1654  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1655  // innerTiles attribute value.
1656  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1657  if (vectorSizes.empty()) {
1658  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1659  if (!outerDimsPerm.empty())
1660  applyPermutationToVector(vectorSizes, outerDimsPerm);
1661  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1662  vectorSizes[pos] *= innerTiles[i];
1663 
1664  useInBoundsInsteadOfMasking = true;
1665  }
1666 
1667  // readVectorSizes is the size of tensor used to read and apply mask. It is
1668  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1669  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1670  // size M-N
1671  // Thus:
1672  // - initially: readVectorSizes = vectorInputSizes
1673  // - Divide all the readMaskShape locations pointed by innerDimPos
1674  // by the innerTileSize attribute value.
1675  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1676  // - Append the remaining shape from SS
1677  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1678  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1679  // 128] and outer_dims_perm is [1, 0] then read shape is:
1680  // ReadVectorSizes(initial): [512, 128]
1681  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1682  // = [16, 8]
1683  // After applying outer_dims_perm: [8, 16]
1684  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1685 
1686  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1687 
1688  for (auto [index, size] : enumerate(innerTiles)) {
1689  readVectorSizes[innerDimPos[index]] =
1690  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1691  }
1692  if (!outerDimsPerm.empty()) {
1693  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1694  }
1695  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1696  sourceShape.end());
1697 
1698  ReifiedRankedShapedTypeDims reifiedRetShapes;
1699  LogicalResult status =
1700  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1701  .reifyResultShapes(rewriter, reifiedRetShapes);
1702  if (status.failed()) {
1703  LDBG("Unable to reify result shapes of " << unpackOp);
1704  return failure();
1705  }
1706  Location loc = unpackOp->getLoc();
1707 
1708  auto padValue = rewriter.create<arith::ConstantOp>(
1709  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1710 
1711  // Read result, mask if necessary. If transferReadOp shape is not equal
1712  // to shape of source, then a mask is necessary.
1714  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1715  /*useInBoundsInsteadOfMasking=*/false);
1716 
1717  PackingMetadata packMetadata;
1718  SmallVector<int64_t> lastDimToInsertPosPerm =
1719  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1720  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1721  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1722  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1723  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1724  RankedTensorType stripMineTensorType =
1725  RankedTensorType::get(stripMineShape, stripMineElemType);
1726  // Transpose the appropriate rows to match output.
1727  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1728  loc, readResult, lastDimToInsertPosPerm);
1729 
1730  // Collapse the vector to the size required by result.
1731  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1732  stripMineTensorType, packMetadata.reassociations);
1733  mlir::VectorType vecCollapsedType =
1734  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1735  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1736  loc, vecCollapsedType, transposeOp->getResult(0));
1737 
1738  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1739  // otherwise the validator complains that the mask size is invalid.
1740  SmallVector<int64_t> writeVectorSizes(
1741  unpackOp.getDestType().hasStaticShape()
1742  ? vectorSizes
1743  : shapeCastOp.getResultVectorType().getShape());
1745  rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1746  writeVectorSizes, useInBoundsInsteadOfMasking);
1747  newResults.push_back(write->getResult(0));
1748  return success();
1749 }
1750 
1751 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1752 /// and (3) all-zero lowPad to
1753 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1754 static LogicalResult
1755 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1756  ArrayRef<int64_t> inputVectorSizes,
1757  SmallVectorImpl<Value> &newResults) {
1758  auto padValue = padOp.getConstantPaddingValue();
1759  Location loc = padOp.getLoc();
1760 
1761  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1762  OpBuilder::InsertionGuard g(rewriter);
1763  rewriter.setInsertionPoint(padOp);
1764 
1765  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1766  LogicalResult status =
1767  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1768  .reifyResultShapes(rewriter, reifiedReturnShapes);
1769  (void)status; // prevent unused variable warning on non-assert builds
1770  assert(succeeded(status) && "failed to reify result shapes");
1771  auto maskedRead = vector::createReadOrMaskedRead(
1772  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1773  /*useInBoundsInsteadOfMasking=*/false);
1775  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1776  /*useInBoundsInsteadOfMasking=*/false);
1777  newResults.push_back(write->getResult(0));
1778  return success();
1779 }
1780 
1781 // TODO: probably need some extra checks for reduction followed by consumer
1782 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1783 static LogicalResult reductionPreconditions(LinalgOp op) {
1784  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1785  LDBG("reduction precondition failed: no reduction iterator\n");
1786  return failure();
1787  }
1788  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1789  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1790  if (indexingMap.isPermutation())
1791  continue;
1792 
1793  Operation *reduceOp = matchLinalgReduction(&opOperand);
1794  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1795  LDBG("reduction precondition failed: reduction detection failed\n");
1796  return failure();
1797  }
1798  }
1799  return success();
1800 }
1801 
1802 static LogicalResult
1804  bool flatten1DDepthwiseConv) {
1805  if (flatten1DDepthwiseConv) {
1806  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1807  "supported\n");
1808  return failure();
1809  }
1810 
1811  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1812  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1813  return failure();
1814  }
1815 
1816  // Support dynamic shapes in 1D depthwise convolution, but only in the
1817  // _channel_ dimension.
1818  Value lhs = conv.getDpsInputOperand(0)->get();
1819  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1820  auto shapeWithoutCh = lhsShape.drop_back(1);
1821  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1822  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1823  "channel dim can be dynamic\n");
1824  return failure();
1825  }
1826 
1827  return success();
1828 }
1829 
1830 static LogicalResult
1832  bool flatten1DDepthwiseConv) {
1833  if (isa<ConvolutionOpInterface>(op.getOperation()))
1834  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1835 
1836  if (hasReductionIterator(op))
1837  return reductionPreconditions(op);
1838 
1839  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1840  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1841  if (!isElementwise(op) &&
1842  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1843  op.getOperation()))
1844  return failure();
1845 
1846  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1847  return success();
1848 }
1849 
1850 /// Need to check if the inner-tiles are static/constant.
1851 static LogicalResult
1852 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1853  ArrayRef<int64_t> inputVectorSizes) {
1854 
1855  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1856  return !getConstantIntValue(res).has_value();
1857  })) {
1858  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1859  return failure();
1860  }
1861  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1862  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1863  unpackOp.getDestType().hasStaticShape() &&
1864  unpackOp.getSourceType().hasStaticShape();
1865  if (!satisfyEmptyCond &&
1866  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1867  return failure();
1868 
1869  return success();
1870 }
1871 
1872 static LogicalResult vectorizeLinalgOpPrecondition(
1873  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1874  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1875  // tensor with dimension of 0 cannot be vectorized.
1876  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1877  return failure();
1878  // Check API contract for input vector sizes.
1879  if (!inputVectorSizes.empty() &&
1880  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1881  inputVectorSizes)))
1882  return failure();
1883 
1884  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1885  linalgOp, flatten1DDepthwiseConv))) {
1886  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1887  return failure();
1888  }
1889 
1891 
1892  // Register CustomVectorizationPrecondition for extractOp.
1893  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1894 
1895  // All types in the body should be a supported element type for VectorType.
1896  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1897  // Check if any custom hook can vectorize the inner op.
1898  if (llvm::any_of(
1899  customPreconditions,
1900  [&](const CustomVectorizationPrecondition &customPrecondition) {
1901  return succeeded(
1902  customPrecondition(&innerOp, vectorizeNDExtract));
1903  })) {
1904  continue;
1905  }
1906  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1907  return !VectorType::isValidElementType(type);
1908  })) {
1909  return failure();
1910  }
1911  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1912  return !VectorType::isValidElementType(type);
1913  })) {
1914  return failure();
1915  }
1916  }
1917  if (isElementwise(linalgOp))
1918  return success();
1919 
1920  // TODO: isaConvolutionOpInterface that can also infer from generic
1921  // features. But we will still need stride/dilation attributes that will be
1922  // annoying to reverse-engineer...
1923  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1924  return success();
1925  // TODO: the common vector shape is equal to the static loop sizes only when
1926  // all indexing maps are projected permutations. For convs and stencils the
1927  // logic will need to evolve.
1928  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1929  LDBG("precondition failed: not projected permutations\n");
1930  return failure();
1931  }
1932  if (failed(reductionPreconditions(linalgOp))) {
1933  LDBG("precondition failed: reduction preconditions\n");
1934  return failure();
1935  }
1936  return success();
1937 }
1938 
1939 static LogicalResult
1940 vectorizePackOpPrecondition(tensor::PackOp packOp,
1941  ArrayRef<int64_t> inputVectorSizes) {
1942  auto padValue = packOp.getPaddingValue();
1943  Attribute cstAttr;
1944  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1945  LDBG("pad value is not constant: " << packOp << "\n");
1946  return failure();
1947  }
1948  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1949  bool satisfyEmptyCond = true;
1950  if (inputVectorSizes.empty()) {
1951  if (!packOp.getDestType().hasStaticShape() ||
1952  !packOp.getSourceType().hasStaticShape())
1953  satisfyEmptyCond = false;
1954  }
1955 
1956  if (!satisfyEmptyCond &&
1958  resultTensorShape.take_front(packOp.getSourceRank()),
1959  inputVectorSizes)))
1960  return failure();
1961 
1962  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1963  return !getConstantIntValue(v).has_value();
1964  })) {
1965  LDBG("inner_tiles must be constant: " << packOp << "\n");
1966  return failure();
1967  }
1968 
1969  return success();
1970 }
1971 
1972 static LogicalResult
1973 vectorizePadOpPrecondition(tensor::PadOp padOp,
1974  ArrayRef<int64_t> inputVectorSizes) {
1975  auto padValue = padOp.getConstantPaddingValue();
1976  if (!padValue) {
1977  LDBG("pad value is not constant: " << padOp << "\n");
1978  return failure();
1979  }
1980 
1981  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1982  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1983  inputVectorSizes)))
1984  return failure();
1985 
1986  if (llvm::any_of(padOp.getLow(), [](Value v) {
1987  std::optional<int64_t> res = getConstantIntValue(v);
1988  return !res.has_value() || res.value() != 0;
1989  })) {
1990  LDBG("low pad must all be zero: " << padOp << "\n");
1991  return failure();
1992  }
1993 
1994  return success();
1995 }
1996 
1997 /// Preconditions for scalable vectors. This is quite restrictive - it models
1998 /// the fact that in practice we would only make selected dimensions scalable.
1999 static LogicalResult
2001  ArrayRef<int64_t> inputVectorSizes,
2002  ArrayRef<bool> inputScalableVecDims) {
2003  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2004  "Number of input vector sizes and scalable dims doesn't match");
2005 
2006  size_t numOfScalableDims =
2007  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2008 
2009  if (numOfScalableDims == 0)
2010  return success();
2011 
2012  auto linalgOp = dyn_cast<LinalgOp>(op);
2013 
2014  // Cond 1: There's been no need for scalable vectorisation of
2015  // non-linalg Ops so far
2016  if (!linalgOp)
2017  return failure();
2018 
2019  // Cond 2: There's been no need for more than 2 scalable dims so far
2020  if (numOfScalableDims > 2)
2021  return failure();
2022 
2023  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2024  // it matches one of the supported cases:
2025  // 1. exactly 1 dim is scalable and that's the _last_ parallel dim
2026  // 2. exactly 2 dims are scalable and those are the _last two adjacent_
2027  // parallel dims
2028  // 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
2029  // The 2nd restriction above means that only Matmul-like Ops are supported
2030  // when 2 dims are scalable, e.g. :
2031  // * iterators = [parallel, parallel, reduction]
2032  // * scalable flags = [true, true, false]
2033 
2034  // Find the first scalable flag
2035  bool seenParalell = false;
2036  auto iterators = linalgOp.getIteratorTypesArray();
2037  SmallVector<bool> scalableFlags(inputScalableVecDims);
2038  while (!scalableFlags.back()) {
2039  seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2040 
2041  iterators.pop_back();
2042  scalableFlags.pop_back();
2043  }
2044 
2045  switch (iterators.back()) {
2046  case utils::IteratorType::reduction: {
2047  // Check 3. above is met.
2048  if (iterators.size() != inputVectorSizes.size()) {
2049  LDBG("Non-trailing reduction dim requested for scalable "
2050  "vectorization\n");
2051  return failure();
2052  }
2053  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2054  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2055  "is not supported\n");
2056  return failure();
2057  }
2058  break;
2059  }
2060  case utils::IteratorType::parallel: {
2061  // Check 1. and 2. above are met.
2062  if (seenParalell) {
2063  LDBG("Inner parallel dim not requested for scalable "
2064  "vectorization\n");
2065  return failure();
2066  }
2067  break;
2068  }
2069  }
2070 
2071  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2072  // supported for which expect the folowing config:
2073  // * iterators = [parallel, parallel, reduction]
2074  // * scalable flags = [true, true, false]
2075  if (numOfScalableDims == 2) {
2076  // Disallow below case which breaks 3. above:
2077  // * iterators = [..., parallel, reduction]
2078  // * scalable flags = [..., true, true]
2079  if (iterators.back() == utils::IteratorType::reduction) {
2080  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2081  "vectorization\n");
2082  return failure();
2083  }
2084  scalableFlags.pop_back();
2085  iterators.pop_back();
2086 
2087  if (!scalableFlags.back() ||
2088  (iterators.back() != utils::IteratorType::parallel))
2089  return failure();
2090  }
2091 
2092  // Check to not let go the matmul with extended semantic, through this
2093  // transform.
2094  if (linalgOp.hasUserDefinedMaps())
2095  return failure();
2096 
2097  // Cond 4: Only the following ops are supported in the
2098  // presence of scalable vectors
2099  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2100  isa<linalg::MatmulTransposeAOp>(op) ||
2101  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2102  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2103 }
2104 
2106  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2107  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2108  bool flatten1DDepthwiseConv) {
2109 
2110  if (!hasVectorizationImpl(op))
2111  return failure();
2112 
2113  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2114  inputScalableVecDims)))
2115  return failure();
2116 
2118  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2119  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2120  vectorizeNDExtract,
2121  flatten1DDepthwiseConv);
2122  })
2123  .Case<tensor::PadOp>([&](auto padOp) {
2124  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2125  })
2126  .Case<tensor::PackOp>([&](auto packOp) {
2127  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2128  })
2129  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2130  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2131  })
2132  .Default([](auto) { return failure(); });
2133 }
2134 
2135 /// Converts affine.apply Ops to arithmetic operations.
2136 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2137  OpBuilder::InsertionGuard g(rewriter);
2138  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2139 
2140  for (auto op : make_early_inc_range(toReplace)) {
2141  rewriter.setInsertionPoint(op);
2142  auto expanded = affine::expandAffineExpr(
2143  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2144  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2145  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2146  rewriter.replaceOp(op, expanded);
2147  }
2148 }
2149 
2151  return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2152  op);
2153 }
2154 
2155 /// Emit a suitable vector form for an operation. If provided,
2156 /// `inputVectorSizes` are used to vectorize this operation.
2157 /// `inputVectorSizes` must match the rank of the iteration space of the
2158 /// operation and the input vector sizes must be greater than or equal to
2159 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2160 /// also allows the vectorization of operations with dynamic shapes.
2161 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2162  ArrayRef<int64_t> inputVectorSizes,
2163  ArrayRef<bool> inputScalableVecDims,
2164  bool vectorizeNDExtract,
2165  bool flatten1DDepthwiseConv) {
2166  LDBG("Attempting to vectorize:\n" << *op << "\n");
2167  LDBG("Input vector sizes: ");
2168  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2169  LLVM_DEBUG(llvm::dbgs() << "\n");
2170  LDBG("Input scalable vector dims: ");
2171  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2172  LLVM_DEBUG(llvm::dbgs() << "\n");
2173 
2174  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2175  vectorizeNDExtract,
2176  flatten1DDepthwiseConv))) {
2177  LDBG("Vectorization pre-conditions failed\n");
2178  return failure();
2179  }
2180 
2181  // Initialize vectorization state.
2182  VectorizationState state(rewriter);
2183  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2184  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2185  inputScalableVecDims))) {
2186  LDBG("Vectorization state couldn't be initialized\n");
2187  return failure();
2188  }
2189  }
2190 
2191  SmallVector<Value> results;
2192  auto vectorizeResult =
2194  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2195  // TODO: isaConvolutionOpInterface that can also infer from
2196  // generic features. Will require stride/dilation attributes
2197  // inference.
2198  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2199  FailureOr<Operation *> convOr = vectorizeConvolution(
2200  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2201  flatten1DDepthwiseConv);
2202  if (succeeded(convOr)) {
2203  llvm::append_range(results, (*convOr)->getResults());
2204  return success();
2205  }
2206 
2207  LDBG("Unsupported convolution can't be vectorized.\n");
2208  return failure();
2209  }
2210 
2211  LDBG("Vectorize generic by broadcasting to the canonical vector "
2212  "shape\n");
2213 
2214  // Pre-process before proceeding.
2215  convertAffineApply(rewriter, linalgOp);
2216 
2217  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2218  // to 'OpBuilder' when it is passed over to some methods like
2219  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2220  // erase an op within these methods, the actual rewriter won't be
2221  // notified and we will end up with read-after-free issues!
2222  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2223  })
2224  .Case<tensor::PadOp>([&](auto padOp) {
2225  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2226  results);
2227  })
2228  .Case<tensor::PackOp>([&](auto packOp) {
2229  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2230  results);
2231  })
2232  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2233  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2234  inputVectorSizes, results);
2235  })
2236  .Default([](auto) { return failure(); });
2237 
2238  if (failed(vectorizeResult)) {
2239  LDBG("Vectorization failed\n");
2240  return failure();
2241  }
2242 
2243  if (!results.empty())
2244  rewriter.replaceOp(op, results);
2245  else
2246  rewriter.eraseOp(op);
2247 
2248  return success();
2249 }
2250 
2252  memref::CopyOp copyOp) {
2253  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2254  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2255  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2256  return failure();
2257 
2258  auto srcElementType = getElementTypeOrSelf(srcType);
2259  auto dstElementType = getElementTypeOrSelf(dstType);
2260  if (!VectorType::isValidElementType(srcElementType) ||
2261  !VectorType::isValidElementType(dstElementType))
2262  return failure();
2263 
2264  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2265  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2266 
2267  Location loc = copyOp->getLoc();
2268  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2269  SmallVector<Value> indices(srcType.getRank(), zero);
2270 
2271  Value readValue = rewriter.create<vector::TransferReadOp>(
2272  loc, readType, copyOp.getSource(), indices,
2273  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2274  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2275  readValue =
2276  rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2277  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2278  }
2279  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2280  loc, readValue, copyOp.getTarget(), indices,
2281  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2282  rewriter.replaceOp(copyOp, writeValue->getResults());
2283  return success();
2284 }
2285 
2286 //----------------------------------------------------------------------------//
2287 // Misc. vectorization patterns.
2288 //----------------------------------------------------------------------------//
2289 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2290 /// given operation type OpTy.
2291 template <typename OpTy>
2292 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2294 
2295  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2296  PatternRewriter &rewriter) const final {
2297  bool changed = false;
2298  // Insert users in vector, because some users may be replaced/removed.
2299  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2300  if (auto op = dyn_cast<OpTy>(user))
2301  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2302  return success(changed);
2303  }
2304 
2305 protected:
2306  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2307  tensor::PadOp padOp, OpTy op) const = 0;
2308 };
2309 
2310 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2311 /// ```
2312 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2313 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2314 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2315 /// ```
2316 /// is rewritten to:
2317 /// ```
2318 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2319 /// {in_bounds = [true, true]}
2320 /// : tensor<?x?xf32>, vector<17x5xf32>
2321 /// ```
2322 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2323 /// sure that the original padding value %cst was never used.
2324 ///
2325 /// This rewrite is possible if:
2326 /// - `xferOp` has no out-of-bounds dims or mask.
2327 /// - Low padding is static 0.
2328 /// - Single, scalar padding value.
2330  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2332  vector::TransferReadOp>::VectorizePadOpUserPattern;
2333 
2334  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2335  vector::TransferReadOp xferOp) const override {
2336  // Low padding must be static 0.
2337  if (!padOp.hasZeroLowPad())
2338  return failure();
2339  // Pad value must be a constant.
2340  auto padValue = padOp.getConstantPaddingValue();
2341  if (!padValue)
2342  return failure();
2343  // Padding value of existing `xferOp` is unused.
2344  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2345  return failure();
2346 
2347  rewriter.modifyOpInPlace(xferOp, [&]() {
2348  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2349  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2350  rewriter.getBoolArrayAttr(inBounds));
2351  xferOp.getSourceMutable().assign(padOp.getSource());
2352  xferOp.getPaddingMutable().assign(padValue);
2353  });
2354 
2355  return success();
2356  }
2357 };
2358 
2359 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2360 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2361 /// value, where the same amount of padding is immediately removed again after
2362 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2363 /// tensor value and apply out-of-bounds masking. E.g.:
2364 /// ```
2365 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2366 /// : tensor<...> to tensor<?x?xf32>
2367 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2368 /// %2 = vector.transfer_write %vec, %1[...]
2369 /// : vector<17x5xf32>, tensor<17x5xf32>
2370 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2371 /// : tensor<17x5xf32> to tensor<?x?xf32>
2372 /// ```
2373 /// is rewritten to:
2374 /// ```
2375 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2376 /// : tensor<...> to tensor<?x?xf32>
2377 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2378 /// tensor<?x?xf32>
2379 /// ```
2380 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2381 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2382 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2383 /// from %r's old dimensions.
2384 ///
2385 /// This rewrite is possible if:
2386 /// - Low padding is static 0.
2387 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2388 /// ExtractSliceOp trims the same amount of padding that was added
2389 /// beforehand.
2390 /// - Single, scalar padding value.
2392  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2394  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2395 
2396  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2397  vector::TransferWriteOp xferOp) const override {
2398  // TODO: support 0-d corner case.
2399  if (xferOp.getTransferRank() == 0)
2400  return failure();
2401 
2402  // Low padding must be static 0.
2403  if (!padOp.hasZeroLowPad())
2404  return failure();
2405  // Pad value must be a constant.
2406  auto padValue = padOp.getConstantPaddingValue();
2407  if (!padValue)
2408  return failure();
2409  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2410  if (!xferOp->hasOneUse())
2411  return failure();
2412  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2413  if (!trimPadding)
2414  return failure();
2415  // Only static zero offsets supported when trimming padding.
2416  if (!trimPadding.hasZeroOffset())
2417  return failure();
2418  // trimPadding must remove the amount of padding that was added earlier.
2419  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2420  return failure();
2421 
2422  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2423  rewriter.setInsertionPoint(xferOp);
2424 
2425  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2426  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2427  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2428  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2429  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2430  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2431 
2432  return success();
2433  }
2434 
2435  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2436  /// i.e., same dimensions.
2437  ///
2438  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2439  /// dimensions, this function tries to infer the (static) tensor size by
2440  /// looking at the defining op and utilizing op-specific knowledge.
2441  ///
2442  /// This is a conservative analysis. In case equal tensor sizes cannot be
2443  /// proven statically, this analysis returns `false` even though the tensor
2444  /// sizes may turn out to be equal at runtime.
2445  bool hasSameTensorSize(Value beforePadding,
2446  tensor::ExtractSliceOp afterTrimming) const {
2447  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2448  // result and CastOp operand.
2449  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2450  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2451  return true;
2452 
2453  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2454  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2455  // Only RankedTensorType supported.
2456  if (!t1 || !t2)
2457  return false;
2458  // Rank of both values must be the same.
2459  if (t1.getRank() != t2.getRank())
2460  return false;
2461 
2462  // All static dimensions must be the same. Mixed cases (e.g., dimension
2463  // static in `t1` but dynamic in `t2`) are not supported.
2464  for (unsigned i = 0; i < t1.getRank(); ++i) {
2465  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2466  return false;
2467  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2468  return false;
2469  }
2470 
2471  // Nothing more to check if all dimensions are static.
2472  if (t1.getNumDynamicDims() == 0)
2473  return true;
2474 
2475  // All dynamic sizes must be the same. The only supported case at the
2476  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2477  // thereof).
2478 
2479  // Apart from CastOp, only ExtractSliceOp is supported.
2480  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2481  if (!beforeSlice)
2482  return false;
2483 
2484  assert(static_cast<size_t>(t1.getRank()) ==
2485  beforeSlice.getMixedSizes().size());
2486  assert(static_cast<size_t>(t2.getRank()) ==
2487  afterTrimming.getMixedSizes().size());
2488 
2489  for (unsigned i = 0; i < t1.getRank(); ++i) {
2490  // Skip static dimensions.
2491  if (!t1.isDynamicDim(i))
2492  continue;
2493  auto size1 = beforeSlice.getMixedSizes()[i];
2494  auto size2 = afterTrimming.getMixedSizes()[i];
2495 
2496  // Case 1: Same value or same constant int.
2497  if (isEqualConstantIntOrValue(size1, size2))
2498  continue;
2499 
2500  // Other cases: Take a deeper look at defining ops of values.
2501  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2502  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2503  if (!v1 || !v2)
2504  return false;
2505 
2506  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2507  // CSE is run.)
2508  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2509  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2510  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2511  minOp1.getOperands() == minOp2.getOperands())
2512  continue;
2513 
2514  // Add additional cases as needed.
2515  }
2516 
2517  // All tests passed.
2518  return true;
2519  }
2520 };
2521 
2522 /// Returns the effective Pad value for the input op, provided it's a scalar.
2523 ///
2524 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2525 /// this Op performs padding, retrieve the padding value provided that it's
2526 /// a scalar and static/fixed for all the padded values. Returns an empty value
2527 /// otherwise.
2529  if (!op)
2530  return {};
2531 
2532  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2533  // being broadcast, provided that it's a scalar.
2534  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2535  auto source = bcast.getSource();
2536  if (llvm::dyn_cast<VectorType>(source.getType()))
2537  return {};
2538 
2539  return source;
2540  }
2541 
2542  // 2. linalg.fill - use the scalar input value that used to fill the output
2543  // tensor.
2544  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2545  return fill.getInputs()[0];
2546  }
2547 
2548  // 3. tensor.generateOp - can't guarantee the value is fixed without
2549  // analysing, bail out.
2550  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2551  return {};
2552  }
2553 
2554  // 4. vector.transfer_write - inspect the input vector that's written from. If
2555  // if contains a single value that has been broadcast (e.g. via
2556  // vector.broadcast), extract it, fail otherwise.
2557  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2558  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2559 
2560  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2561  // than the input tensor, then, provided it's constant, we'll extract the
2562  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2563  // TODO: Clarify the semantics when the input tensor is larger than the
2564  // destination.
2565  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2566  return getStaticPadVal(slice.getDest().getDefiningOp());
2567 
2568  return {};
2569 }
2570 
2571 /// Rewrite tensor.insert.slice as a vector.transfer_read +
2572 /// vector.transfer_write pair. The vector size is inferred from the static
2573 /// dims in the input and output tensors. If a dim is dynamic in both the input
2574 /// and output tensors, bails out.
2575 ///
2576 /// Before:
2577 /// !t_in_type = tensor<1x2x3xf32>
2578 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
2579 /// !v_type = vector<1x2x3xf32>
2580 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2581 /// into !t_out_type
2582 /// After:
2583 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2584 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2585 ///
2586 /// TODO: Support masking
2588  : public OpRewritePattern<tensor::InsertSliceOp> {
2590 
2591  LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2592  PatternRewriter &rewriter) const final {
2593  auto sourceType = sliceOp.getSource().getType();
2594  if (!VectorType::isValidElementType(sourceType.getElementType()))
2595  return failure();
2596 
2597  auto resultType = sliceOp.getResultType();
2598 
2599  // 1. Get the pad value.
2600  // TransferReadOp requires a scalar padding value. Note that:
2601  // * for in-bounds access, the value is actually irrelevant.
2602  // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2603  // 1. The source shape is static (output vector sizes would be based on
2604  // the source shape and hence all memory accesses would be in-bounds),
2605  // 2. Masking is used (output vector sizes would be user-provided, in which
2606  // case it is assumed that all memory accesses are in-bounds). This
2607  // remains a TODO.
2608  //
2609  // When the value is not known and not needed, use 0. Otherwise, bail out.
2610  Value padValue = getStaticPadVal(sliceOp);
2611  bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2612 
2613  if (!padValue && isOutOfBoundsRead) {
2614  LDBG("Failed to get a pad value for out-of-bounds read access\n");
2615  return failure();
2616  }
2617 
2618  if (!padValue) {
2619  auto elemType = sourceType.getElementType();
2620  padValue = rewriter.create<arith::ConstantOp>(
2621  sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2622  }
2623 
2624  // 2. Get the vector shape and in-bounds attributes
2625  SmallVector<int64_t> vecShape;
2626  SmallVector<bool> readInBounds;
2627  SmallVector<bool> writeInBounds;
2628  size_t rankDiff = resultType.getRank() - sourceType.getRank();
2629  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2630  if (!sourceType.isDynamicDim(i)) {
2631  vecShape.push_back(sourceType.getDimSize(i));
2632  // Source shape is statically known: Neither read nor write are
2633  // out-of-bounds.
2634  readInBounds.push_back(true);
2635  writeInBounds.push_back(true);
2636  } else if (!resultType.isDynamicDim(i)) {
2637  // Source shape is not statically known, but result shape is.
2638  // Vectorize with size of result shape. This may be larger than the
2639  // source size.
2640  // FIXME: Using rankDiff implies that the source tensor is inserted at
2641  // the end of the destination tensor. However, that's not required.
2642  vecShape.push_back(resultType.getDimSize(rankDiff + i));
2643  // Read may be out-of-bounds because the result size could be larger
2644  // than the source size.
2645  readInBounds.push_back(false);
2646  // Write will in-bounds provided that the corresponding write idx is 0.
2647  // To keep this logic simple, conservatively mark as out-of-bounds.
2648  writeInBounds.push_back(false);
2649  } else {
2650  // Neither source nor result dim of padOp is static. Cannot vectorize
2651  // the copy.
2652  // TODO: Add support for masking
2653  return failure();
2654  }
2655  }
2656  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2657 
2658  // 3. Generate TransferReadOp.
2659  SmallVector<Value> readIndices(
2660  vecType.getRank(),
2661  rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2662  auto read = rewriter.create<vector::TransferReadOp>(
2663  sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2664  ArrayRef<bool>{readInBounds});
2665 
2666  // 4. Generate TransferWriteOp.
2667  auto writeIndices = getValueOrCreateConstantIndexOp(
2668  rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2669 
2670  // 5. Finalize
2671  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2672  sliceOp, read, sliceOp.getDest(), writeIndices,
2673  ArrayRef<bool>{writeInBounds});
2674 
2675  return success();
2676  }
2677 };
2678 
2679 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2680 /// ```
2681 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2682 /// %r = tensor.insert_slice %0
2683 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2684 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2685 /// ```
2686 /// is rewritten to:
2687 /// ```
2688 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2689 /// : tensor<?x?xf32>, vector<17x5xf32>
2690 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2691 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2692 /// ```
2693 ///
2694 /// This rewrite is possible if:
2695 /// - Low padding is static 0.
2696 /// - `padOp` result shape is static.
2697 /// - The entire padded tensor is inserted.
2698 /// (Implies that sizes of `insertOp` are all static.)
2699 /// - Only unit strides in `insertOp`.
2700 /// - Single, scalar padding value.
2701 /// - `padOp` result not used as destination.
2703  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2705  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2706 
2707  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2708  tensor::InsertSliceOp insertOp) const override {
2709  // Low padding must be static 0.
2710  if (!padOp.hasZeroLowPad())
2711  return failure();
2712  // Only unit stride supported.
2713  if (!insertOp.hasUnitStride())
2714  return failure();
2715  // Pad value must be a constant.
2716  auto padValue = padOp.getConstantPaddingValue();
2717  if (!padValue)
2718  return failure();
2719  // Dynamic shapes not supported.
2720  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2721  return failure();
2722  // Pad result not used as destination.
2723  if (insertOp.getDest() == padOp.getResult())
2724  return failure();
2725 
2726  auto vecType = VectorType::get(padOp.getType().getShape(),
2727  padOp.getType().getElementType());
2728  unsigned vecRank = vecType.getRank();
2729  unsigned tensorRank = insertOp.getType().getRank();
2730 
2731  // Check if sizes match: Insert the entire tensor into most minor dims.
2732  // (No permutations allowed.)
2733  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2734  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2735  if (!llvm::all_of(
2736  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2737  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2738  }))
2739  return failure();
2740 
2741  // Insert the TransferReadOp and TransferWriteOp at the position of the
2742  // InsertSliceOp.
2743  rewriter.setInsertionPoint(insertOp);
2744 
2745  // Generate TransferReadOp: Read entire source tensor and add high
2746  // padding.
2747  SmallVector<Value> readIndices(
2748  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2749  auto read = rewriter.create<vector::TransferReadOp>(
2750  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2751 
2752  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2753  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2754  // source must fit into the destination at the specified offsets.
2755  auto writeIndices = getValueOrCreateConstantIndexOp(
2756  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2757  SmallVector<bool> inBounds(vecRank, true);
2758  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2759  insertOp, read, insertOp.getDest(), writeIndices,
2760  ArrayRef<bool>{inBounds});
2761 
2762  return success();
2763  }
2764 };
2765 
2767  RewritePatternSet &patterns) {
2768  patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2769 }
2770 
2772  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2773  // TODO: The following pattern implements "decomposition" and
2774  // optional "vectorization". Seperate "decomposition" into a sepereate
2775  // pre-processing pattern group.
2776  patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
2777 
2778  // Try these specialized patterns first before resorting to the generic one.
2782  patterns.getContext(), baseBenefit.getBenefit() + 1);
2783 }
2784 
2785 //----------------------------------------------------------------------------//
2786 // Forwarding patterns
2787 //----------------------------------------------------------------------------//
2788 
2789 /// Check whether there is any interleaved use of any `values` between
2790 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2791 /// is in a different block.
2792 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2793  ValueRange values) {
2794  if (firstOp->getBlock() != secondOp->getBlock() ||
2795  !firstOp->isBeforeInBlock(secondOp)) {
2796  LDBG("interleavedUses precondition failed, firstOp: "
2797  << *firstOp << ", second op: " << *secondOp << "\n");
2798  return true;
2799  }
2800  for (auto v : values) {
2801  for (auto &u : v.getUses()) {
2802  Operation *owner = u.getOwner();
2803  if (owner == firstOp || owner == secondOp)
2804  continue;
2805  // TODO: this is too conservative, use dominance info in the future.
2806  if (owner->getBlock() == firstOp->getBlock() &&
2807  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2808  continue;
2809  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2810  << ", second op: " << *secondOp << "\n");
2811  return true;
2812  }
2813  }
2814  return false;
2815 }
2816 
2817 /// Return the unique subview use of `v` if it is indeed unique, null
2818 /// otherwise.
2819 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2820  memref::SubViewOp subViewOp;
2821  for (auto &u : v.getUses()) {
2822  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2823  if (subViewOp)
2824  return memref::SubViewOp();
2825  subViewOp = newSubViewOp;
2826  }
2827  }
2828  return subViewOp;
2829 }
2830 
2831 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2832 /// when available.
2834  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2835 
2836  // TODO: support mask.
2837  if (xferOp.getMask())
2838  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2839 
2840  // Transfer into `view`.
2841  Value viewOrAlloc = xferOp.getSource();
2842  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2843  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2844  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2845 
2846  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2847  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2848  if (!subViewOp)
2849  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2850  Value subView = subViewOp.getResult();
2851 
2852  // Find the copy into `subView` without interleaved uses.
2853  memref::CopyOp copyOp;
2854  for (auto &u : subView.getUses()) {
2855  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2856  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2857  if (newCopyOp.getTarget() != subView)
2858  continue;
2859  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2860  continue;
2861  copyOp = newCopyOp;
2862  break;
2863  }
2864  }
2865  if (!copyOp)
2866  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2867 
2868  // Find the fill into `viewOrAlloc` without interleaved uses before the
2869  // copy.
2870  FillOp maybeFillOp;
2871  for (auto &u : viewOrAlloc.getUses()) {
2872  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2873  assert(isa<MemRefType>(newFillOp.output().getType()));
2874  if (newFillOp.output() != viewOrAlloc)
2875  continue;
2876  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2877  continue;
2878  maybeFillOp = newFillOp;
2879  break;
2880  }
2881  }
2882  // Ensure padding matches.
2883  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2884  return rewriter.notifyMatchFailure(xferOp,
2885  "padding value does not match fill");
2886 
2887  // `in` is the subview that memref.copy reads. Replace it.
2888  Value in = copyOp.getSource();
2889 
2890  // memref.copy + linalg.fill can be used to create a padded local buffer.
2891  // The `masked` attribute is only valid on this padded buffer.
2892  // When forwarding to vector.transfer_read, the attribute must be reset
2893  // conservatively.
2894  auto vectorType = xferOp.getVectorType();
2895  Value res = rewriter.create<vector::TransferReadOp>(
2896  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2897  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2898  rewriter.getBoolArrayAttr(
2899  SmallVector<bool>(vectorType.getRank(), false)));
2900 
2901  if (maybeFillOp)
2902  rewriter.eraseOp(maybeFillOp);
2903  rewriter.eraseOp(copyOp);
2904  rewriter.replaceOp(xferOp, res);
2905 
2906  return success();
2907 }
2908 
2909 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2910 /// when available.
2912  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2913  // TODO: support mask.
2914  if (xferOp.getMask())
2915  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2916 
2917  // Transfer into `viewOrAlloc`.
2918  Value viewOrAlloc = xferOp.getSource();
2919  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2920  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2921  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2922 
2923  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2924  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2925  if (!subViewOp)
2926  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2927  Value subView = subViewOp.getResult();
2928 
2929  // Find the copy from `subView` without interleaved uses.
2930  memref::CopyOp copyOp;
2931  for (auto &u : subViewOp.getResult().getUses()) {
2932  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2933  if (newCopyOp.getSource() != subView)
2934  continue;
2935  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2936  continue;
2937  copyOp = newCopyOp;
2938  break;
2939  }
2940  }
2941  if (!copyOp)
2942  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2943 
2944  // `out` is the subview copied into that we replace.
2945  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2946  Value out = copyOp.getTarget();
2947 
2948  // Forward vector.transfer into copy.
2949  // memref.copy + linalg.fill can be used to create a padded local buffer.
2950  // The `masked` attribute is only valid on this padded buffer.
2951  // When forwarding to vector.transfer_write, the attribute must be reset
2952  // conservatively.
2953  auto vector = xferOp.getVector();
2954  rewriter.create<vector::TransferWriteOp>(
2955  xferOp.getLoc(), vector, out, xferOp.getIndices(),
2956  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2957  rewriter.getBoolArrayAttr(
2958  SmallVector<bool>(vector.getType().getRank(), false)));
2959 
2960  rewriter.eraseOp(copyOp);
2961  rewriter.eraseOp(xferOp);
2962 
2963  return success();
2964 }
2965 
2966 //===----------------------------------------------------------------------===//
2967 // Convolution vectorization patterns
2968 //===----------------------------------------------------------------------===//
2969 
2970 template <int N>
2971 static void bindShapeDims(ShapedType shapedType) {}
2972 
2973 template <int N, typename IntTy, typename... IntTy2>
2974 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2975  val = shapedType.getShape()[N];
2976  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2977 }
2978 
2979 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2980 template <typename... IntTy>
2981 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2982  bindShapeDims<0>(shapedType, vals...);
2983 }
2984 
2985 namespace {
2986 bool isCastOfBlockArgument(Operation *op) {
2987  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2988  isa<BlockArgument>(op->getOperand(0));
2989 }
2990 
2991 bool isSupportedPoolKind(vector::CombiningKind kind) {
2992  switch (kind) {
2993  case vector::CombiningKind::ADD:
2994  case vector::CombiningKind::MAXNUMF:
2995  case vector::CombiningKind::MAXIMUMF:
2996  case vector::CombiningKind::MAXSI:
2997  case vector::CombiningKind::MAXUI:
2998  case vector::CombiningKind::MINNUMF:
2999  case vector::CombiningKind::MINIMUMF:
3000  case vector::CombiningKind::MINSI:
3002  return true;
3003  default:
3004  return false;
3005  }
3006 }
3007 
3008 /// Generate a vector implementation for either:
3009 /// ```
3010 /// Op def: ( w, kw )
3011 /// Iters: ({Par(), Red()})
3012 /// Layout: {{w + kw}, {kw}, {w}}
3013 /// ```
3014 /// kw is unrolled.
3015 ///
3016 /// or
3017 ///
3018 /// ```
3019 /// Op def: ( n, w, c, kw, f )
3020 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3021 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3022 /// ```
3023 /// kw is unrolled, w is unrolled iff dilationW > 1.
3024 ///
3025 /// or
3026 ///
3027 /// ```
3028 /// Op def: ( n, c, w, f, kw )
3029 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3030 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3031 /// ```
3032 /// kw is unrolled, w is unrolled iff dilationW > 1.
3033 ///
3034 /// or
3035 ///
3036 /// ```
3037 /// Op def: ( n, w, c, kw )
3038 /// Iters: ({Par(), Par(), Par(), Red()})
3039 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3040 /// ```
3041 /// kw is unrolled, w is unrolled iff dilationW > 1.
3042 struct Conv1DGenerator
3043  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3044  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3045  int dilationW)
3046  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3047  strideW(strideW), dilationW(dilationW) {
3048  // Determine whether `linalgOp` can be generated with this generator
3049  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3050  return;
3051  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3052  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3053  resShaped = linalgOp.getDpsInitOperand(0)->get();
3054  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3055  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3056  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3057  if (!lhsShapedType || !rhsShapedType || !resShapedType)
3058  return;
3059  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3060  // (non-channeled convolution -> LHS and RHS both have single dimensions).
3061  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3062  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3063  return;
3064 
3065  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3066  if (!reduceOp)
3067  return;
3068  redOp = reduceOp->getName().getIdentifier();
3069 
3070  if (!setOperKind(reduceOp))
3071  return;
3072  auto maybeKind = getCombinerOpKind(reduceOp);
3073  // Typically convolution will have a `Add` CombiningKind but for i1 type it
3074  // can get strength reduced to `OR` which is also supported. This strength
3075  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3076  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3077  *maybeKind != vector::CombiningKind::OR) &&
3078  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3079  return;
3080  }
3081  reductionKind = maybeKind.value();
3082 
3083  auto rhsRank = rhsShapedType.getRank();
3084  switch (oper) {
3085  case Conv:
3086  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3087  return;
3088  break;
3089  case Pool:
3090  if (rhsRank != 1)
3091  return;
3092  break;
3093  }
3094  // The op is now known to be valid.
3095  valid = true;
3096  }
3097 
3098  /// Generate a vector implementation for:
3099  /// ```
3100  /// Op def: ( w, kw )
3101  /// Iters: ({Par(), Red()})
3102  /// Layout: {{w + kw}, {kw}, {w}}
3103  /// ```
3104  /// kw is always unrolled.
3105  ///
3106  /// or
3107  ///
3108  /// ```
3109  /// Op def: ( n, w, c, kw, f )
3110  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3111  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3112  /// ```
3113  /// kw is always unrolled.
3114  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3115  /// > 1.
3116  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3117  if (!valid)
3118  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3119 
3120  int64_t nSize, wSize, cSize, kwSize, fSize;
3121  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3122  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3123  switch (conv1DOpOrder) {
3124  case Conv1DOpOrder::W:
3125  // Initialize unused dimensions
3126  nSize = fSize = cSize = 0;
3127  // out{W}
3128  bindShapeDims(resShapedType, wSize);
3129  // kernel{kw}
3130  bindShapeDims(rhsShapedType, kwSize);
3131  lhsShape = {// iw = ow + kw - 1
3132  // (i.e. 16 convolved with 3 -> 14)
3133  (wSize + kwSize - 1)};
3134  rhsShape = {kwSize};
3135  resShape = {wSize};
3136  break;
3137  case Conv1DOpOrder::Nwc:
3138  // out{n, w, f}
3139  bindShapeDims(resShapedType, nSize, wSize, fSize);
3140  switch (oper) {
3141  case Conv:
3142  // kernel{kw, c, f}
3143  bindShapeDims(rhsShapedType, kwSize, cSize);
3144  break;
3145  case Pool:
3146  // kernel{kw}
3147  bindShapeDims(rhsShapedType, kwSize);
3148  cSize = fSize;
3149  break;
3150  }
3151  lhsShape = {nSize,
3152  // iw = ow * sw + kw * dw - 1
3153  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3154  // Perform the proper inclusive -> exclusive -> inclusive.
3155  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3156  1,
3157  cSize};
3158  switch (oper) {
3159  case Conv:
3160  rhsShape = {kwSize, cSize, fSize};
3161  break;
3162  case Pool:
3163  rhsShape = {kwSize};
3164  break;
3165  }
3166  resShape = {nSize, wSize, fSize};
3167  break;
3168  case Conv1DOpOrder::Ncw:
3169  // out{n, f, w}
3170  bindShapeDims(resShapedType, nSize, fSize, wSize);
3171  switch (oper) {
3172  case Conv:
3173  // kernel{f, c, kw}
3174  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3175  break;
3176  case Pool:
3177  // kernel{kw}
3178  bindShapeDims(rhsShapedType, kwSize);
3179  cSize = fSize;
3180  break;
3181  }
3182  lhsShape = {nSize, cSize,
3183  // iw = ow * sw + kw * dw - 1
3184  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3185  // Perform the proper inclusive -> exclusive -> inclusive.
3186  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3187  1};
3188  switch (oper) {
3189  case Conv:
3190  rhsShape = {fSize, cSize, kwSize};
3191  break;
3192  case Pool:
3193  rhsShape = {kwSize};
3194  break;
3195  }
3196  resShape = {nSize, fSize, wSize};
3197  break;
3198  }
3199 
3200  vector::TransferWriteOp write;
3201  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3202 
3203  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3204  // When strideW == 1, we can batch the contiguous loads and avoid
3205  // unrolling
3206  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3207 
3208  Type lhsEltType = lhsShapedType.getElementType();
3209  Type rhsEltType = rhsShapedType.getElementType();
3210  Type resEltType = resShapedType.getElementType();
3211  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3212  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3213  auto resType = VectorType::get(resShape, resEltType);
3214  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3215  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3216  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3217  SmallVector<Value> resPadding(resShape.size(), zero);
3218 
3219  // Read the whole lhs, rhs and res in one shot (with zero padding).
3220  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3221  lhsPadding);
3222  // This is needed only for Conv.
3223  Value rhs = nullptr;
3224  if (oper == Conv)
3225  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3226  rhsPadding);
3227  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3228  resPadding);
3229 
3230  // The base vectorization case for channeled convolution is input:
3231  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3232  // vectorization case, we do pre transpose on input, weight, and output.
3233  switch (conv1DOpOrder) {
3234  case Conv1DOpOrder::W:
3235  case Conv1DOpOrder::Nwc:
3236  // Base case, so no transposes necessary.
3237  break;
3238  case Conv1DOpOrder::Ncw: {
3239  // To match base vectorization case, we pre-transpose current case.
3240  // ncw -> nwc
3241  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3242  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3243  // fcw -> wcf
3244  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3245 
3246  // This is needed only for Conv.
3247  if (oper == Conv)
3248  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3249  // nfw -> nwf
3250  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3251  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3252  break;
3253  }
3254  }
3255 
3256  //===------------------------------------------------------------------===//
3257  // Begin vector-only rewrite part
3258  //===------------------------------------------------------------------===//
3259  // Unroll along kw and read slices of lhs and rhs.
3260  SmallVector<Value> lhsVals, rhsVals, resVals;
3261  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3262  kwSize, strideW, dilationW, wSizeStep,
3263  isSingleChanneled);
3264  // Do not do for pooling.
3265  if (oper == Conv)
3266  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3267  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3268  wSizeStep, isSingleChanneled);
3269 
3270  auto linearIndex = [&](int64_t kw, int64_t w) {
3271  return kw * (wSize / wSizeStep) + w;
3272  };
3273 
3274  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3275  // or perform outerproduct for non-channeled convolution or perform simple
3276  // arith operation for pooling
3277  for (int64_t kw = 0; kw < kwSize; ++kw) {
3278  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3279  switch (oper) {
3280  case Conv:
3281  if (isSingleChanneled) {
3282  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3283  lhsVals[linearIndex(kw, w)],
3284  rhsVals[kw], resVals[w]);
3285  } else {
3286  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3287  lhsVals[linearIndex(kw, w)],
3288  rhsVals[kw], resVals[w]);
3289  }
3290  break;
3291  case Pool:
3292  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3293  resVals[w]);
3294  break;
3295  }
3296  }
3297  }
3298 
3299  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3300  isSingleChanneled);
3301  //===------------------------------------------------------------------===//
3302  // End vector-only rewrite part
3303  //===------------------------------------------------------------------===//
3304 
3305  // The base vectorization case for channeled convolution is output:
3306  // {n,w,f} To reuse the result from base pattern vectorization case, we
3307  // post transpose the base case result.
3308  switch (conv1DOpOrder) {
3309  case Conv1DOpOrder::W:
3310  case Conv1DOpOrder::Nwc:
3311  // Base case, so no transposes necessary.
3312  break;
3313  case Conv1DOpOrder::Ncw: {
3314  // nwf -> nfw
3315  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3316  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3317  break;
3318  }
3319  }
3320 
3321  return rewriter
3322  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3323  .getOperation();
3324  }
3325 
3326  // Take a value and widen to have the same element type as `ty`.
3327  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3328  const Type srcElementType = getElementTypeOrSelf(val.getType());
3329  const Type dstElementType = getElementTypeOrSelf(ty);
3330  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3331  if (srcElementType == dstElementType)
3332  return val;
3333 
3334  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3335  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3336  const Type dstType =
3337  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3338 
3339  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3340  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3341  }
3342 
3343  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3344  srcWidth < dstWidth)
3345  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3346 
3347  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3348  srcWidth < dstWidth)
3349  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3350 
3351  assert(false && "unhandled promotion case");
3352  return nullptr;
3353  }
3354 
3355  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3356  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3357  Value lhs, Value rhs, Value res) {
3358  vector::IteratorType par = vector::IteratorType::parallel;
3359  vector::IteratorType red = vector::IteratorType::reduction;
3360  AffineExpr n, w, f, c;
3361  bindDims(ctx, n, w, f, c);
3362  lhs = promote(rewriter, loc, lhs, res.getType());
3363  rhs = promote(rewriter, loc, rhs, res.getType());
3364  auto contrationOp = rewriter.create<vector::ContractionOp>(
3365  loc, lhs, rhs, res,
3366  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3367  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3368  contrationOp.setKind(reductionKind);
3369  return contrationOp;
3370  }
3371 
3372  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3373  // convolution.
3374  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3375  Value lhs, Value rhs, Value res) {
3376  return rewriter.create<vector::OuterProductOp>(
3377  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3378  }
3379 
3380  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3381  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3382  Value res) {
3383  if (isPoolExt)
3384  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3385  return rewriter
3386  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3387  ->getResult(0);
3388  }
3389 
3390  /// Generate a vector implementation for:
3391  /// ```
3392  /// Op def: ( n, w, c, kw)
3393  /// Iters: ({Par(), Par(), Par(), Red()})
3394  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3395  /// ```
3396  /// kw is always unrolled.
3397  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3398  /// > 1.
3399  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3400  bool channelDimScalableFlag,
3401  bool flatten) {
3402  if (!valid)
3403  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3404 
3405  bool scalableChDim = false;
3406  bool useMasking = false;
3407  int64_t nSize, wSize, cSize, kwSize;
3408  // kernel{kw, c}
3409  bindShapeDims(rhsShapedType, kwSize, cSize);
3410  if (ShapedType::isDynamic(cSize)) {
3411  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3412  cSize = channelDimVecSize;
3413  // Scalable vectors are only used when both conditions are met:
3414  // 1. channel dim is dynamic
3415  // 2. channelDimScalableFlag is set
3416  scalableChDim = channelDimScalableFlag;
3417  useMasking = true;
3418  }
3419 
3420  assert(!(useMasking && flatten) &&
3421  "Unsupported flattened conv with dynamic shapes");
3422 
3423  // out{n, w, c}
3424  bindShapeDims(resShapedType, nSize, wSize);
3425 
3426  vector::TransferWriteOp write;
3427  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3428 
3429  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3430  // When strideW == 1, we can batch the contiguous loads and avoid
3431  // unrolling
3432  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3433 
3434  Type lhsEltType = lhsShapedType.getElementType();
3435  Type rhsEltType = rhsShapedType.getElementType();
3436  Type resEltType = resShapedType.getElementType();
3437  VectorType lhsType = VectorType::get(
3438  {nSize,
3439  // iw = ow * sw + kw * dw - 1
3440  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3441  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3442  cSize},
3443  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3444  VectorType rhsType =
3445  VectorType::get({kwSize, cSize}, rhsEltType,
3446  /*scalableDims=*/{false, scalableChDim});
3447  VectorType resType =
3448  VectorType::get({nSize, wSize, cSize}, resEltType,
3449  /*scalableDims=*/{false, false, scalableChDim});
3450 
3451  // Masks the input xfer Op along the channel dim, iff the corresponding
3452  // scalable flag is set.
3453  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3454  ArrayRef<bool> scalableDims,
3455  Operation *opToMask) {
3456  if (!useMasking)
3457  return opToMask;
3458  auto maskType =
3459  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3460 
3461  SmallVector<bool> inBounds(maskShape.size(), true);
3462  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3463  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3464  rewriter.getBoolArrayAttr(inBounds));
3465 
3467  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3468 
3469  Value maskOp =
3470  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3471 
3472  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3473  };
3474 
3475  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3476  // 0].
3477  Value lhs = rewriter.create<vector::TransferReadOp>(
3478  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3479  auto maybeMaskedLhs = maybeMaskXferOp(
3480  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3481 
3482  // Read rhs slice of size {kw, c} @ [0, 0].
3483  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3484  ValueRange{zero, zero});
3485  auto maybeMaskedRhs = maybeMaskXferOp(
3486  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3487 
3488  // Read res slice of size {n, w, c} @ [0, 0, 0].
3489  Value res = rewriter.create<vector::TransferReadOp>(
3490  loc, resType, resShaped, ValueRange{zero, zero, zero});
3491  auto maybeMaskedRes = maybeMaskXferOp(
3492  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3493 
3494  //===------------------------------------------------------------------===//
3495  // Begin vector-only rewrite part
3496  //===------------------------------------------------------------------===//
3497  // Unroll along kw and read slices of lhs and rhs.
3498  SmallVector<Value> lhsVals, rhsVals, resVals;
3499  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3500  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3501 
3502  // Extract lhs slice of size {n, wSizeStep, c}
3503  // @ [0, sw * w + dw * kw, 0].
3504  for (int64_t kw = 0; kw < kwSize; ++kw) {
3505  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3506  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3507  loc, maybeMaskedLhs->getResult(0),
3508  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3509  inOutSliceSizes, inOutStrides));
3510  }
3511  }
3512  // Extract rhs slice of size {c} @ [kw].
3513  for (int64_t kw = 0; kw < kwSize; ++kw) {
3514  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3515  loc, maybeMaskedRhs->getResult(0),
3516  /*offsets=*/ArrayRef<int64_t>{kw}));
3517  }
3518  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3519  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3520  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3521  loc, maybeMaskedRes->getResult(0),
3522  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3523  inOutStrides));
3524  }
3525 
3526  auto linearIndex = [&](int64_t kw, int64_t w) {
3527  return kw * (wSize / wSizeStep) + w;
3528  };
3529 
3530  // Note - the scalable flags are ignored as flattening combined with
3531  // scalable vectorization is not supported.
3532  auto inOutFlattenSliceSizes =
3533  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3534  auto lhsTypeAfterFlattening =
3535  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3536  auto resTypeAfterFlattening =
3537  VectorType::get(inOutFlattenSliceSizes, resEltType);
3538 
3539  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3540  for (int64_t kw = 0; kw < kwSize; ++kw) {
3541  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3542  Value lhsVal = lhsVals[linearIndex(kw, w)];
3543  Value resVal = resVals[w];
3544  if (flatten) {
3545  // Flatten the input and output vectors (collapse the channel
3546  // dimension)
3547  lhsVal = rewriter.create<vector::ShapeCastOp>(
3548  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3549  resVal = rewriter.create<vector::ShapeCastOp>(
3550  loc, resTypeAfterFlattening, resVals[w]);
3551  }
3552  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3553  rhsVals[kw], resVal, flatten);
3554  if (flatten) {
3555  // Un-flatten the output vector (restore the channel dimension)
3556  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3557  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3558  }
3559  }
3560  }
3561 
3562  // Its possible we failed to create the Fma.
3563  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3564  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3565  for (auto &collection :
3566  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3567  for (Value v : collection)
3568  rewriter.eraseOp(v.getDefiningOp());
3569  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3570  }
3571 
3572  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3573  // This does not depend on kw.
3574  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3575  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3576  loc, resVals[w], maybeMaskedRes->getResult(0),
3577  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3578  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3579  }
3580  //===------------------------------------------------------------------===//
3581  // End vector-only rewrite part
3582  //===------------------------------------------------------------------===//
3583 
3584  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3585  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3586  loc, maybeMaskedRes->getResult(0), resShaped,
3587  ValueRange{zero, zero, zero});
3588  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3589  resOut);
3590  }
3591 
3592  /// Lower:
3593  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3594  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3595  /// to MulAcc.
3596  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3597  Value lhs, Value rhs, Value res,
3598  bool flatten) {
3599  auto rhsTy = cast<ShapedType>(rhs.getType());
3600  auto resTy = cast<ShapedType>(res.getType());
3601 
3602  // TODO(suderman): Change this to use a vector.ima intrinsic.
3603  lhs = promote(rewriter, loc, lhs, resTy);
3604 
3605  if (flatten) {
3606  // NOTE: This following logic won't work for scalable vectors. For this
3607  // reason, "flattening" is not supported when shapes are dynamic (this
3608  // should be captured by one of the pre-conditions).
3609 
3610  // There are two options for handling the filter:
3611  // * shape_cast(broadcast(filter))
3612  // * broadcast(shuffle(filter))
3613  // Opt for the option without shape_cast to simplify the codegen.
3614  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3615  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3616 
3617  SmallVector<int64_t, 16> indices;
3618  for (int i = 0; i < resSize / rhsSize; ++i) {
3619  for (int j = 0; j < rhsSize; ++j)
3620  indices.push_back(j);
3621  }
3622 
3623  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3624  }
3625  // Broadcast the filter to match the output vector
3626  rhs = rewriter.create<vector::BroadcastOp>(
3627  loc, resTy.clone(rhsTy.getElementType()), rhs);
3628 
3629  rhs = promote(rewriter, loc, rhs, resTy);
3630 
3631  if (!lhs || !rhs)
3632  return nullptr;
3633 
3634  if (isa<FloatType>(resTy.getElementType()))
3635  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3636 
3637  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3638  return rewriter.create<arith::AddIOp>(loc, mul, res);
3639  }
3640 
3641  /// Entry point for non-channeled convolution:
3642  /// {{w + kw}, {kw}, {w}}
3643  FailureOr<Operation *> generateNonChanneledConv() {
3644  AffineExpr w, kw;
3645  bindDims(ctx, w, kw);
3646  if (!iters({Par(), Red()}))
3647  return rewriter.notifyMatchFailure(op,
3648  "failed to match conv::W 1-par 1-red");
3649 
3650  // No transposition needed.
3651  if (layout({/*lhsIndex*/ {w + kw},
3652  /*rhsIndex*/ {kw},
3653  /*resIndex*/ {w}}))
3654  return conv(Conv1DOpOrder::W);
3655 
3656  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3657  }
3658 
3659  /// Entry point that transposes into the common form:
3660  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3661  FailureOr<Operation *> generateNwcConv() {
3662  AffineExpr n, w, f, kw, c;
3663  bindDims(ctx, n, w, f, kw, c);
3664  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3665  return rewriter.notifyMatchFailure(
3666  op, "failed to match conv::Nwc 3-par 2-red");
3667 
3668  // No transposition needed.
3669  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3670  /*rhsIndex*/ {kw, c, f},
3671  /*resIndex*/ {n, w, f}}))
3672  return conv(Conv1DOpOrder::Nwc);
3673 
3674  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3675  }
3676 
3677  /// Entry point that transposes into the common form:
3678  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3679  FailureOr<Operation *> generateNcwConv() {
3680  AffineExpr n, w, f, kw, c;
3681  bindDims(ctx, n, f, w, c, kw);
3682  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3683  return rewriter.notifyMatchFailure(
3684  op, "failed to match conv::Ncw 3-par 2-red");
3685 
3686  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3687  /*rhsIndex*/ {f, c, kw},
3688  /*resIndex*/ {n, f, w}}))
3689  return conv(Conv1DOpOrder::Ncw);
3690 
3691  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3692  }
3693 
3694  /// Entry point that transposes into the common form:
3695  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3696  FailureOr<Operation *> generateNwcPooling() {
3697  AffineExpr n, w, c, kw;
3698  bindDims(ctx, n, w, c, kw);
3699  if (!iters({Par(), Par(), Par(), Red()}))
3700  return rewriter.notifyMatchFailure(op,
3701  "failed to match pooling 3-par 1-red");
3702 
3703  // No transposition needed.
3704  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3705  /*rhsIndex*/ {kw},
3706  /*resIndex*/ {n, w, c}}))
3707  return conv(Conv1DOpOrder::Nwc);
3708 
3709  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3710  }
3711 
3712  /// Entry point that transposes into the common form:
3713  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3714  FailureOr<Operation *> generateNcwPooling() {
3715  AffineExpr n, w, c, kw;
3716  bindDims(ctx, n, c, w, kw);
3717  if (!iters({Par(), Par(), Par(), Red()}))
3718  return rewriter.notifyMatchFailure(op,
3719  "failed to match pooling 3-par 1-red");
3720 
3721  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3722  /*rhsIndex*/ {kw},
3723  /*resIndex*/ {n, c, w}}))
3724  return conv(Conv1DOpOrder::Ncw);
3725 
3726  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3727  }
3728 
3729  /// Entry point that transposes into the common form:
3730  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3731  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3732  bool vecChDimScalableFlag = false,
3733  bool flatten = false) {
3734  AffineExpr n, w, c, kw;
3735  bindDims(ctx, n, w, c, kw);
3736  if (!iters({Par(), Par(), Par(), Red()}))
3737  return rewriter.notifyMatchFailure(
3738  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3739 
3740  // No transposition needed.
3741  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3742  /*rhsIndex*/ {kw, c},
3743  /*resIndex*/ {n, w, c}}))
3744  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3745 
3746  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3747  }
3748 
3749 private:
3750  enum OperKind { Conv, Pool };
3751  bool valid = false;
3752  OperKind oper = Conv;
3753  StringAttr redOp;
3754  StringAttr poolExtOp;
3755  bool isPoolExt = false;
3756  int strideW, dilationW;
3757  Value lhsShaped, rhsShaped, resShaped;
3758  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3759  vector::CombiningKind reductionKind;
3760 
3761  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3762  // Returns true iff it is a valid conv/pooling op.
3763  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3764  // + yield) and rhs is not used) then it is the body of a pooling
3765  // If conv, check for single `mul` predecessor. The `mul` operands must be
3766  // block arguments or extension of block arguments.
3767  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3768  // must be block arguments or extension of block arguments.
3769  bool setOperKind(Operation *reduceOp) {
3770  int numBlockArguments =
3771  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3772  switch (numBlockArguments) {
3773  case 1: {
3774  // Will be convolution if feeder is a MulOp.
3775  // A strength reduced version of MulOp for i1 type is AndOp which is also
3776  // supported. Otherwise, it can be pooling. This strength reduction logic
3777  // is in `buildBinaryFn` helper in the Linalg dialect.
3778  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3779  llvm::IsaPred<BlockArgument>);
3780  Operation *feedOp = (*feedValIt).getDefiningOp();
3781  if (isCastOfBlockArgument(feedOp)) {
3782  oper = Pool;
3783  isPoolExt = true;
3784  poolExtOp = feedOp->getName().getIdentifier();
3785  } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3786  (isa<arith::AndIOp>(feedOp) &&
3787  feedOp->getResultTypes()[0].isInteger(1))) &&
3788  llvm::all_of(feedOp->getOperands(), [](Value v) {
3789  if (isa<BlockArgument>(v))
3790  return true;
3791  if (Operation *op = v.getDefiningOp())
3792  return isCastOfBlockArgument(op);
3793  return false;
3794  }))) {
3795  return false;
3796  }
3797  return true;
3798  }
3799  case 2:
3800  // Must be pooling
3801  oper = Pool;
3802  isPoolExt = false;
3803  return true;
3804  default:
3805  return false;
3806  }
3807  }
3808 };
3809 } // namespace
3810 
3811 /// Helper function to vectorize a LinalgOp with convolution semantics.
3812 // TODO: extend the generic vectorization to support windows and drop this.
3813 static FailureOr<Operation *> vectorizeConvolution(
3814  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3815  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3816  // The ConvolutionOpInterface gives us guarantees of existence for
3817  // strides/dilations. However, we do not need to rely on those, we can
3818  // simply use them if present, otherwise use the default and let the generic
3819  // conv. matcher in the ConvGenerator succeed or fail.
3820  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3821  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3822  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3823  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3824  Conv1DGenerator e(rewriter, op, stride, dilation);
3825  auto res = e.generateNonChanneledConv();
3826  if (succeeded(res))
3827  return res;
3828  res = e.generateNwcConv();
3829  if (succeeded(res))
3830  return res;
3831  res = e.generateNcwConv();
3832  if (succeeded(res))
3833  return res;
3834  res = e.generateNwcPooling();
3835  if (succeeded(res))
3836  return res;
3837  res = e.generateNcwPooling();
3838  if (succeeded(res))
3839  return res;
3840 
3841  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3842  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3843  // masked/scalable) is the channel dim (i.e. the trailing dim).
3844  uint64_t vecChDimSize = ShapedType::kDynamic;
3845  bool vecChDimScalableFlag = false;
3846  if (!inputVecSizes.empty()) {
3847  // Only use the input vector size corresponding to the channel dim. Other
3848  // vector dims will be inferred from the Ops.
3849  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3850  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3851  "Not a 1D depthwise conv!");
3852  size_t chDimIdx =
3854  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3855  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3856 
3857  vecChDimSize = inputVecSizes[chDimIdx];
3858  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3859  }
3860  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3861  flatten1DDepthwiseConv);
3862 }
3863 
3866 
3867  LogicalResult matchAndRewrite(LinalgOp op,
3868  PatternRewriter &rewriter) const override {
3869  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3870  if (failed(resultOrFail))
3871  return failure();
3872  Operation *newOp = *resultOrFail;
3873  if (newOp->getNumResults() == 0) {
3874  rewriter.eraseOp(op.getOperation());
3875  return success();
3876  }
3877  assert(newOp->getNumResults() == 1 && "expected single result");
3878  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3879  return success();
3880  }
3881 };
3882 
3884  RewritePatternSet &patterns, PatternBenefit benefit) {
3885  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3886 }
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:142
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:606
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:216
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2448
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:815
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite tensor.insert.slice as a vector.transfer_read + vector.transfer_write pair.
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp, PatternRewriter &rewriter) const final
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1506
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.