MLIR  20.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66  OpType res;
67  block.walk([&](OpType op) {
68  if (res) {
69  res = nullptr;
70  return WalkResult::interrupt();
71  }
72  res = op;
73  return WalkResult::advance();
74  });
75  return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
82  int64_t nSize, int64_t wSize, int64_t cSize,
83  int64_t kwSize, int strideW, int dilationW,
84  int64_t wSizeStep, bool isSingleChanneled) {
85  SmallVector<Value> result;
86  if (isSingleChanneled) {
87  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88  // convolution.
89  SmallVector<int64_t> sizes{wSizeStep};
90  SmallVector<int64_t> strides{1};
91  for (int64_t kw = 0; kw < kwSize; ++kw) {
92  for (int64_t w = 0; w < wSize; w += wSizeStep) {
93  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95  }
96  }
97  } else {
98  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99  // for channeled convolution.
100  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101  SmallVector<int64_t> strides{1, 1, 1};
102  for (int64_t kw = 0; kw < kwSize; ++kw) {
103  for (int64_t w = 0; w < wSize; w += wSizeStep) {
104  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105  loc, input,
106  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107  sizes, strides));
108  }
109  }
110  }
111  return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
117  Location loc, Value filter,
118  int64_t kwSize) {
119  SmallVector<Value> result;
120  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121  // non-chanelled convolution] @ [kw].
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  result.push_back(rewriter.create<vector::ExtractOp>(
124  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125  }
126  return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
133  int64_t nSize, int64_t wSize, int64_t fSize,
134  int64_t wSizeStep, bool isSingleChanneled) {
135  SmallVector<Value> result;
136  if (isSingleChanneled) {
137  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138  SmallVector<int64_t> sizes{wSizeStep};
139  SmallVector<int64_t> strides{1};
140  for (int64_t w = 0; w < wSize; w += wSizeStep) {
141  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143  }
144  } else {
145  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146  // convolution.
147  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148  SmallVector<int64_t> strides{1, 1, 1};
149  for (int64_t w = 0; w < wSize; w += wSizeStep) {
150  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152  }
153  }
154  return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
159  Value res, int64_t wSize, int64_t wSizeStep,
160  SmallVectorImpl<Value> &resVals,
161  bool isSingleChanneled) {
162 
163  if (isSingleChanneled) {
164  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165  // This does not depend on kw.
166  SmallVector<int64_t> strides{1};
167  for (int64_t w = 0; w < wSize; w += wSizeStep) {
168  res = rewriter.create<vector::InsertStridedSliceOp>(
169  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170  }
171  } else {
172  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173  // convolution. This does not depend on kw.
174  SmallVector<int64_t> strides{1, 1, 1};
175  for (int64_t w = 0; w < wSize; w += wSizeStep) {
176  res = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178  strides);
179  }
180  }
181  return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
187  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189  /// Initializes the vectorization state, including the computation of the
190  /// canonical vector shape for vectorization.
191  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192  ArrayRef<int64_t> inputVectorSizes,
193  ArrayRef<bool> inputScalableVecDims);
194 
195  /// Returns the canonical vector shape used to vectorize the iteration space.
196  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198  /// Returns the vector dimensions that are scalable in the canonical vector
199  /// shape.
200  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202  /// Returns a vector type of the provided `elementType` with the canonical
203  /// vector shape and the corresponding fixed/scalable dimensions bit. If
204  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205  /// accordingly.
207  Type elementType,
208  std::optional<AffineMap> dimPermutation = std::nullopt) const {
210  SmallVector<bool> scalableDims;
211  if (dimPermutation.has_value()) {
212  vectorShape =
213  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214  scalableDims =
215  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216  } else {
217  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219  }
220 
221  return VectorType::get(vectorShape, elementType, scalableDims);
222  }
223 
224  /// Masks an operation with the canonical vector mask if the operation needs
225  /// masking. Returns the masked operation or the original operation if masking
226  /// is not needed. If provided, the canonical mask for this operation is
227  /// permuted using `maybeIndexingMap`.
228  Operation *
229  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231 
232 private:
233  /// Initializes the iteration space static sizes using the Linalg op
234  /// information. This may become more complicated in the future.
235  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237  }
238 
239  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240  /// all the static and dynamic dimensions of the iteration space to be
241  /// vectorized and store them in `iterSpaceValueSizes`.
242  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243  LinalgOp linalgOp);
244 
245  /// Create or retrieve an existing mask value to mask `opToMask` in the
246  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247  /// permuted using that permutation map. If a new mask is created, it will be
248  /// cached for future users.
249  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250  LinalgOp linalgOp,
251  std::optional<AffineMap> maybeMaskingMap);
252 
253  /// Check whether this permutation map can be used for masking. At the
254  /// moment we only make sure that there are no broadcast dimensions, but this
255  /// might change if indexing maps evolve.
256  bool isValidMaskingMap(AffineMap maskingMap) {
257  return maskingMap.getBroadcastDims().size() == 0;
258  }
259 
260  /// Turn the input indexing map into a valid masking map.
261  ///
262  /// The input indexing map may contain "zero" results, e.g.:
263  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264  /// Applying such maps to canonical vector shapes like this one:
265  /// (1, 16, 16, 4)
266  /// would yield an invalid vector shape like this:
267  /// (16, 16, 1, 0)
268  /// Instead, drop the broadcasting dims that make no sense for masking perm.
269  /// maps:
270  /// (d0, d1, d2, d3) -> (d2, d1, d0)
271  /// This way, the corresponding vector/mask type will be:
272  /// vector<16x16x1xty>
273  /// rather than this invalid Vector type:
274  /// vector<16x16x1x0xty>
275  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
276  return indexingMap.dropZeroResults();
277  }
278 
279  // Holds the compile-time static sizes of the iteration space to vectorize.
280  // Dynamic dimensions are represented using ShapedType::kDynamic.
281  SmallVector<int64_t> iterSpaceStaticSizes;
282 
283  /// Holds the value sizes of the iteration space to vectorize. Static
284  /// dimensions are represented by 'arith.constant' and dynamic
285  /// dimensions by 'tensor/memref.dim'.
286  SmallVector<Value> iterSpaceValueSizes;
287 
288  /// Holds the canonical vector shape used to vectorize the iteration space.
289  SmallVector<int64_t> canonicalVecShape;
290 
291  /// Holds the vector dimensions that are scalable in the canonical vector
292  /// shape.
293  SmallVector<bool> scalableVecDims;
294 
295  /// Holds the active masks for permutations of the canonical vector iteration
296  /// space.
297  DenseMap<AffineMap, Value> activeMaskCache;
298 
299  /// Global vectorization guard for the incoming rewriter. It's initialized
300  /// when the vectorization state is initialized.
302 };
303 
304 LogicalResult
305 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
306  LinalgOp linalgOp) {
307  // TODO: Support 0-d vectors.
308  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
310  // Create constant index op for static dimensions.
311  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
312  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
313  continue;
314  }
315 
316  // Find an operand defined on this dimension of the iteration space to
317  // extract the runtime dimension size.
318  Value operand;
319  unsigned operandDimPos;
320  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
321  operandDimPos)))
322  return failure();
323 
324  Value dynamicDim = linalgOp.hasPureTensorSemantics()
325  ? (Value)rewriter.create<tensor::DimOp>(
326  linalgOp.getLoc(), operand, operandDimPos)
327  : (Value)rewriter.create<memref::DimOp>(
328  linalgOp.getLoc(), operand, operandDimPos);
329  iterSpaceValueSizes.push_back(dynamicDim);
330  }
331 
332  return success();
333 }
334 
335 /// Initializes the vectorization state, including the computation of the
336 /// canonical vector shape for vectorization.
337 // TODO: Move this to the constructor when we can remove the failure cases.
338 LogicalResult
339 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
340  ArrayRef<int64_t> inputVectorSizes,
341  ArrayRef<bool> inputScalableVecDims) {
342  // Initialize the insertion point.
343  rewriter.setInsertionPoint(linalgOp);
344 
345  if (!inputVectorSizes.empty()) {
346  // Get the canonical vector shape from the input vector sizes provided. This
347  // path should be taken to vectorize code with dynamic shapes and when using
348  // vector sizes greater than the iteration space sizes.
349  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350  scalableVecDims.append(inputScalableVecDims.begin(),
351  inputScalableVecDims.end());
352  } else {
353  // Compute the canonical vector shape from the operation shape. If there are
354  // dynamic shapes, the operation won't be vectorized. We assume all the
355  // vector dimensions are fixed.
356  canonicalVecShape = linalgOp.getStaticLoopRanges();
357  scalableVecDims.append(linalgOp.getNumLoops(), false);
358  }
359 
360  LDBG("Canonical vector shape: ");
361  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362  LLVM_DEBUG(llvm::dbgs() << "\n");
363  LDBG("Scalable vector dims: ");
364  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365  LLVM_DEBUG(llvm::dbgs() << "\n");
366 
367  if (ShapedType::isDynamicShape(canonicalVecShape))
368  return failure();
369 
370  // Initialize iteration space static sizes.
371  initIterSpaceStaticSizes(linalgOp);
372 
373  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
374  // all the static and dynamic dimensions of the iteration space, needed to
375  // compute a mask during vectorization.
376  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
377  return failure();
378 
379  return success();
380 }
381 
382 /// Create or retrieve an existing mask value to mask `opToMask` in the
383 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
384 /// using that permutation map. If a new mask is created, it will be cached for
385 /// future users.
386 Value VectorizationState::getOrCreateMaskFor(
387  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
388  std::optional<AffineMap> maybeMaskingMap) {
389 
390  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391  "Ill-formed masking map.");
392 
393  // No mask is needed if the operation is not maskable.
394  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
395  if (!maskableOp)
396  return Value();
397 
398  assert(!maskableOp.isMasked() &&
399  "Masking an operation that is already masked");
400 
401  // If no masking map was provided, use an identity map with the loop dims.
402  assert((!maybeMaskingMap || *maybeMaskingMap) &&
403  "Unexpected null mask permutation map");
404  AffineMap maskingMap =
405  maybeMaskingMap ? *maybeMaskingMap
407  linalgOp.getNumLoops(), rewriter.getContext());
408 
409  LDBG("Masking map: " << maskingMap << "\n");
410 
411  // Return the active mask for the masking map of this operation if it was
412  // already created.
413  auto activeMaskIt = activeMaskCache.find(maskingMap);
414  if (activeMaskIt != activeMaskCache.end()) {
415  Value mask = activeMaskIt->second;
416  LDBG("Reusing mask: " << mask << "\n");
417  return mask;
418  }
419 
420  // Compute permuted projection of the iteration space to be masked and the
421  // corresponding mask shape. If the resulting iteration space dimensions are
422  // static and identical to the mask shape, masking is not needed for this
423  // operation.
424  // TODO: Improve this check. Only projected permutation indexing maps are
425  // supported.
426  SmallVector<int64_t> permutedStaticSizes =
427  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
429  auto maskShape = maskType.getShape();
430 
431  LDBG("Mask shape: ");
432  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433  LLVM_DEBUG(llvm::dbgs() << "\n");
434 
435  if (permutedStaticSizes == maskShape) {
436  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
437  activeMaskCache[maskingMap] = Value();
438  return Value();
439  }
440 
441  // Permute the iteration space value sizes to compute the mask upper bounds.
442  SmallVector<Value> upperBounds =
443  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
444  assert(!maskShape.empty() && !upperBounds.empty() &&
445  "Masked 0-d vectors are not supported yet");
446 
447  // Create the mask based on the dimension values.
448  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
449  maskType, upperBounds);
450  LDBG("Creating new mask: " << mask << "\n");
451  activeMaskCache[maskingMap] = mask;
452  return mask;
453 }
454 
455 Operation *
457  LinalgOp linalgOp,
458  std::optional<AffineMap> maybeIndexingMap) {
459  LDBG("Trying to mask: " << *opToMask << "\n");
460 
461  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462  if (maybeIndexingMap)
463  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
464 
465  // Create or retrieve mask for this operation.
466  Value mask =
467  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
468 
469  if (!mask) {
470  LDBG("No mask required\n");
471  return opToMask;
472  }
473 
474  // Wrap the operation with a new `vector.mask` and update D-U chain.
475  assert(opToMask && "Expected a valid operation to mask");
476  auto maskOp = cast<vector::MaskOp>(
477  mlir::vector::maskOperation(rewriter, opToMask, mask));
478  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
479 
480  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
481  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
482  maskOpTerminator);
483 
484  LDBG("Masked operation: " << *maskOp << "\n");
485  return maskOp;
486 }
487 
488 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
489 /// projectedPermutation, compress the unused dimensions to serve as a
490 /// permutation_map for a vector transfer operation.
491 /// For example, given a linalg op such as:
492 ///
493 /// ```
494 /// %0 = linalg.generic {
495 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
496 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
497 /// }
498 /// ins(%0 : tensor<2x3x4xf32>)
499 /// outs(%1 : tensor<5x6xf32>)
500 /// ```
501 ///
502 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
503 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
504 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
506  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
507  "expected projected permutation");
508  auto res = compressUnusedDims(map);
509  assert(res.getNumDims() ==
510  (res.getNumResults() - res.getNumOfZeroResults()) &&
511  "expected reindexed map with same number of dims and results");
512  return res;
513 }
514 
515 /// Helper enum to represent conv1d input traversal order.
516 enum class Conv1DOpOrder {
517  W, // Corresponds to non-channeled 1D convolution operation.
518  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
519  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
520 };
521 
522 /// Helper data structure to represent the result of vectorization.
523 /// In certain specific cases, like terminators, we do not want to propagate/
525  /// Op failed to vectorize.
526  Failure = 0,
527  /// Op vectorized and custom function took care of replacement logic
529  /// Op vectorized into a new Op whose results will replace original Op's
530  /// results.
531  NewOp
532  // TODO: support values if Op vectorized to Many-Ops whose results we need to
533  // aggregate for replacement.
534 };
536  /// Return status from vectorizing the current op.
538  /// New vectorized operation to replace the current op.
539  /// Replacement behavior is specified by `status`.
541 };
542 
543 std::optional<vector::CombiningKind>
545  using ::mlir::vector::CombiningKind;
546 
547  if (!combinerOp)
548  return std::nullopt;
550  .Case<arith::AddIOp, arith::AddFOp>(
551  [&](auto op) { return CombiningKind::ADD; })
552  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
553  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
554  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
555  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
556  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
557  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
558  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
559  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
560  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
561  .Case<arith::MulIOp, arith::MulFOp>(
562  [&](auto op) { return CombiningKind::MUL; })
563  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
564  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
565  .Default([&](auto op) { return std::nullopt; });
566 }
567 
568 /// Check whether `outputOperand` is a reduction with a single combiner
569 /// operation. Return the combiner operation of the reduction. Return
570 /// nullptr otherwise. Multiple reduction operations would impose an
571 /// ordering between reduction dimensions and is currently unsupported in
572 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
573 /// max(min(X))
574 // TODO: use in LinalgOp verification, there is a circular dependency atm.
575 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
576  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
577  unsigned outputPos =
578  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
579  // Only single combiner operations are supported for now.
580  SmallVector<Operation *, 4> combinerOps;
581  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582  combinerOps.size() != 1)
583  return nullptr;
584 
585  // Return the combiner operation.
586  return combinerOps[0];
587 }
588 
589 /// Broadcast `value` to a vector of `shape` if possible. Return value
590 /// otherwise.
591 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
592  auto dstVecType = dyn_cast<VectorType>(dstType);
593  // If no shape to broadcast to, just return `value`.
594  if (dstVecType.getRank() == 0)
595  return value;
596  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
598  return value;
599  Location loc = b.getInsertionPoint()->getLoc();
600  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
601 }
602 
603 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
604 /// assumes that `reductionOp` has two operands and one of them is the reduction
605 /// initial value.buildMultiDimReduce
606 // Note: this is a true builder that notifies the OpBuilder listener.
607 // TODO: Consider moving as a static helper on the ReduceOp.
609  Value valueToReduce, Value acc,
610  ArrayRef<bool> dimsToMask) {
611  auto maybeKind = getCombinerOpKind(reduceOp);
612  assert(maybeKind && "Failed precondition: could not get reduction kind");
613  return b.create<vector::MultiDimReductionOp>(
614  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
615 }
616 
617 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
618  return llvm::to_vector(
619  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
620 }
621 
622 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
623 /// reduction iterator.
624 static bool hasReductionIterator(LinalgOp &op) {
625  return isa<linalg::ReduceOp>(op) ||
626  (isa<linalg::GenericOp>(op) &&
627  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
628 }
629 
630 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
631 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
632 /// currently being vectorized. If `dest` has null rank, build an memref.store.
633 /// Return the produced value or null if no value is produced.
634 // Note: this is a true builder that notifies the OpBuilder listener.
635 // TODO: Consider moving as a static helper on the ReduceOp.
636 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
637  OpOperand *outputOperand,
638  VectorizationState &state) {
639  Location loc = value.getLoc();
640  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
641  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
642 
643  // Compute the vector type of the value to store. This type should be an
644  // identity or projection of the canonical vector type without any permutation
645  // applied, given that any permutation in a transfer write happens as part of
646  // the write itself.
648  opOperandMap.getContext(), opOperandMap.getNumInputs(),
649  [&](AffineDimExpr dimExpr) -> bool {
650  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
651  });
652  auto vectorType = state.getCanonicalVecType(
653  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
654 
655  Operation *write;
656  if (vectorType.getRank() > 0) {
657  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
658  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
659  rewriter.create<arith::ConstantIndexOp>(loc, 0));
660  value = broadcastIfNeeded(rewriter, value, vectorType);
661  assert(value.getType() == vectorType && "Incorrect type");
662  write = rewriter.create<vector::TransferWriteOp>(
663  loc, value, outputOperand->get(), indices, writeMap);
664  } else {
665  // 0-d case is still special: do not invert the reindexing writeMap.
666  if (!isa<VectorType>(value.getType()))
667  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
668  assert(value.getType() == vectorType && "Incorrect type");
669  write = rewriter.create<vector::TransferWriteOp>(
670  loc, value, outputOperand->get(), ValueRange{});
671  }
672 
673  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
674 
675  // If masked, set in-bounds to true. Masking guarantees that the access will
676  // be in-bounds.
677  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
679  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
680  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
681  }
682 
683  LDBG("vectorized op: " << *write << "\n");
684  if (!write->getResults().empty())
685  return write->getResult(0);
686  return Value();
687 }
688 
689 // Custom vectorization precondition function type. This is intented to be used
690 // with CustomVectorizationHook. Returns success if the corresponding custom
691 // hook can vectorize the op.
693  std::function<LogicalResult(Operation *, bool)>;
694 
695 // Custom vectorization function type. Produce a vector form of Operation*
696 // assuming all its vectorized operands are already in the IRMapping.
697 // Return nullptr if the Operation cannot be vectorized.
699  std::function<VectorizationResult(Operation *, const IRMapping &)>;
700 
701 /// Helper function to vectorize the terminator of a `linalgOp`. New result
702 /// vector values are appended to `newResults`. Return
703 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
704 /// should not try to map produced operations and instead return the results
705 /// using the `newResults` vector making them available to the vectorization
706 /// algorithm for RAUW. This function is meant to be used as a
707 /// CustomVectorizationHook.
708 static VectorizationResult
710  const IRMapping &bvm, VectorizationState &state,
711  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
712  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
713  if (!yieldOp)
715  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
716  // TODO: Scan for an opportunity for reuse.
717  // TODO: use a map.
718  Value vectorValue = bvm.lookup(output.value());
719  Value newResult =
720  buildVectorWrite(rewriter, vectorValue,
721  linalgOp.getDpsInitOperand(output.index()), state);
722  if (newResult)
723  newResults.push_back(newResult);
724  }
725 
727 }
728 
729 /// Helper function to vectorize the index operations of a `linalgOp`. Return
730 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
731 /// should map the produced operations. This function is meant to be used as a
732 /// CustomVectorizationHook.
734  VectorizationState &state,
735  Operation *op,
736  LinalgOp linalgOp) {
737  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
738  if (!indexOp)
740  auto loc = indexOp.getLoc();
741  // Compute the static loop sizes of the index op.
742  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
743  auto dim = indexOp.getDim();
744  // Compute a one-dimensional index vector for the index op dimension.
745  auto indexVectorType =
746  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
747  state.getScalableVecDims()[dim]);
748  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
749  // Return the one-dimensional index vector if it lives in the trailing
750  // dimension of the iteration space since the vectorization algorithm in this
751  // case can handle the broadcast.
752  if (dim == targetShape.size() - 1)
754  // Otherwise permute the targetShape to move the index dimension last,
755  // broadcast the one-dimensional index vector to the permuted shape, and
756  // finally transpose the broadcasted index vector to undo the permutation.
757  auto permPattern =
758  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759  std::swap(permPattern[dim], permPattern.back());
760  auto permMap =
761  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
762 
763  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
764  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
765  indexSteps);
766  SmallVector<int64_t> transposition =
767  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768  std::swap(transposition.back(), transposition[dim]);
769  auto transposeOp =
770  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
771  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
772 }
773 
774 /// Helper function to check if the tensor.extract can be vectorized by the
775 /// custom hook vectorizeTensorExtract.
776 static LogicalResult
777 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
778  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
779  if (!extractOp)
780  return failure();
781 
782  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
783  return failure();
784 
785  // Check the index type, but only for non 0-d tensors (for which we do need
786  // access indices).
787  if (not extractOp.getIndices().empty()) {
788  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
789  return failure();
790  }
791 
792  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
793  return !VectorType::isValidElementType(type);
794  })) {
795  return failure();
796  }
797 
798  return success();
799 }
800 
801 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
802 /// generated from `tensor.extract`. The offset is calculated as follows
803 /// (example using scalar values):
804 ///
805 /// offset = extractOp.indices[0]
806 /// for (i = 1; i < numIndices; i++)
807 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
808 ///
809 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
810 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
812  VectorizationState &state,
813  tensor::ExtractOp extractOp,
814  const IRMapping &bvm) {
815  // The vector of indices for GatherOp should be shaped as the output vector.
816  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
817  auto loc = extractOp.getLoc();
818 
819  Value offset = broadcastIfNeeded(
820  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
821 
822  const size_t numIndices = extractOp.getIndices().size();
823  for (size_t i = 1; i < numIndices; i++) {
824  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
825 
826  auto dimSize = broadcastIfNeeded(
827  rewriter,
828  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
829  indexVecType);
830 
831  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
832 
833  auto extractOpIndex = broadcastIfNeeded(
834  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
835 
836  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
837  }
838 
839  return offset;
840 }
841 
843 
844 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
845 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
846 /// represents a contiguous load operation.
847 ///
848 /// Note that when calling this hook, it is assumed that the output vector is
849 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
850 /// labelled as a gather load before entering this method.
851 ///
852 /// Following on from the above, it is assumed that:
853 /// * for statically shaped loops, when no masks are used, only one dim is !=
854 /// 1 (that's what the shape of the output vector is based on).
855 /// * for dynamically shaped loops, there might be more non-unit dims
856 /// as the output vector type is user-specified.
857 ///
858 /// TODO: Statically shaped loops + vector masking
859 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
860  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
861  assert(
862  (linalgOp.hasDynamicShape() ||
863  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864  "For statically shaped Linalg Ops, only one "
865  "non-unit loop dim is expected");
866  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
867 
868  size_t idx = loopRanges.size() - 1;
869  for (; idx != 0; idx--)
870  if (loopRanges[idx] != 1)
871  break;
872 
873  return idx;
874 }
875 
876 /// Checks whether `val` can be used for calculating a loop invariant index.
877 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
878  VectorType resType) {
879 
880  assert(((llvm::count_if(resType.getShape(),
881  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882  "n-D vectors are not yet supported");
883 
884  // Blocks outside _this_ linalg.generic are effectively loop invariant.
885  // However, analysing block arguments for _this_ linalg.generic Op is a bit
886  // tricky. Just bail out in the latter case.
887  // TODO: We could try analysing the corresponding affine map here.
888  auto *block = linalgOp.getBlock();
889  if (isa<BlockArgument>(val))
890  return llvm::all_of(block->getArguments(),
891  [&val](Value v) { return (v != val); });
892 
893  Operation *defOp = val.getDefiningOp();
894  assert(defOp && "This is neither a block argument nor an operation result");
895 
896  // IndexOp is loop invariant as long as its result remains constant across
897  // iterations. Note that for dynamic shapes, the corresponding dim will also
898  // be conservatively treated as != 1.
899  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
901  }
902 
903  auto *ancestor = block->findAncestorOpInBlock(*defOp);
904 
905  // Values define outside `linalgOp` are loop invariant.
906  if (!ancestor)
907  return true;
908 
909  // Values defined inside `linalgOp`, which are constant, are loop invariant.
910  if (isa<arith::ConstantOp>(ancestor))
911  return true;
912 
913  bool result = true;
914  for (auto op : ancestor->getOperands())
915  result &= isLoopInvariantIdx(linalgOp, op, resType);
916 
917  return result;
918 }
919 
920 /// Check whether `val` could be used for calculating the trailing index for a
921 /// contiguous load operation.
922 ///
923 /// There are currently 3 types of values that are allowed here:
924 /// 1. loop-invariant values,
925 /// 2. values that increment by 1 with every loop iteration,
926 /// 3. results of basic arithmetic operations (linear and continuous)
927 /// involving 1., 2. and 3.
928 /// This method returns True if indeed only such values are used in calculating
929 /// `val.`
930 ///
931 /// Additionally, the trailing index for a contiguous load operation should
932 /// increment by 1 with every loop iteration, i.e. be based on:
933 /// * `linalg.index <dim>` ,
934 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
935 /// `linalg.index <dim>` increments by 1 with every loop iteration).
936 /// `foundIndexOp` is updated to `true` when such Op is found.
937 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
938  bool &foundIndexOp, VectorType resType) {
939 
940  assert(((llvm::count_if(resType.getShape(),
941  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942  "n-D vectors are not yet supported");
943 
944  // Blocks outside _this_ linalg.generic are effectively loop invariant.
945  // However, analysing block arguments for _this_ linalg.generic Op is a bit
946  // tricky. Just bail out in the latter case.
947  // TODO: We could try analysing the corresponding affine map here.
948  auto *block = linalgOp.getBlock();
949  if (isa<BlockArgument>(val))
950  return llvm::all_of(block->getArguments(),
951  [&val](Value v) { return (v != val); });
952 
953  Operation *defOp = val.getDefiningOp();
954  assert(defOp && "This is neither a block argument nor an operation result");
955 
956  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
957  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
958 
959  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
960  return true;
961  }
962 
963  auto *ancestor = block->findAncestorOpInBlock(*defOp);
964 
965  if (!ancestor)
966  return false;
967 
968  // Conservatively reject Ops that could lead to indices with stride other
969  // than 1.
970  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
971  return false;
972 
973  bool result = false;
974  for (auto op : ancestor->getOperands())
975  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
976 
977  return result;
978 }
979 
980 /// Infer the memory access pattern for the input ExtractOp
981 ///
982 /// Based on the ExtratOp result shape and the access indices, decides whether
983 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
984 /// or a gather load. When analysing the ExtractOp indices (to identify
985 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
986 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
987 /// Op).
988 ///
989 /// Note that it is always safe to use gather load operations for contiguous
990 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
991 /// that `extractOp` is a gather load.
993 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
994  LinalgOp &linalgOp, VectorType resType) {
995 
996  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
997 
998  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
999  if (inputShape.getShape().empty())
1001 
1002  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1003  // otherwise.
1004  bool isOutput1DVector =
1005  (llvm::count_if(resType.getShape(),
1006  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1007  // 1. Assume that it's a gather load when reading non-1D vector.
1008  if (!isOutput1DVector)
1010 
1011  bool leadingIdxsLoopInvariant = true;
1012 
1013  // 2. Analyze the leading indices of `extractOp`.
1014  // Look at the way each index is calculated and decide whether it is suitable
1015  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1016  // gather load.
1017  auto indices = extractOp.getIndices();
1018  auto leadIndices = indices.drop_back(1);
1019 
1020  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1021  if (inputShape.getShape()[i] == 1)
1022  continue;
1023 
1024  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1025  }
1026 
1027  if (!leadingIdxsLoopInvariant) {
1028  LDBG("Found gather load: " << extractOp);
1030  }
1031 
1032  // 3. Analyze the trailing index for `extractOp`.
1033  // At this point we know that the leading indices are loop invariant. This
1034  // means that is potentially a scalar or a contiguous load. We can decide
1035  // based on the trailing idx.
1036  auto extractOpTrailingIdx = indices.back();
1037 
1038  // 3a. Scalar broadcast load
1039  // If the trailing index is loop invariant then this is a scalar load.
1040  if (leadingIdxsLoopInvariant &&
1041  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1042  LDBG("Found scalar broadcast load: " << extractOp);
1043 
1045  }
1046 
1047  // 3b. Contiguous loads
1048  // The trailing `extractOp` index should increment with every loop iteration.
1049  // This effectively means that it must be based on the trailing loop index.
1050  // This is what the following bool captures.
1051  bool foundIndexOp = false;
1052  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1053  foundIndexOp, resType);
1054  // TODO: Support generating contiguous loads for column vectors - that will
1055  // require adding a permutation map to tranfer_read Ops.
1056  bool isRowVector = resType.getShape().back() != 1;
1057  isContiguousLoad &= (foundIndexOp && isRowVector);
1058 
1059  if (isContiguousLoad) {
1060  LDBG("Found contigous load: " << extractOp);
1062  }
1063 
1064  // 4. Fallback case - gather load.
1065  LDBG("Found gather load: " << extractOp);
1067 }
1068 
1069 /// Helper function to vectorize the tensor.extract operations. Returns
1070 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1071 /// should map the produced operations. This function is meant to be used as a
1072 /// CustomVectorizationHook.
1073 static VectorizationResult
1075  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1076  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1077  if (!extractOp)
1079  auto loc = extractOp.getLoc();
1080 
1081  // Compute the static loop sizes of the extract op.
1082  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1084  loc,
1085  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1086  /*value=*/true));
1087  auto passThruConstantOp =
1088  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1089 
1090  // Base indices are currently set to 0. We will need to re-visit if more
1091  // generic scenarios are to be supported.
1092  SmallVector<Value> baseIndices(
1093  extractOp.getIndices().size(),
1094  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1095 
1096  VectorMemoryAccessKind memAccessKind =
1097  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1098 
1099  // 1. Handle gather access
1100  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1101  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1102 
1103  // Generate the gather load
1104  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1105  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106  maskConstantOp, passThruConstantOp);
1107  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1108 
1109  LDBG("Vectorised as gather load: " << extractOp << "\n");
1111  }
1112 
1113  // 2. Handle:
1114  // a. scalar loads + broadcast,
1115  // b. contiguous loads.
1116  // Both cases use vector.transfer_read.
1117 
1118  // Collect indices for `vector.transfer_read`. At this point, the indices will
1119  // either be scalars or would have been broadcast to vectors matching the
1120  // result type. For indices that are vectors, there are two options:
1121  // * for non-trailing indices, all elements are identical (contiguous
1122  // loads are identified by looking for non-trailing indices that are
1123  // invariant with respect to the corresponding linalg.generic), or
1124  // * for trailing indices, the index vector will contain values with stride
1125  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1126  // needed.
1127  // This means that
1128  // * for scalar indices - just re-use it,
1129  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1130  // (0th) element and use that.
1131  SmallVector<Value> transferReadIdxs;
1132  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1133  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1134  if (idx.getType().isIndex()) {
1135  transferReadIdxs.push_back(idx);
1136  continue;
1137  }
1138 
1139  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1140  loc,
1141  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1142  resultType.getScalableDims().back()),
1143  idx);
1144  transferReadIdxs.push_back(
1145  rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1146  }
1147 
1148  // `tensor.extract_element` is always in-bounds, hence the following holds.
1149  auto dstRank = resultType.getRank();
1150  auto srcRank = extractOp.getTensor().getType().getRank();
1151  SmallVector<bool> inBounds(dstRank, true);
1152 
1153  // 2a. Handle scalar broadcast access.
1154  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1155  MLIRContext *ctx = rewriter.getContext();
1156  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1157  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1158 
1159  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1160  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1161  permutationMap, inBounds);
1162 
1163  // Mask this broadcasting xfer_read here rather than relying on the generic
1164  // path (the generic path assumes identity masking map, which wouldn't be
1165  // valid here).
1166  SmallVector<int64_t> readMaskShape = {1};
1167  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1168  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1169  loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1170  auto *maskedReadOp =
1171  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1172 
1173  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1174  return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1175  }
1176 
1177  // 2b. Handle contiguous access.
1178  auto permutationMap = AffineMap::getMinorIdentityMap(
1179  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1180 
1181  int32_t rankDiff = dstRank - srcRank;
1182  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1183  // dims so that the ranks match. This is done by extending the map with 0s.
1184  // For example, for dstRank = 3, srcRank = 2, the following map created
1185  // above:
1186  // (d0, d1) --> (d0, d1)
1187  // is extended as:
1188  // (d0, d1) --> (0, d0, d1)
1189  while (rankDiff > 0) {
1190  permutationMap = permutationMap.insertResult(
1191  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1192  rankDiff--;
1193  }
1194 
1195  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1196  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1197  inBounds);
1198 
1199  LDBG("Vectorised as contiguous load: " << extractOp);
1200  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1201 }
1202 
1203 /// Emit reduction operations if the shapes of the value to reduce is different
1204 /// that the result shape.
1205 // Note: this is a true builder that notifies the OpBuilder listener.
1206 // TODO: Consider moving as a static helper on the ReduceOp.
1207 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1208  Value reduceValue, Value initialValue,
1209  const IRMapping &bvm) {
1210  Value reduceVec = bvm.lookup(reduceValue);
1211  Value outputVec = bvm.lookup(initialValue);
1212  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1213  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1214  // Reduce only if needed as the value may already have been reduce for
1215  // contraction vectorization.
1216  if (!reduceType ||
1217  (outputType && reduceType.getShape() == outputType.getShape()))
1218  return nullptr;
1219  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1220  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1221 }
1222 
1223 /// Generic vectorization for a single operation `op`, given already vectorized
1224 /// operands carried by `bvm`. Vectorization occurs as follows:
1225 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1226 /// result on success.
1227 /// 2. Clone any constant in the current scope without vectorization: each
1228 /// consumer of the constant will later determine the shape to which the
1229 /// constant needs to be broadcast to.
1230 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1231 /// of the `customVectorizationHooks` to cover such cases.
1232 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1233 /// operand of maximal rank. Other operands have smaller rank and are
1234 /// broadcast accordingly. It is assumed this broadcast is always legal,
1235 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1236 ///
1237 /// This function assumes all operands of `op` have been vectorized and are in
1238 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1239 /// a topologically-sorted list of ops.
1240 /// This function does not update `bvm` but returns a VectorizationStatus that
1241 /// instructs the caller what `bvm` update needs to occur.
1242 static VectorizationResult
1244  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1245  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1246  LDBG("vectorize op " << *op << "\n");
1247 
1248  // 1. Try to apply any CustomVectorizationHook.
1249  if (!customVectorizationHooks.empty()) {
1250  for (auto &customFunc : customVectorizationHooks) {
1251  VectorizationResult result = customFunc(op, bvm);
1252  if (result.status == VectorizationStatus::Failure)
1253  continue;
1254  return result;
1255  }
1256  }
1257 
1258  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1259  // Clone so that the constant is not confined to the linalgOp block .
1260  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1261  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1262 
1263  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1266 
1267  // 4 . Check if the operation is a reduction.
1268  SmallVector<std::pair<Value, Value>> reductionOperands;
1269  for (Value operand : op->getOperands()) {
1270  auto blockArg = dyn_cast<BlockArgument>(operand);
1271  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1272  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1273  continue;
1274  SmallVector<Operation *> reductionOps;
1275  Value reduceValue = matchReduction(
1276  linalgOp.getRegionOutputArgs(),
1277  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1278  if (!reduceValue)
1279  continue;
1280  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1281  }
1282  if (!reductionOperands.empty()) {
1283  assert(reductionOperands.size() == 1);
1284  Operation *reduceOp =
1285  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1286  reductionOperands[0].second, bvm);
1287  if (reduceOp)
1289  }
1290 
1291  // 5. Generic vectorization path for ElementwiseMappable ops.
1292  // a. Get the first max ranked shape.
1293  VectorType firstMaxRankedType;
1294  for (Value operand : op->getOperands()) {
1295  auto vecOperand = bvm.lookup(operand);
1296  assert(vecOperand && "Vector operand couldn't be found");
1297 
1298  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1299  if (vecType && (!firstMaxRankedType ||
1300  firstMaxRankedType.getRank() < vecType.getRank()))
1301  firstMaxRankedType = vecType;
1302  }
1303  // b. Broadcast each op if needed.
1304  SmallVector<Value> vecOperands;
1305  for (Value scalarOperand : op->getOperands()) {
1306  Value vecOperand = bvm.lookup(scalarOperand);
1307  assert(vecOperand && "Vector operand couldn't be found");
1308 
1309  if (firstMaxRankedType) {
1310  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1311  getElementTypeOrSelf(vecOperand.getType()),
1312  firstMaxRankedType.getScalableDims());
1313  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1314  } else {
1315  vecOperands.push_back(vecOperand);
1316  }
1317  }
1318  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1319  SmallVector<Type> resultTypes;
1320  for (Type resultType : op->getResultTypes()) {
1321  resultTypes.push_back(
1322  firstMaxRankedType
1323  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1324  firstMaxRankedType.getScalableDims())
1325  : resultType);
1326  }
1327  // d. Build and return the new op.
1328  return VectorizationResult{
1330  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1331  resultTypes, op->getAttrs())};
1332 }
1333 
1334 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1335 /// vector form. Generic vectorization proceeds as follows:
1336 /// 1. Verify the `linalgOp` has one non-empty region.
1337 /// 2. Values defined above the region are mapped to themselves and will be
1338 /// broadcasted on a per-need basis by their consumers.
1339 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1340 /// load).
1341 /// TODO: Reuse opportunities for RAR dependencies.
1342 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1343 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1344 /// iteration indices.
1345 /// 5. Iteratively call vectorizeOneOp on the region operations.
1346 ///
1347 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1348 /// performed to the maximal common vector size implied by the `linalgOp`
1349 /// iteration space. This eager broadcasting is introduced in the
1350 /// permutation_map of the vector.transfer_read operations. The eager
1351 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1352 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1353 /// the absence of good canonicalizations, the amount of work increases.
1354 /// This is not deemed a problem as we expect canonicalizations and foldings to
1355 /// aggressively clean up the useless work.
1356 static LogicalResult
1358  LinalgOp linalgOp,
1359  SmallVectorImpl<Value> &newResults) {
1360  LDBG("Vectorizing operation as linalg generic\n");
1361  Block *block = linalgOp.getBlock();
1362 
1363  // 2. Values defined above the region can only be broadcast for now. Make them
1364  // map to themselves.
1365  IRMapping bvm;
1366  SetVector<Value> valuesSet;
1367  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1368  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1369 
1370  if (linalgOp.getNumDpsInits() == 0)
1371  return failure();
1372 
1373  // 3. Turn all BBArgs into vector.transfer_read / load.
1374  Location loc = linalgOp.getLoc();
1375  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1376  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1377  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1378  if (linalgOp.isScalar(opOperand)) {
1379  bvm.map(bbarg, opOperand->get());
1380  continue;
1381  }
1382 
1383  // 3.a. Convert the indexing map for this input/output to a transfer read
1384  // permutation map and masking map.
1385  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1386 
1387  AffineMap readMap;
1388  VectorType readType;
1389  Type elemType = getElementTypeOrSelf(opOperand->get());
1390  if (linalgOp.isDpsInput(opOperand)) {
1391  // 3.a.i. For input reads we use the canonical vector shape.
1392  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1393  readType = state.getCanonicalVecType(elemType);
1394  } else {
1395  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1396  // reductions), the vector shape is computed by mapping the canonical
1397  // vector shape to the output domain and back to the canonical domain.
1398  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1399  readType =
1400  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1401  }
1402 
1403  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1404 
1405  Operation *read = rewriter.create<vector::TransferReadOp>(
1406  loc, readType, opOperand->get(), indices, readMap);
1407  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1408  Value readValue = read->getResult(0);
1409 
1410  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1411  // will be in-bounds.
1412  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1413  SmallVector<bool> inBounds(readType.getRank(), true);
1414  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1415  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1416  }
1417 
1418  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1419  // TODO: remove this.
1420  if (readType.getRank() == 0)
1421  readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1422  ArrayRef<int64_t>());
1423 
1424  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1425  << "\n");
1426  bvm.map(bbarg, readValue);
1427  bvm.map(opOperand->get(), readValue);
1428  }
1429 
1431  // 4a. Register CustomVectorizationHook for yieldOp.
1432  CustomVectorizationHook vectorizeYield =
1433  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1434  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1435  };
1436  hooks.push_back(vectorizeYield);
1437 
1438  // 4b. Register CustomVectorizationHook for indexOp.
1439  CustomVectorizationHook vectorizeIndex =
1440  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1441  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1442  };
1443  hooks.push_back(vectorizeIndex);
1444 
1445  // 4c. Register CustomVectorizationHook for extractOp.
1446  CustomVectorizationHook vectorizeExtract =
1447  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1448  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1449  };
1450  hooks.push_back(vectorizeExtract);
1451 
1452  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1453  for (Operation &op : block->getOperations()) {
1454  VectorizationResult result =
1455  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1456  if (result.status == VectorizationStatus::Failure) {
1457  LDBG("failed to vectorize: " << op << "\n");
1458  return failure();
1459  }
1460  if (result.status == VectorizationStatus::NewOp) {
1461  Operation *maybeMaskedOp =
1462  state.maskOperation(rewriter, result.newOp, linalgOp);
1463  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1464  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1465  }
1466  }
1467 
1468  return success();
1469 }
1470 
1471 /// Given a tensor::PackOp, return the `dest` shape before any packing
1472 /// permutations.
1473 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1474  ArrayRef<int64_t> destShape) {
1475  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1476 }
1477 
1478 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1479 /// create an empty destination tensor and create a TransferWriteOp from the
1480 /// input to the empty tensor. If the destination shape is not the same as the
1481 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1482 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1483 /// inBounds attribute of the transfer write op instead of masking.
1485  Value input,
1486  SmallVector<OpFoldResult> destSizes,
1487  ArrayRef<int64_t> inputVectorSizes,
1488  bool useInBoundsInsteadOfMasking) {
1489 
1490  auto inputType = cast<VectorType>(input.getType());
1491  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1492  inputType.getElementType());
1493  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1494  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1495  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1496  SmallVector<bool> inBoundsVal(rank, true);
1497  if (useInBoundsInsteadOfMasking) {
1498  // Update the inBounds attribute.
1499  for (unsigned i = 0; i < rank; i++)
1500  inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1501  !ShapedType::isDynamic(destShape[i]);
1502  }
1503  Operation *write = builder.create<vector::TransferWriteOp>(
1504  loc,
1505  /*vector=*/input,
1506  /*source=*/dest,
1507  /*indices=*/SmallVector<Value>(rank, zero),
1508  /*inBounds=*/inBoundsVal);
1509  assert(llvm::none_of(
1510  destShape.drop_front(inputVectorSizes.size()),
1511  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1512  "Only dims aligned with inputVectorSizes may be dynamic");
1513  if (useInBoundsInsteadOfMasking)
1514  return write;
1515  bool needMaskForWrite = !llvm::equal(
1516  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1517  if (needMaskForWrite) {
1518  SmallVector<int64_t> writeMaskShape;
1519  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1520  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1521  destShape.end());
1522  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1523  Value maskForWrite =
1524  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1525  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1526  }
1527  return write;
1528 }
1529 
1530 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1531 /// padding value and (3) input vector sizes into:
1532 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1533 /// As in the following example:
1534 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1535 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1536 ///
1537 /// This pack would be vectorized to:
1538 ///
1539 /// %load = vector.mask %mask {
1540 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1541 /// {in_bounds = [true, true, true]} :
1542 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1543 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1544 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1545 /// to vector<32x4x2x1x16xf32>
1546 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1547 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1548 /// %write = vector.transfer_write %transpose,
1549 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1550 /// {in_bounds = [true, true, true, true, true]}
1551 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1552 ///
1553 /// If the (3) input vector sizes are not provided, the vector sizes are
1554 /// determined by the result tensor shape. Also, we update the inBounds
1555 /// attribute instead of masking.
1556 static LogicalResult
1557 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1558  ArrayRef<int64_t> inputVectorSizes,
1559  SmallVectorImpl<Value> &newResults) {
1560  OpBuilder::InsertionGuard g(rewriter);
1561  rewriter.setInsertionPoint(packOp);
1562 
1563  Location loc = packOp.getLoc();
1564  auto padValue = packOp.getPaddingValue();
1565  if (!padValue) {
1566  padValue = rewriter.create<arith::ConstantOp>(
1567  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1568  }
1569  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1570  LogicalResult status =
1571  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1572  .reifyResultShapes(rewriter, reifiedReturnShapes);
1573  (void)status; // prevent unused variable warning on non-assert builds.
1574  assert(succeeded(status) && "failed to reify result shapes");
1575 
1576  // If the input vector sizes are not provided, then the vector sizes are
1577  // determined by the result tensor shape. In case the vector sizes aren't
1578  // provided, we update the inBounds attribute instead of masking.
1579  bool useInBoundsInsteadOfMasking = false;
1580  if (inputVectorSizes.empty()) {
1581  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1582  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1583  useInBoundsInsteadOfMasking = true;
1584  }
1585 
1586  // Create masked TransferReadOp.
1587  SmallVector<int64_t> inputShape(inputVectorSizes);
1588  auto innerTiles = packOp.getStaticInnerTiles();
1589  auto innerDimsPos = packOp.getInnerDimsPos();
1590  auto outerDimsPerm = packOp.getOuterDimsPerm();
1591  if (!outerDimsPerm.empty())
1592  applyPermutationToVector(inputShape,
1593  invertPermutationVector(outerDimsPerm));
1594  for (auto [idx, size] : enumerate(innerTiles))
1595  inputShape[innerDimsPos[idx]] *= size;
1596  auto maskedRead = vector::createReadOrMaskedRead(
1597  rewriter, loc, packOp.getSource(), inputShape, padValue,
1598  useInBoundsInsteadOfMasking);
1599 
1600  // Create ShapeCastOp.
1601  SmallVector<int64_t> destShape(inputVectorSizes);
1602  destShape.append(innerTiles.begin(), innerTiles.end());
1603  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1604  packOp.getDestType().getElementType());
1605  auto shapeCastOp =
1606  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1607 
1608  // Create TransposeOp.
1609  auto destPermutation =
1611  auto transposeOp = rewriter.create<vector::TransposeOp>(
1612  loc, shapeCastOp.getResult(), destPermutation);
1613 
1614  // Create TransferWriteOp.
1616  rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1617  inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1618  newResults.push_back(write->getResult(0));
1619  return success();
1620 }
1621 
1622 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1623 /// Vector::TransferReadOp - Reads a vector from the source tensor
1624 /// vector::TransposeOp - Transpose the Source tensor
1625 /// ShapeCastOp - Reshape the data based on the target.
1626 /// vector::TransferWriteOp. - Write the result vector back to the destination
1627 /// tensor.
1628 /// If the vector sizes are not provided:
1629 /// * the vector sizes are determined by the input operand and attributes,
1630 /// * update the inBounds attribute instead of masking.
1631 static LogicalResult
1632 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1633  ArrayRef<int64_t> inputVectorSizes,
1634  SmallVectorImpl<Value> &newResults) {
1635 
1636  OpBuilder::InsertionGuard g(rewriter);
1637  rewriter.setInsertionPoint(unpackOp);
1638 
1639  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1640 
1641  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1642  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1643  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1644  bool useInBoundsInsteadOfMasking = false;
1645  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1646 
1647  auto destSize = unpackOp.getDestRank();
1648 
1649  if (!inputVectorSizes.empty())
1650  assert(inputVectorSizes.size() == destSize &&
1651  "Incorrect number of input vector sizes");
1652 
1653  // vectorSizes is the shape of the vector that will be used to do final
1654  // write on the destination tensor. It is set like this: Let's say the
1655  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1656  // Thus:
1657  // 1. vectorSizes = sourceShape.take_front(N)
1658  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1659  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1660  // innerTiles attribute value.
1661  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1662  if (vectorSizes.empty()) {
1663  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1664  if (!outerDimsPerm.empty())
1665  applyPermutationToVector(vectorSizes, outerDimsPerm);
1666  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1667  vectorSizes[pos] *= innerTiles[i];
1668 
1669  useInBoundsInsteadOfMasking = true;
1670  }
1671 
1672  // readVectorSizes is the size of tensor used to read and apply mask. It is
1673  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1674  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1675  // size M-N
1676  // Thus:
1677  // - initially: readVectorSizes = vectorInputSizes
1678  // - Divide all the readMaskShape locations pointed by innerDimPos
1679  // by the innerTileSize attribute value.
1680  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1681  // - Append the remaining shape from SS
1682  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1683  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1684  // 128] and outer_dims_perm is [1, 0] then read shape is:
1685  // ReadVectorSizes(initial): [512, 128]
1686  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1687  // = [16, 8]
1688  // After applying outer_dims_perm: [8, 16]
1689  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1690 
1691  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1692 
1693  for (auto [index, size] : enumerate(innerTiles)) {
1694  readVectorSizes[innerDimPos[index]] =
1695  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1696  }
1697  if (!outerDimsPerm.empty()) {
1698  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1699  }
1700  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1701  sourceShape.end());
1702 
1703  ReifiedRankedShapedTypeDims reifiedRetShapes;
1704  LogicalResult status =
1705  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1706  .reifyResultShapes(rewriter, reifiedRetShapes);
1707  if (status.failed()) {
1708  LDBG("Unable to reify result shapes of " << unpackOp);
1709  return failure();
1710  }
1711  Location loc = unpackOp->getLoc();
1712 
1713  auto padValue = rewriter.create<arith::ConstantOp>(
1714  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1715 
1716  // Read result, mask if necessary. If transferReadOp shape is not equal
1717  // to shape of source, then a mask is necessary.
1719  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1720  /*useInBoundsInsteadOfMasking=*/false);
1721 
1722  PackingMetadata packMetadata;
1723  SmallVector<int64_t> lastDimToInsertPosPerm =
1724  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1725  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1726  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1727  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1728  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1729  RankedTensorType stripMineTensorType =
1730  RankedTensorType::get(stripMineShape, stripMineElemType);
1731  // Transpose the appropriate rows to match output.
1732  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1733  loc, readResult, lastDimToInsertPosPerm);
1734 
1735  // Collapse the vector to the size required by result.
1736  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1737  stripMineTensorType, packMetadata.reassociations);
1738  mlir::VectorType vecCollapsedType =
1739  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1740  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1741  loc, vecCollapsedType, transposeOp->getResult(0));
1742 
1743  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1744  // otherwise the validator complains that the mask size is invalid.
1745  SmallVector<int64_t> writeVectorSizes(
1746  unpackOp.getDestType().hasStaticShape()
1747  ? vectorSizes
1748  : shapeCastOp.getResultVectorType().getShape());
1750  rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1751  writeVectorSizes, useInBoundsInsteadOfMasking);
1752  newResults.push_back(write->getResult(0));
1753  return success();
1754 }
1755 
1756 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1757 /// and (3) all-zero lowPad to
1758 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1759 static LogicalResult
1760 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1761  ArrayRef<int64_t> inputVectorSizes,
1762  SmallVectorImpl<Value> &newResults) {
1763  auto padValue = padOp.getConstantPaddingValue();
1764  Location loc = padOp.getLoc();
1765 
1766  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1767  OpBuilder::InsertionGuard g(rewriter);
1768  rewriter.setInsertionPoint(padOp);
1769 
1770  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1771  LogicalResult status =
1772  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1773  .reifyResultShapes(rewriter, reifiedReturnShapes);
1774  (void)status; // prevent unused variable warning on non-assert builds
1775  assert(succeeded(status) && "failed to reify result shapes");
1776  auto maskedRead = vector::createReadOrMaskedRead(
1777  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1778  /*useInBoundsInsteadOfMasking=*/false);
1780  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1781  /*useInBoundsInsteadOfMasking=*/false);
1782  newResults.push_back(write->getResult(0));
1783  return success();
1784 }
1785 
1786 // TODO: probably need some extra checks for reduction followed by consumer
1787 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1788 static LogicalResult reductionPreconditions(LinalgOp op) {
1789  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1790  LDBG("reduction precondition failed: no reduction iterator\n");
1791  return failure();
1792  }
1793  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1794  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1795  if (indexingMap.isPermutation())
1796  continue;
1797 
1798  Operation *reduceOp = matchLinalgReduction(&opOperand);
1799  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1800  LDBG("reduction precondition failed: reduction detection failed\n");
1801  return failure();
1802  }
1803  }
1804  return success();
1805 }
1806 
1807 static LogicalResult
1809  bool flatten1DDepthwiseConv) {
1810  if (flatten1DDepthwiseConv) {
1811  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1812  "supported\n");
1813  return failure();
1814  }
1815 
1816  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1817  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1818  return failure();
1819  }
1820 
1821  // Support dynamic shapes in 1D depthwise convolution, but only in the
1822  // _channel_ dimension.
1823  Value lhs = conv.getDpsInputOperand(0)->get();
1824  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1825  auto shapeWithoutCh = lhsShape.drop_back(1);
1826  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1827  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1828  "channel dim can be dynamic\n");
1829  return failure();
1830  }
1831 
1832  return success();
1833 }
1834 
1835 static LogicalResult
1837  bool flatten1DDepthwiseConv) {
1838  if (isa<ConvolutionOpInterface>(op.getOperation()))
1839  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1840 
1841  if (hasReductionIterator(op))
1842  return reductionPreconditions(op);
1843 
1844  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1845  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1846  if (!isElementwise(op) &&
1847  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1848  op.getOperation()))
1849  return failure();
1850 
1851  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1852  return success();
1853 }
1854 
1855 /// Need to check if the inner-tiles are static/constant.
1856 static LogicalResult
1857 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1858  ArrayRef<int64_t> inputVectorSizes) {
1859 
1860  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1861  return !getConstantIntValue(res).has_value();
1862  })) {
1863  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1864  return failure();
1865  }
1866  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1867  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1868  unpackOp.getDestType().hasStaticShape() &&
1869  unpackOp.getSourceType().hasStaticShape();
1870  if (!satisfyEmptyCond &&
1871  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1872  return failure();
1873 
1874  return success();
1875 }
1876 
1877 static LogicalResult vectorizeLinalgOpPrecondition(
1878  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1879  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1880  // tensor with dimension of 0 cannot be vectorized.
1881  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1882  return failure();
1883  // Check API contract for input vector sizes.
1884  if (!inputVectorSizes.empty() &&
1885  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1886  inputVectorSizes)))
1887  return failure();
1888 
1889  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1890  linalgOp, flatten1DDepthwiseConv))) {
1891  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1892  return failure();
1893  }
1894 
1896 
1897  // Register CustomVectorizationPrecondition for extractOp.
1898  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1899 
1900  // All types in the body should be a supported element type for VectorType.
1901  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1902  // Check if any custom hook can vectorize the inner op.
1903  if (llvm::any_of(
1904  customPreconditions,
1905  [&](const CustomVectorizationPrecondition &customPrecondition) {
1906  return succeeded(
1907  customPrecondition(&innerOp, vectorizeNDExtract));
1908  })) {
1909  continue;
1910  }
1911  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1912  return !VectorType::isValidElementType(type);
1913  })) {
1914  return failure();
1915  }
1916  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1917  return !VectorType::isValidElementType(type);
1918  })) {
1919  return failure();
1920  }
1921  }
1922  if (isElementwise(linalgOp))
1923  return success();
1924 
1925  // TODO: isaConvolutionOpInterface that can also infer from generic
1926  // features. But we will still need stride/dilation attributes that will be
1927  // annoying to reverse-engineer...
1928  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1929  return success();
1930  // TODO: the common vector shape is equal to the static loop sizes only when
1931  // all indexing maps are projected permutations. For convs and stencils the
1932  // logic will need to evolve.
1933  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1934  LDBG("precondition failed: not projected permutations\n");
1935  return failure();
1936  }
1937  if (failed(reductionPreconditions(linalgOp))) {
1938  LDBG("precondition failed: reduction preconditions\n");
1939  return failure();
1940  }
1941  return success();
1942 }
1943 
1944 static LogicalResult
1945 vectorizePackOpPrecondition(tensor::PackOp packOp,
1946  ArrayRef<int64_t> inputVectorSizes) {
1947  auto padValue = packOp.getPaddingValue();
1948  Attribute cstAttr;
1949  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1950  LDBG("pad value is not constant: " << packOp << "\n");
1951  return failure();
1952  }
1953  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1954  bool satisfyEmptyCond = true;
1955  if (inputVectorSizes.empty()) {
1956  if (!packOp.getDestType().hasStaticShape() ||
1957  !packOp.getSourceType().hasStaticShape())
1958  satisfyEmptyCond = false;
1959  }
1960 
1961  if (!satisfyEmptyCond &&
1963  resultTensorShape.take_front(packOp.getSourceRank()),
1964  inputVectorSizes)))
1965  return failure();
1966 
1967  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1968  return !getConstantIntValue(v).has_value();
1969  })) {
1970  LDBG("inner_tiles must be constant: " << packOp << "\n");
1971  return failure();
1972  }
1973 
1974  return success();
1975 }
1976 
1977 static LogicalResult
1978 vectorizePadOpPrecondition(tensor::PadOp padOp,
1979  ArrayRef<int64_t> inputVectorSizes) {
1980  auto padValue = padOp.getConstantPaddingValue();
1981  if (!padValue) {
1982  LDBG("pad value is not constant: " << padOp << "\n");
1983  return failure();
1984  }
1985 
1986  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1987  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1988  inputVectorSizes)))
1989  return failure();
1990 
1991  if (llvm::any_of(padOp.getLow(), [](Value v) {
1992  std::optional<int64_t> res = getConstantIntValue(v);
1993  return !res.has_value() || res.value() != 0;
1994  })) {
1995  LDBG("low pad must all be zero: " << padOp << "\n");
1996  return failure();
1997  }
1998 
1999  return success();
2000 }
2001 
2002 /// Preconditions for scalable vectors. This is quite restrictive - it models
2003 /// the fact that in practice we would only make selected dimensions scalable.
2004 static LogicalResult
2006  ArrayRef<int64_t> inputVectorSizes,
2007  ArrayRef<bool> inputScalableVecDims) {
2008  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2009  "Number of input vector sizes and scalable dims doesn't match");
2010 
2011  size_t numOfScalableDims =
2012  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2013 
2014  if (numOfScalableDims == 0)
2015  return success();
2016 
2017  auto linalgOp = dyn_cast<LinalgOp>(op);
2018 
2019  // Cond 1: There's been no need for scalable vectorisation of
2020  // non-linalg Ops so far
2021  if (!linalgOp)
2022  return failure();
2023 
2024  // Cond 2: There's been no need for more than 2 scalable dims so far
2025  if (numOfScalableDims > 2)
2026  return failure();
2027 
2028  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2029  // it matches one of the supported cases:
2030  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2031  // (*).
2032  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2033  // parallel dims.
2034  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2035  // dim.
2036  // The 2nd restriction above means that only Matmul-like Ops are supported
2037  // when 2 dims are scalable, e.g. :
2038  // * iterators = [parallel, parallel, reduction]
2039  // * scalable flags = [true, true, false]
2040  //
2041  // (*) Non-unit dims get folded away in practice.
2042  // TODO: Relax these conditions as good motivating examples are identified.
2043 
2044  // Find the first scalable flag.
2045  bool seenNonUnitParallel = false;
2046  auto iterators = linalgOp.getIteratorTypesArray();
2047  SmallVector<bool> scalableFlags(inputScalableVecDims);
2048  int64_t idx = scalableFlags.size() - 1;
2049  while (!scalableFlags[idx]) {
2050  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2051  seenNonUnitParallel |=
2052  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2053 
2054  iterators.pop_back();
2055  scalableFlags.pop_back();
2056  --idx;
2057  }
2058 
2059  // Analyze the iterator corresponding to the first scalable dim.
2060  switch (iterators.back()) {
2061  case utils::IteratorType::reduction: {
2062  // Check 3. above is met.
2063  if (iterators.size() != inputVectorSizes.size()) {
2064  LDBG("Non-trailing reduction dim requested for scalable "
2065  "vectorization\n");
2066  return failure();
2067  }
2068  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2069  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2070  "is not supported\n");
2071  return failure();
2072  }
2073  break;
2074  }
2075  case utils::IteratorType::parallel: {
2076  // Check 1. and 2. above are met.
2077  if (seenNonUnitParallel) {
2078  LDBG("Inner parallel dim not requested for scalable "
2079  "vectorization\n");
2080  return failure();
2081  }
2082  break;
2083  }
2084  }
2085 
2086  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2087  // supported for which expect the folowing config:
2088  // * iterators = [parallel, parallel, reduction]
2089  // * scalable flags = [true, true, false]
2090  if (numOfScalableDims == 2) {
2091  // Disallow below case which breaks 3. above:
2092  // * iterators = [..., parallel, reduction]
2093  // * scalable flags = [..., true, true]
2094  if (iterators.back() == utils::IteratorType::reduction) {
2095  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2096  "vectorization\n");
2097  return failure();
2098  }
2099  scalableFlags.pop_back();
2100  iterators.pop_back();
2101 
2102  if (!scalableFlags.back() ||
2103  (iterators.back() != utils::IteratorType::parallel))
2104  return failure();
2105  }
2106 
2107  // Check to not let go the matmul with extended semantic, through this
2108  // transform.
2109  if (linalgOp.hasUserDefinedMaps())
2110  return failure();
2111 
2112  // Cond 4: Only the following ops are supported in the
2113  // presence of scalable vectors
2114  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2115  isa<linalg::MatmulTransposeAOp>(op) ||
2116  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2117  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2118 }
2119 
2121  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2122  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2123  bool flatten1DDepthwiseConv) {
2124 
2125  if (!hasVectorizationImpl(op))
2126  return failure();
2127 
2128  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2129  inputScalableVecDims)))
2130  return failure();
2131 
2133  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2134  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2135  vectorizeNDExtract,
2136  flatten1DDepthwiseConv);
2137  })
2138  .Case<tensor::PadOp>([&](auto padOp) {
2139  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2140  })
2141  .Case<tensor::PackOp>([&](auto packOp) {
2142  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2143  })
2144  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2145  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2146  })
2147  .Default([](auto) { return failure(); });
2148 }
2149 
2150 /// Converts affine.apply Ops to arithmetic operations.
2151 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2152  OpBuilder::InsertionGuard g(rewriter);
2153  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2154 
2155  for (auto op : make_early_inc_range(toReplace)) {
2156  rewriter.setInsertionPoint(op);
2157  auto expanded = affine::expandAffineExpr(
2158  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2159  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2160  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2161  rewriter.replaceOp(op, expanded);
2162  }
2163 }
2164 
2166  return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167  op);
2168 }
2169 
2170 /// Emit a suitable vector form for an operation. If provided,
2171 /// `inputVectorSizes` are used to vectorize this operation.
2172 /// `inputVectorSizes` must match the rank of the iteration space of the
2173 /// operation and the input vector sizes must be greater than or equal to
2174 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2175 /// also allows the vectorization of operations with dynamic shapes.
2176 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2177  ArrayRef<int64_t> inputVectorSizes,
2178  ArrayRef<bool> inputScalableVecDims,
2179  bool vectorizeNDExtract,
2180  bool flatten1DDepthwiseConv) {
2181  LDBG("Attempting to vectorize:\n" << *op << "\n");
2182  LDBG("Input vector sizes: ");
2183  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2184  LLVM_DEBUG(llvm::dbgs() << "\n");
2185  LDBG("Input scalable vector dims: ");
2186  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2187  LLVM_DEBUG(llvm::dbgs() << "\n");
2188 
2189  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2190  vectorizeNDExtract,
2191  flatten1DDepthwiseConv))) {
2192  LDBG("Vectorization pre-conditions failed\n");
2193  return failure();
2194  }
2195 
2196  // Initialize vectorization state.
2197  VectorizationState state(rewriter);
2198  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2199  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2200  inputScalableVecDims))) {
2201  LDBG("Vectorization state couldn't be initialized\n");
2202  return failure();
2203  }
2204  }
2205 
2206  SmallVector<Value> results;
2207  auto vectorizeResult =
2209  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2210  // TODO: isaConvolutionOpInterface that can also infer from
2211  // generic features. Will require stride/dilation attributes
2212  // inference.
2213  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2214  FailureOr<Operation *> convOr = vectorizeConvolution(
2215  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2216  flatten1DDepthwiseConv);
2217  if (succeeded(convOr)) {
2218  llvm::append_range(results, (*convOr)->getResults());
2219  return success();
2220  }
2221 
2222  LDBG("Unsupported convolution can't be vectorized.\n");
2223  return failure();
2224  }
2225 
2226  LDBG("Vectorize generic by broadcasting to the canonical vector "
2227  "shape\n");
2228 
2229  // Pre-process before proceeding.
2230  convertAffineApply(rewriter, linalgOp);
2231 
2232  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2233  // to 'OpBuilder' when it is passed over to some methods like
2234  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2235  // erase an op within these methods, the actual rewriter won't be
2236  // notified and we will end up with read-after-free issues!
2237  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2238  })
2239  .Case<tensor::PadOp>([&](auto padOp) {
2240  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2241  results);
2242  })
2243  .Case<tensor::PackOp>([&](auto packOp) {
2244  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2245  results);
2246  })
2247  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2248  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2249  inputVectorSizes, results);
2250  })
2251  .Default([](auto) { return failure(); });
2252 
2253  if (failed(vectorizeResult)) {
2254  LDBG("Vectorization failed\n");
2255  return failure();
2256  }
2257 
2258  if (!results.empty())
2259  rewriter.replaceOp(op, results);
2260  else
2261  rewriter.eraseOp(op);
2262 
2263  return success();
2264 }
2265 
2267  memref::CopyOp copyOp) {
2268  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2269  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2270  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2271  return failure();
2272 
2273  auto srcElementType = getElementTypeOrSelf(srcType);
2274  auto dstElementType = getElementTypeOrSelf(dstType);
2275  if (!VectorType::isValidElementType(srcElementType) ||
2276  !VectorType::isValidElementType(dstElementType))
2277  return failure();
2278 
2279  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2280  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2281 
2282  Location loc = copyOp->getLoc();
2283  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2284  SmallVector<Value> indices(srcType.getRank(), zero);
2285 
2286  Value readValue = rewriter.create<vector::TransferReadOp>(
2287  loc, readType, copyOp.getSource(), indices,
2288  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2289  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2290  readValue =
2291  rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2292  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2293  }
2294  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2295  loc, readValue, copyOp.getTarget(), indices,
2296  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2297  rewriter.replaceOp(copyOp, writeValue->getResults());
2298  return success();
2299 }
2300 
2301 //----------------------------------------------------------------------------//
2302 // Misc. vectorization patterns.
2303 //----------------------------------------------------------------------------//
2304 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2305 /// given operation type OpTy.
2306 template <typename OpTy>
2307 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2309 
2310  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2311  PatternRewriter &rewriter) const final {
2312  bool changed = false;
2313  // Insert users in vector, because some users may be replaced/removed.
2314  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2315  if (auto op = dyn_cast<OpTy>(user))
2316  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2317  return success(changed);
2318  }
2319 
2320 protected:
2321  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2322  tensor::PadOp padOp, OpTy op) const = 0;
2323 };
2324 
2325 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2326 /// ```
2327 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2328 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2329 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2330 /// ```
2331 /// is rewritten to:
2332 /// ```
2333 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2334 /// {in_bounds = [true, true]}
2335 /// : tensor<?x?xf32>, vector<17x5xf32>
2336 /// ```
2337 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2338 /// sure that the original padding value %cst was never used.
2339 ///
2340 /// This rewrite is possible if:
2341 /// - `xferOp` has no out-of-bounds dims or mask.
2342 /// - Low padding is static 0.
2343 /// - Single, scalar padding value.
2345  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2347  vector::TransferReadOp>::VectorizePadOpUserPattern;
2348 
2349  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2350  vector::TransferReadOp xferOp) const override {
2351  // Low padding must be static 0.
2352  if (!padOp.hasZeroLowPad())
2353  return failure();
2354  // Pad value must be a constant.
2355  auto padValue = padOp.getConstantPaddingValue();
2356  if (!padValue)
2357  return failure();
2358  // Padding value of existing `xferOp` is unused.
2359  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2360  return failure();
2361 
2362  rewriter.modifyOpInPlace(xferOp, [&]() {
2363  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2364  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2365  rewriter.getBoolArrayAttr(inBounds));
2366  xferOp.getSourceMutable().assign(padOp.getSource());
2367  xferOp.getPaddingMutable().assign(padValue);
2368  });
2369 
2370  return success();
2371  }
2372 };
2373 
2374 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2375 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2376 /// value, where the same amount of padding is immediately removed again after
2377 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2378 /// tensor value and apply out-of-bounds masking. E.g.:
2379 /// ```
2380 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2381 /// : tensor<...> to tensor<?x?xf32>
2382 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2383 /// %2 = vector.transfer_write %vec, %1[...]
2384 /// : vector<17x5xf32>, tensor<17x5xf32>
2385 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2386 /// : tensor<17x5xf32> to tensor<?x?xf32>
2387 /// ```
2388 /// is rewritten to:
2389 /// ```
2390 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2391 /// : tensor<...> to tensor<?x?xf32>
2392 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2393 /// tensor<?x?xf32>
2394 /// ```
2395 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2396 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2397 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2398 /// from %r's old dimensions.
2399 ///
2400 /// This rewrite is possible if:
2401 /// - Low padding is static 0.
2402 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2403 /// ExtractSliceOp trims the same amount of padding that was added
2404 /// beforehand.
2405 /// - Single, scalar padding value.
2407  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2409  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2410 
2411  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2412  vector::TransferWriteOp xferOp) const override {
2413  // TODO: support 0-d corner case.
2414  if (xferOp.getTransferRank() == 0)
2415  return failure();
2416 
2417  // Low padding must be static 0.
2418  if (!padOp.hasZeroLowPad())
2419  return failure();
2420  // Pad value must be a constant.
2421  auto padValue = padOp.getConstantPaddingValue();
2422  if (!padValue)
2423  return failure();
2424  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2425  if (!xferOp->hasOneUse())
2426  return failure();
2427  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2428  if (!trimPadding)
2429  return failure();
2430  // Only static zero offsets supported when trimming padding.
2431  if (!trimPadding.hasZeroOffset())
2432  return failure();
2433  // trimPadding must remove the amount of padding that was added earlier.
2434  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2435  return failure();
2436 
2437  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2438  rewriter.setInsertionPoint(xferOp);
2439 
2440  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2441  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2442  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2443  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2444  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2445  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2446 
2447  return success();
2448  }
2449 
2450  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2451  /// i.e., same dimensions.
2452  ///
2453  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2454  /// dimensions, this function tries to infer the (static) tensor size by
2455  /// looking at the defining op and utilizing op-specific knowledge.
2456  ///
2457  /// This is a conservative analysis. In case equal tensor sizes cannot be
2458  /// proven statically, this analysis returns `false` even though the tensor
2459  /// sizes may turn out to be equal at runtime.
2460  bool hasSameTensorSize(Value beforePadding,
2461  tensor::ExtractSliceOp afterTrimming) const {
2462  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2463  // result and CastOp operand.
2464  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2465  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2466  return true;
2467 
2468  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2469  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2470  // Only RankedTensorType supported.
2471  if (!t1 || !t2)
2472  return false;
2473  // Rank of both values must be the same.
2474  if (t1.getRank() != t2.getRank())
2475  return false;
2476 
2477  // All static dimensions must be the same. Mixed cases (e.g., dimension
2478  // static in `t1` but dynamic in `t2`) are not supported.
2479  for (unsigned i = 0; i < t1.getRank(); ++i) {
2480  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2481  return false;
2482  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2483  return false;
2484  }
2485 
2486  // Nothing more to check if all dimensions are static.
2487  if (t1.getNumDynamicDims() == 0)
2488  return true;
2489 
2490  // All dynamic sizes must be the same. The only supported case at the
2491  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2492  // thereof).
2493 
2494  // Apart from CastOp, only ExtractSliceOp is supported.
2495  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2496  if (!beforeSlice)
2497  return false;
2498 
2499  assert(static_cast<size_t>(t1.getRank()) ==
2500  beforeSlice.getMixedSizes().size());
2501  assert(static_cast<size_t>(t2.getRank()) ==
2502  afterTrimming.getMixedSizes().size());
2503 
2504  for (unsigned i = 0; i < t1.getRank(); ++i) {
2505  // Skip static dimensions.
2506  if (!t1.isDynamicDim(i))
2507  continue;
2508  auto size1 = beforeSlice.getMixedSizes()[i];
2509  auto size2 = afterTrimming.getMixedSizes()[i];
2510 
2511  // Case 1: Same value or same constant int.
2512  if (isEqualConstantIntOrValue(size1, size2))
2513  continue;
2514 
2515  // Other cases: Take a deeper look at defining ops of values.
2516  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2517  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2518  if (!v1 || !v2)
2519  return false;
2520 
2521  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2522  // CSE is run.)
2523  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2524  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2525  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2526  minOp1.getOperands() == minOp2.getOperands())
2527  continue;
2528 
2529  // Add additional cases as needed.
2530  }
2531 
2532  // All tests passed.
2533  return true;
2534  }
2535 };
2536 
2537 /// Returns the effective Pad value for the input op, provided it's a scalar.
2538 ///
2539 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2540 /// this Op performs padding, retrieve the padding value provided that it's
2541 /// a scalar and static/fixed for all the padded values. Returns an empty value
2542 /// otherwise.
2544  if (!op)
2545  return {};
2546 
2547  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2548  // being broadcast, provided that it's a scalar.
2549  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2550  auto source = bcast.getSource();
2551  if (llvm::dyn_cast<VectorType>(source.getType()))
2552  return {};
2553 
2554  return source;
2555  }
2556 
2557  // 2. linalg.fill - use the scalar input value that used to fill the output
2558  // tensor.
2559  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2560  return fill.getInputs()[0];
2561  }
2562 
2563  // 3. tensor.generateOp - can't guarantee the value is fixed without
2564  // analysing, bail out.
2565  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2566  return {};
2567  }
2568 
2569  // 4. vector.transfer_write - inspect the input vector that's written from. If
2570  // if contains a single value that has been broadcast (e.g. via
2571  // vector.broadcast), extract it, fail otherwise.
2572  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2573  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2574 
2575  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2576  // than the input tensor, then, provided it's constant, we'll extract the
2577  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2578  // TODO: Clarify the semantics when the input tensor is larger than the
2579  // destination.
2580  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2581  return getStaticPadVal(slice.getDest().getDefiningOp());
2582 
2583  return {};
2584 }
2585 
2586 /// Rewrite tensor.insert.slice as a vector.transfer_read +
2587 /// vector.transfer_write pair. The vector size is inferred from the static
2588 /// dims in the input and output tensors. If a dim is dynamic in both the input
2589 /// and output tensors, bails out.
2590 ///
2591 /// Before:
2592 /// !t_in_type = tensor<1x2x3xf32>
2593 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
2594 /// !v_type = vector<1x2x3xf32>
2595 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596 /// into !t_out_type
2597 /// After:
2598 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600 ///
2601 /// TODO: Support masking
2603  : public OpRewritePattern<tensor::InsertSliceOp> {
2605 
2606  LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2607  PatternRewriter &rewriter) const final {
2608  auto sourceType = sliceOp.getSource().getType();
2609  if (!VectorType::isValidElementType(sourceType.getElementType()))
2610  return failure();
2611 
2612  auto resultType = sliceOp.getResultType();
2613 
2614  // 1. Get the pad value.
2615  // TransferReadOp requires a scalar padding value. Note that:
2616  // * for in-bounds access, the value is actually irrelevant.
2617  // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618  // 1. The source shape is static (output vector sizes would be based on
2619  // the source shape and hence all memory accesses would be in-bounds),
2620  // 2. Masking is used (output vector sizes would be user-provided, in which
2621  // case it is assumed that all memory accesses are in-bounds). This
2622  // remains a TODO.
2623  //
2624  // When the value is not known and not needed, use 0. Otherwise, bail out.
2625  Value padValue = getStaticPadVal(sliceOp);
2626  bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2627 
2628  if (!padValue && isOutOfBoundsRead) {
2629  LDBG("Failed to get a pad value for out-of-bounds read access\n");
2630  return failure();
2631  }
2632 
2633  if (!padValue) {
2634  auto elemType = sourceType.getElementType();
2635  padValue = rewriter.create<arith::ConstantOp>(
2636  sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2637  }
2638 
2639  // 2. Get the vector shape and in-bounds attributes
2640  SmallVector<int64_t> vecShape;
2641  SmallVector<bool> readInBounds;
2642  SmallVector<bool> writeInBounds;
2643  size_t rankDiff = resultType.getRank() - sourceType.getRank();
2644  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2645  if (!sourceType.isDynamicDim(i)) {
2646  vecShape.push_back(sourceType.getDimSize(i));
2647  // Source shape is statically known: Neither read nor write are
2648  // out-of-bounds.
2649  readInBounds.push_back(true);
2650  writeInBounds.push_back(true);
2651  } else if (!resultType.isDynamicDim(i)) {
2652  // Source shape is not statically known, but result shape is.
2653  // Vectorize with size of result shape. This may be larger than the
2654  // source size.
2655  // FIXME: Using rankDiff implies that the source tensor is inserted at
2656  // the end of the destination tensor. However, that's not required.
2657  vecShape.push_back(resultType.getDimSize(rankDiff + i));
2658  // Read may be out-of-bounds because the result size could be larger
2659  // than the source size.
2660  readInBounds.push_back(false);
2661  // Write will in-bounds provided that the corresponding write idx is 0.
2662  // To keep this logic simple, conservatively mark as out-of-bounds.
2663  writeInBounds.push_back(false);
2664  } else {
2665  // Neither source nor result dim of padOp is static. Cannot vectorize
2666  // the copy.
2667  // TODO: Add support for masking
2668  return failure();
2669  }
2670  }
2671  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2672 
2673  // 3. Generate TransferReadOp.
2674  SmallVector<Value> readIndices(
2675  vecType.getRank(),
2676  rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2677  auto read = rewriter.create<vector::TransferReadOp>(
2678  sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2679  ArrayRef<bool>{readInBounds});
2680 
2681  // 4. Generate TransferWriteOp.
2682  auto writeIndices = getValueOrCreateConstantIndexOp(
2683  rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2684 
2685  // 5. Finalize
2686  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2687  sliceOp, read, sliceOp.getDest(), writeIndices,
2688  ArrayRef<bool>{writeInBounds});
2689 
2690  return success();
2691  }
2692 };
2693 
2694 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2695 /// ```
2696 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2697 /// %r = tensor.insert_slice %0
2698 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2699 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2700 /// ```
2701 /// is rewritten to:
2702 /// ```
2703 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2704 /// : tensor<?x?xf32>, vector<17x5xf32>
2705 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2706 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2707 /// ```
2708 ///
2709 /// This rewrite is possible if:
2710 /// - Low padding is static 0.
2711 /// - `padOp` result shape is static.
2712 /// - The entire padded tensor is inserted.
2713 /// (Implies that sizes of `insertOp` are all static.)
2714 /// - Only unit strides in `insertOp`.
2715 /// - Single, scalar padding value.
2716 /// - `padOp` result not used as destination.
2718  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2720  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2721 
2722  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2723  tensor::InsertSliceOp insertOp) const override {
2724  // Low padding must be static 0.
2725  if (!padOp.hasZeroLowPad())
2726  return failure();
2727  // Only unit stride supported.
2728  if (!insertOp.hasUnitStride())
2729  return failure();
2730  // Pad value must be a constant.
2731  auto padValue = padOp.getConstantPaddingValue();
2732  if (!padValue)
2733  return failure();
2734  // Dynamic shapes not supported.
2735  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2736  return failure();
2737  // Pad result not used as destination.
2738  if (insertOp.getDest() == padOp.getResult())
2739  return failure();
2740 
2741  auto vecType = VectorType::get(padOp.getType().getShape(),
2742  padOp.getType().getElementType());
2743  unsigned vecRank = vecType.getRank();
2744  unsigned tensorRank = insertOp.getType().getRank();
2745 
2746  // Check if sizes match: Insert the entire tensor into most minor dims.
2747  // (No permutations allowed.)
2748  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2749  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2750  if (!llvm::all_of(
2751  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2752  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2753  }))
2754  return failure();
2755 
2756  // Insert the TransferReadOp and TransferWriteOp at the position of the
2757  // InsertSliceOp.
2758  rewriter.setInsertionPoint(insertOp);
2759 
2760  // Generate TransferReadOp: Read entire source tensor and add high
2761  // padding.
2762  SmallVector<Value> readIndices(
2763  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2764  auto read = rewriter.create<vector::TransferReadOp>(
2765  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2766 
2767  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2768  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2769  // source must fit into the destination at the specified offsets.
2770  auto writeIndices = getValueOrCreateConstantIndexOp(
2771  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2772  SmallVector<bool> inBounds(vecRank, true);
2773  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2774  insertOp, read, insertOp.getDest(), writeIndices,
2775  ArrayRef<bool>{inBounds});
2776 
2777  return success();
2778  }
2779 };
2780 
2783  patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2784 }
2785 
2787  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2791  patterns.getContext(), baseBenefit.getBenefit() + 1);
2792 }
2793 
2794 //----------------------------------------------------------------------------//
2795 // Forwarding patterns
2796 //----------------------------------------------------------------------------//
2797 
2798 /// Check whether there is any interleaved use of any `values` between
2799 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2800 /// is in a different block.
2801 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2802  ValueRange values) {
2803  if (firstOp->getBlock() != secondOp->getBlock() ||
2804  !firstOp->isBeforeInBlock(secondOp)) {
2805  LDBG("interleavedUses precondition failed, firstOp: "
2806  << *firstOp << ", second op: " << *secondOp << "\n");
2807  return true;
2808  }
2809  for (auto v : values) {
2810  for (auto &u : v.getUses()) {
2811  Operation *owner = u.getOwner();
2812  if (owner == firstOp || owner == secondOp)
2813  continue;
2814  // TODO: this is too conservative, use dominance info in the future.
2815  if (owner->getBlock() == firstOp->getBlock() &&
2816  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2817  continue;
2818  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2819  << ", second op: " << *secondOp << "\n");
2820  return true;
2821  }
2822  }
2823  return false;
2824 }
2825 
2826 /// Return the unique subview use of `v` if it is indeed unique, null
2827 /// otherwise.
2828 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2829  memref::SubViewOp subViewOp;
2830  for (auto &u : v.getUses()) {
2831  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2832  if (subViewOp)
2833  return memref::SubViewOp();
2834  subViewOp = newSubViewOp;
2835  }
2836  }
2837  return subViewOp;
2838 }
2839 
2840 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2841 /// when available.
2843  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2844 
2845  // TODO: support mask.
2846  if (xferOp.getMask())
2847  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2848 
2849  // Transfer into `view`.
2850  Value viewOrAlloc = xferOp.getSource();
2851  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2852  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2853  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2854 
2855  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2856  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2857  if (!subViewOp)
2858  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2859  Value subView = subViewOp.getResult();
2860 
2861  // Find the copy into `subView` without interleaved uses.
2862  memref::CopyOp copyOp;
2863  for (auto &u : subView.getUses()) {
2864  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2865  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2866  if (newCopyOp.getTarget() != subView)
2867  continue;
2868  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2869  continue;
2870  copyOp = newCopyOp;
2871  break;
2872  }
2873  }
2874  if (!copyOp)
2875  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2876 
2877  // Find the fill into `viewOrAlloc` without interleaved uses before the
2878  // copy.
2879  FillOp maybeFillOp;
2880  for (auto &u : viewOrAlloc.getUses()) {
2881  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2882  assert(isa<MemRefType>(newFillOp.output().getType()));
2883  if (newFillOp.output() != viewOrAlloc)
2884  continue;
2885  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2886  continue;
2887  maybeFillOp = newFillOp;
2888  break;
2889  }
2890  }
2891  // Ensure padding matches.
2892  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2893  return rewriter.notifyMatchFailure(xferOp,
2894  "padding value does not match fill");
2895 
2896  // `in` is the subview that memref.copy reads. Replace it.
2897  Value in = copyOp.getSource();
2898 
2899  // memref.copy + linalg.fill can be used to create a padded local buffer.
2900  // The `masked` attribute is only valid on this padded buffer.
2901  // When forwarding to vector.transfer_read, the attribute must be reset
2902  // conservatively.
2903  auto vectorType = xferOp.getVectorType();
2904  Value res = rewriter.create<vector::TransferReadOp>(
2905  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2906  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2907  rewriter.getBoolArrayAttr(
2908  SmallVector<bool>(vectorType.getRank(), false)));
2909 
2910  if (maybeFillOp)
2911  rewriter.eraseOp(maybeFillOp);
2912  rewriter.eraseOp(copyOp);
2913  rewriter.replaceOp(xferOp, res);
2914 
2915  return success();
2916 }
2917 
2918 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2919 /// when available.
2921  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2922  // TODO: support mask.
2923  if (xferOp.getMask())
2924  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2925 
2926  // Transfer into `viewOrAlloc`.
2927  Value viewOrAlloc = xferOp.getSource();
2928  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2929  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2930  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2931 
2932  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2933  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2934  if (!subViewOp)
2935  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2936  Value subView = subViewOp.getResult();
2937 
2938  // Find the copy from `subView` without interleaved uses.
2939  memref::CopyOp copyOp;
2940  for (auto &u : subViewOp.getResult().getUses()) {
2941  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2942  if (newCopyOp.getSource() != subView)
2943  continue;
2944  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2945  continue;
2946  copyOp = newCopyOp;
2947  break;
2948  }
2949  }
2950  if (!copyOp)
2951  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2952 
2953  // `out` is the subview copied into that we replace.
2954  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2955  Value out = copyOp.getTarget();
2956 
2957  // Forward vector.transfer into copy.
2958  // memref.copy + linalg.fill can be used to create a padded local buffer.
2959  // The `masked` attribute is only valid on this padded buffer.
2960  // When forwarding to vector.transfer_write, the attribute must be reset
2961  // conservatively.
2962  auto vector = xferOp.getVector();
2963  rewriter.create<vector::TransferWriteOp>(
2964  xferOp.getLoc(), vector, out, xferOp.getIndices(),
2965  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2966  rewriter.getBoolArrayAttr(
2967  SmallVector<bool>(vector.getType().getRank(), false)));
2968 
2969  rewriter.eraseOp(copyOp);
2970  rewriter.eraseOp(xferOp);
2971 
2972  return success();
2973 }
2974 
2975 //===----------------------------------------------------------------------===//
2976 // Convolution vectorization patterns
2977 //===----------------------------------------------------------------------===//
2978 
2979 template <int N>
2980 static void bindShapeDims(ShapedType shapedType) {}
2981 
2982 template <int N, typename IntTy, typename... IntTy2>
2983 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2984  val = shapedType.getShape()[N];
2985  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2986 }
2987 
2988 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2989 template <typename... IntTy>
2990 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2991  bindShapeDims<0>(shapedType, vals...);
2992 }
2993 
2994 namespace {
2995 bool isCastOfBlockArgument(Operation *op) {
2996  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2997  isa<BlockArgument>(op->getOperand(0));
2998 }
2999 
3000 bool isSupportedPoolKind(vector::CombiningKind kind) {
3001  switch (kind) {
3002  case vector::CombiningKind::ADD:
3003  case vector::CombiningKind::MAXNUMF:
3004  case vector::CombiningKind::MAXIMUMF:
3005  case vector::CombiningKind::MAXSI:
3006  case vector::CombiningKind::MAXUI:
3007  case vector::CombiningKind::MINNUMF:
3008  case vector::CombiningKind::MINIMUMF:
3009  case vector::CombiningKind::MINSI:
3011  return true;
3012  default:
3013  return false;
3014  }
3015 }
3016 
3017 /// Generate a vector implementation for either:
3018 /// ```
3019 /// Op def: ( w, kw )
3020 /// Iters: ({Par(), Red()})
3021 /// Layout: {{w + kw}, {kw}, {w}}
3022 /// ```
3023 /// kw is unrolled.
3024 ///
3025 /// or
3026 ///
3027 /// ```
3028 /// Op def: ( n, w, c, kw, f )
3029 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3030 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3031 /// ```
3032 /// kw is unrolled, w is unrolled iff dilationW > 1.
3033 ///
3034 /// or
3035 ///
3036 /// ```
3037 /// Op def: ( n, c, w, f, kw )
3038 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3039 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3040 /// ```
3041 /// kw is unrolled, w is unrolled iff dilationW > 1.
3042 ///
3043 /// or
3044 ///
3045 /// ```
3046 /// Op def: ( n, w, c, kw )
3047 /// Iters: ({Par(), Par(), Par(), Red()})
3048 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3049 /// ```
3050 /// kw is unrolled, w is unrolled iff dilationW > 1.
3051 struct Conv1DGenerator
3052  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3053  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3054  int dilationW)
3055  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3056  strideW(strideW), dilationW(dilationW) {
3057  // Determine whether `linalgOp` can be generated with this generator
3058  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3059  return;
3060  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3061  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3062  resShaped = linalgOp.getDpsInitOperand(0)->get();
3063  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3064  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3065  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3066  if (!lhsShapedType || !rhsShapedType || !resShapedType)
3067  return;
3068  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3069  // (non-channeled convolution -> LHS and RHS both have single dimensions).
3070  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3071  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3072  return;
3073 
3074  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3075  if (!reduceOp)
3076  return;
3077  redOp = reduceOp->getName().getIdentifier();
3078 
3079  if (!setOperKind(reduceOp))
3080  return;
3081  auto maybeKind = getCombinerOpKind(reduceOp);
3082  // Typically convolution will have a `Add` CombiningKind but for i1 type it
3083  // can get strength reduced to `OR` which is also supported. This strength
3084  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3085  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3086  *maybeKind != vector::CombiningKind::OR) &&
3087  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3088  return;
3089  }
3090  reductionKind = maybeKind.value();
3091 
3092  auto rhsRank = rhsShapedType.getRank();
3093  switch (oper) {
3094  case Conv:
3095  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3096  return;
3097  break;
3098  case Pool:
3099  if (rhsRank != 1)
3100  return;
3101  break;
3102  }
3103  // The op is now known to be valid.
3104  valid = true;
3105  }
3106 
3107  /// Generate a vector implementation for:
3108  /// ```
3109  /// Op def: ( w, kw )
3110  /// Iters: ({Par(), Red()})
3111  /// Layout: {{w + kw}, {kw}, {w}}
3112  /// ```
3113  /// kw is always unrolled.
3114  ///
3115  /// or
3116  ///
3117  /// ```
3118  /// Op def: ( n, w, c, kw, f )
3119  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3120  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3121  /// ```
3122  /// kw is always unrolled.
3123  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3124  /// > 1.
3125  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3126  if (!valid)
3127  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3128 
3129  int64_t nSize, wSize, cSize, kwSize, fSize;
3130  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3131  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3132  switch (conv1DOpOrder) {
3133  case Conv1DOpOrder::W:
3134  // Initialize unused dimensions
3135  nSize = fSize = cSize = 0;
3136  // out{W}
3137  bindShapeDims(resShapedType, wSize);
3138  // kernel{kw}
3139  bindShapeDims(rhsShapedType, kwSize);
3140  lhsShape = {// iw = ow + kw - 1
3141  // (i.e. 16 convolved with 3 -> 14)
3142  (wSize + kwSize - 1)};
3143  rhsShape = {kwSize};
3144  resShape = {wSize};
3145  break;
3146  case Conv1DOpOrder::Nwc:
3147  // out{n, w, f}
3148  bindShapeDims(resShapedType, nSize, wSize, fSize);
3149  switch (oper) {
3150  case Conv:
3151  // kernel{kw, c, f}
3152  bindShapeDims(rhsShapedType, kwSize, cSize);
3153  break;
3154  case Pool:
3155  // kernel{kw}
3156  bindShapeDims(rhsShapedType, kwSize);
3157  cSize = fSize;
3158  break;
3159  }
3160  lhsShape = {nSize,
3161  // iw = ow * sw + kw * dw - 1
3162  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3163  // Perform the proper inclusive -> exclusive -> inclusive.
3164  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3165  1,
3166  cSize};
3167  switch (oper) {
3168  case Conv:
3169  rhsShape = {kwSize, cSize, fSize};
3170  break;
3171  case Pool:
3172  rhsShape = {kwSize};
3173  break;
3174  }
3175  resShape = {nSize, wSize, fSize};
3176  break;
3177  case Conv1DOpOrder::Ncw:
3178  // out{n, f, w}
3179  bindShapeDims(resShapedType, nSize, fSize, wSize);
3180  switch (oper) {
3181  case Conv:
3182  // kernel{f, c, kw}
3183  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3184  break;
3185  case Pool:
3186  // kernel{kw}
3187  bindShapeDims(rhsShapedType, kwSize);
3188  cSize = fSize;
3189  break;
3190  }
3191  lhsShape = {nSize, cSize,
3192  // iw = ow * sw + kw * dw - 1
3193  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3194  // Perform the proper inclusive -> exclusive -> inclusive.
3195  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3196  1};
3197  switch (oper) {
3198  case Conv:
3199  rhsShape = {fSize, cSize, kwSize};
3200  break;
3201  case Pool:
3202  rhsShape = {kwSize};
3203  break;
3204  }
3205  resShape = {nSize, fSize, wSize};
3206  break;
3207  }
3208 
3209  vector::TransferWriteOp write;
3210  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3211 
3212  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3213  // When strideW == 1, we can batch the contiguous loads and avoid
3214  // unrolling
3215  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3216 
3217  Type lhsEltType = lhsShapedType.getElementType();
3218  Type rhsEltType = rhsShapedType.getElementType();
3219  Type resEltType = resShapedType.getElementType();
3220  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3221  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3222  auto resType = VectorType::get(resShape, resEltType);
3223  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3224  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3225  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3226  SmallVector<Value> resPadding(resShape.size(), zero);
3227 
3228  // Read the whole lhs, rhs and res in one shot (with zero padding).
3229  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3230  lhsPadding);
3231  // This is needed only for Conv.
3232  Value rhs = nullptr;
3233  if (oper == Conv)
3234  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3235  rhsPadding);
3236  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3237  resPadding);
3238 
3239  // The base vectorization case for channeled convolution is input:
3240  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3241  // vectorization case, we do pre transpose on input, weight, and output.
3242  switch (conv1DOpOrder) {
3243  case Conv1DOpOrder::W:
3244  case Conv1DOpOrder::Nwc:
3245  // Base case, so no transposes necessary.
3246  break;
3247  case Conv1DOpOrder::Ncw: {
3248  // To match base vectorization case, we pre-transpose current case.
3249  // ncw -> nwc
3250  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3251  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3252  // fcw -> wcf
3253  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3254 
3255  // This is needed only for Conv.
3256  if (oper == Conv)
3257  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3258  // nfw -> nwf
3259  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3260  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3261  break;
3262  }
3263  }
3264 
3265  //===------------------------------------------------------------------===//
3266  // Begin vector-only rewrite part
3267  //===------------------------------------------------------------------===//
3268  // Unroll along kw and read slices of lhs and rhs.
3269  SmallVector<Value> lhsVals, rhsVals, resVals;
3270  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3271  kwSize, strideW, dilationW, wSizeStep,
3272  isSingleChanneled);
3273  // Do not do for pooling.
3274  if (oper == Conv)
3275  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3276  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3277  wSizeStep, isSingleChanneled);
3278 
3279  auto linearIndex = [&](int64_t kw, int64_t w) {
3280  return kw * (wSize / wSizeStep) + w;
3281  };
3282 
3283  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3284  // or perform outerproduct for non-channeled convolution or perform simple
3285  // arith operation for pooling
3286  for (int64_t kw = 0; kw < kwSize; ++kw) {
3287  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3288  switch (oper) {
3289  case Conv:
3290  if (isSingleChanneled) {
3291  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3292  lhsVals[linearIndex(kw, w)],
3293  rhsVals[kw], resVals[w]);
3294  } else {
3295  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3296  lhsVals[linearIndex(kw, w)],
3297  rhsVals[kw], resVals[w]);
3298  }
3299  break;
3300  case Pool:
3301  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3302  resVals[w]);
3303  break;
3304  }
3305  }
3306  }
3307 
3308  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3309  isSingleChanneled);
3310  //===------------------------------------------------------------------===//
3311  // End vector-only rewrite part
3312  //===------------------------------------------------------------------===//
3313 
3314  // The base vectorization case for channeled convolution is output:
3315  // {n,w,f} To reuse the result from base pattern vectorization case, we
3316  // post transpose the base case result.
3317  switch (conv1DOpOrder) {
3318  case Conv1DOpOrder::W:
3319  case Conv1DOpOrder::Nwc:
3320  // Base case, so no transposes necessary.
3321  break;
3322  case Conv1DOpOrder::Ncw: {
3323  // nwf -> nfw
3324  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3325  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3326  break;
3327  }
3328  }
3329 
3330  return rewriter
3331  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3332  .getOperation();
3333  }
3334 
3335  // Take a value and widen to have the same element type as `ty`.
3336  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3337  const Type srcElementType = getElementTypeOrSelf(val.getType());
3338  const Type dstElementType = getElementTypeOrSelf(ty);
3339  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3340  if (srcElementType == dstElementType)
3341  return val;
3342 
3343  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3344  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3345  const Type dstType =
3346  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3347 
3348  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3349  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3350  }
3351 
3352  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3353  srcWidth < dstWidth)
3354  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3355 
3356  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3357  srcWidth < dstWidth)
3358  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3359 
3360  assert(false && "unhandled promotion case");
3361  return nullptr;
3362  }
3363 
3364  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3365  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3366  Value lhs, Value rhs, Value res) {
3367  vector::IteratorType par = vector::IteratorType::parallel;
3368  vector::IteratorType red = vector::IteratorType::reduction;
3369  AffineExpr n, w, f, c;
3370  bindDims(ctx, n, w, f, c);
3371  lhs = promote(rewriter, loc, lhs, res.getType());
3372  rhs = promote(rewriter, loc, rhs, res.getType());
3373  auto contrationOp = rewriter.create<vector::ContractionOp>(
3374  loc, lhs, rhs, res,
3375  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3376  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3377  contrationOp.setKind(reductionKind);
3378  return contrationOp;
3379  }
3380 
3381  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3382  // convolution.
3383  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3384  Value lhs, Value rhs, Value res) {
3385  return rewriter.create<vector::OuterProductOp>(
3386  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3387  }
3388 
3389  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3390  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3391  Value res) {
3392  if (isPoolExt)
3393  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3394  return rewriter
3395  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3396  ->getResult(0);
3397  }
3398 
3399  /// Generate a vector implementation for:
3400  /// ```
3401  /// Op def: ( n, w, c, kw)
3402  /// Iters: ({Par(), Par(), Par(), Red()})
3403  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3404  /// ```
3405  /// kw is always unrolled.
3406  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3407  /// > 1.
3408  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3409  bool channelDimScalableFlag,
3410  bool flatten) {
3411  if (!valid)
3412  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3413 
3414  bool scalableChDim = false;
3415  bool useMasking = false;
3416  int64_t nSize, wSize, cSize, kwSize;
3417  // kernel{kw, c}
3418  bindShapeDims(rhsShapedType, kwSize, cSize);
3419  if (ShapedType::isDynamic(cSize)) {
3420  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3421  cSize = channelDimVecSize;
3422  // Scalable vectors are only used when both conditions are met:
3423  // 1. channel dim is dynamic
3424  // 2. channelDimScalableFlag is set
3425  scalableChDim = channelDimScalableFlag;
3426  useMasking = true;
3427  }
3428 
3429  assert(!(useMasking && flatten) &&
3430  "Unsupported flattened conv with dynamic shapes");
3431 
3432  // out{n, w, c}
3433  bindShapeDims(resShapedType, nSize, wSize);
3434 
3435  vector::TransferWriteOp write;
3436  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3437 
3438  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3439  // When strideW == 1, we can batch the contiguous loads and avoid
3440  // unrolling
3441  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3442 
3443  Type lhsEltType = lhsShapedType.getElementType();
3444  Type rhsEltType = rhsShapedType.getElementType();
3445  Type resEltType = resShapedType.getElementType();
3446  VectorType lhsType = VectorType::get(
3447  {nSize,
3448  // iw = ow * sw + kw * dw - 1
3449  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3450  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3451  cSize},
3452  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3453  VectorType rhsType =
3454  VectorType::get({kwSize, cSize}, rhsEltType,
3455  /*scalableDims=*/{false, scalableChDim});
3456  VectorType resType =
3457  VectorType::get({nSize, wSize, cSize}, resEltType,
3458  /*scalableDims=*/{false, false, scalableChDim});
3459 
3460  // Masks the input xfer Op along the channel dim, iff the corresponding
3461  // scalable flag is set.
3462  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3463  ArrayRef<bool> scalableDims,
3464  Operation *opToMask) {
3465  if (!useMasking)
3466  return opToMask;
3467  auto maskType =
3468  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3469 
3470  SmallVector<bool> inBounds(maskShape.size(), true);
3471  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3472  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3473  rewriter.getBoolArrayAttr(inBounds));
3474 
3476  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3477 
3478  Value maskOp =
3479  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3480 
3481  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3482  };
3483 
3484  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3485  // 0].
3486  Value lhs = rewriter.create<vector::TransferReadOp>(
3487  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3488  auto maybeMaskedLhs = maybeMaskXferOp(
3489  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3490 
3491  // Read rhs slice of size {kw, c} @ [0, 0].
3492  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3493  ValueRange{zero, zero});
3494  auto maybeMaskedRhs = maybeMaskXferOp(
3495  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3496 
3497  // Read res slice of size {n, w, c} @ [0, 0, 0].
3498  Value res = rewriter.create<vector::TransferReadOp>(
3499  loc, resType, resShaped, ValueRange{zero, zero, zero});
3500  auto maybeMaskedRes = maybeMaskXferOp(
3501  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3502 
3503  //===------------------------------------------------------------------===//
3504  // Begin vector-only rewrite part
3505  //===------------------------------------------------------------------===//
3506  // Unroll along kw and read slices of lhs and rhs.
3507  SmallVector<Value> lhsVals, rhsVals, resVals;
3508  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3509  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3510 
3511  // Extract lhs slice of size {n, wSizeStep, c}
3512  // @ [0, sw * w + dw * kw, 0].
3513  for (int64_t kw = 0; kw < kwSize; ++kw) {
3514  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3515  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3516  loc, maybeMaskedLhs->getResult(0),
3517  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3518  inOutSliceSizes, inOutStrides));
3519  }
3520  }
3521  // Extract rhs slice of size {c} @ [kw].
3522  for (int64_t kw = 0; kw < kwSize; ++kw) {
3523  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3524  loc, maybeMaskedRhs->getResult(0),
3525  /*offsets=*/ArrayRef<int64_t>{kw}));
3526  }
3527  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3528  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3529  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3530  loc, maybeMaskedRes->getResult(0),
3531  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3532  inOutStrides));
3533  }
3534 
3535  auto linearIndex = [&](int64_t kw, int64_t w) {
3536  return kw * (wSize / wSizeStep) + w;
3537  };
3538 
3539  // Note - the scalable flags are ignored as flattening combined with
3540  // scalable vectorization is not supported.
3541  auto inOutFlattenSliceSizes =
3542  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3543  auto lhsTypeAfterFlattening =
3544  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3545  auto resTypeAfterFlattening =
3546  VectorType::get(inOutFlattenSliceSizes, resEltType);
3547 
3548  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3549  for (int64_t kw = 0; kw < kwSize; ++kw) {
3550  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3551  Value lhsVal = lhsVals[linearIndex(kw, w)];
3552  Value resVal = resVals[w];
3553  if (flatten) {
3554  // Flatten the input and output vectors (collapse the channel
3555  // dimension)
3556  lhsVal = rewriter.create<vector::ShapeCastOp>(
3557  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3558  resVal = rewriter.create<vector::ShapeCastOp>(
3559  loc, resTypeAfterFlattening, resVals[w]);
3560  }
3561  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3562  rhsVals[kw], resVal, flatten);
3563  if (flatten) {
3564  // Un-flatten the output vector (restore the channel dimension)
3565  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3566  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3567  }
3568  }
3569  }
3570 
3571  // Its possible we failed to create the Fma.
3572  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3573  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3574  for (auto &collection :
3575  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3576  for (Value v : collection)
3577  rewriter.eraseOp(v.getDefiningOp());
3578  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3579  }
3580 
3581  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3582  // This does not depend on kw.
3583  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3584  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3585  loc, resVals[w], maybeMaskedRes->getResult(0),
3586  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3587  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3588  }
3589  //===------------------------------------------------------------------===//
3590  // End vector-only rewrite part
3591  //===------------------------------------------------------------------===//
3592 
3593  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3594  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3595  loc, maybeMaskedRes->getResult(0), resShaped,
3596  ValueRange{zero, zero, zero});
3597  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3598  resOut);
3599  }
3600 
3601  /// Lower:
3602  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3603  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3604  /// to MulAcc.
3605  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3606  Value lhs, Value rhs, Value res,
3607  bool flatten) {
3608  auto rhsTy = cast<ShapedType>(rhs.getType());
3609  auto resTy = cast<ShapedType>(res.getType());
3610 
3611  // TODO(suderman): Change this to use a vector.ima intrinsic.
3612  lhs = promote(rewriter, loc, lhs, resTy);
3613 
3614  if (flatten) {
3615  // NOTE: This following logic won't work for scalable vectors. For this
3616  // reason, "flattening" is not supported when shapes are dynamic (this
3617  // should be captured by one of the pre-conditions).
3618 
3619  // There are two options for handling the filter:
3620  // * shape_cast(broadcast(filter))
3621  // * broadcast(shuffle(filter))
3622  // Opt for the option without shape_cast to simplify the codegen.
3623  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3624  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3625 
3626  SmallVector<int64_t, 16> indices;
3627  for (int i = 0; i < resSize / rhsSize; ++i) {
3628  for (int j = 0; j < rhsSize; ++j)
3629  indices.push_back(j);
3630  }
3631 
3632  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3633  }
3634  // Broadcast the filter to match the output vector
3635  rhs = rewriter.create<vector::BroadcastOp>(
3636  loc, resTy.clone(rhsTy.getElementType()), rhs);
3637 
3638  rhs = promote(rewriter, loc, rhs, resTy);
3639 
3640  if (!lhs || !rhs)
3641  return nullptr;
3642 
3643  if (isa<FloatType>(resTy.getElementType()))
3644  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3645 
3646  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3647  return rewriter.create<arith::AddIOp>(loc, mul, res);
3648  }
3649 
3650  /// Entry point for non-channeled convolution:
3651  /// {{w + kw}, {kw}, {w}}
3652  FailureOr<Operation *> generateNonChanneledConv() {
3653  AffineExpr w, kw;
3654  bindDims(ctx, w, kw);
3655  if (!iters({Par(), Red()}))
3656  return rewriter.notifyMatchFailure(op,
3657  "failed to match conv::W 1-par 1-red");
3658 
3659  // No transposition needed.
3660  if (layout({/*lhsIndex*/ {w + kw},
3661  /*rhsIndex*/ {kw},
3662  /*resIndex*/ {w}}))
3663  return conv(Conv1DOpOrder::W);
3664 
3665  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3666  }
3667 
3668  /// Entry point that transposes into the common form:
3669  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3670  FailureOr<Operation *> generateNwcConv() {
3671  AffineExpr n, w, f, kw, c;
3672  bindDims(ctx, n, w, f, kw, c);
3673  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3674  return rewriter.notifyMatchFailure(
3675  op, "failed to match conv::Nwc 3-par 2-red");
3676 
3677  // No transposition needed.
3678  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3679  /*rhsIndex*/ {kw, c, f},
3680  /*resIndex*/ {n, w, f}}))
3681  return conv(Conv1DOpOrder::Nwc);
3682 
3683  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3684  }
3685 
3686  /// Entry point that transposes into the common form:
3687  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3688  FailureOr<Operation *> generateNcwConv() {
3689  AffineExpr n, w, f, kw, c;
3690  bindDims(ctx, n, f, w, c, kw);
3691  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3692  return rewriter.notifyMatchFailure(
3693  op, "failed to match conv::Ncw 3-par 2-red");
3694 
3695  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3696  /*rhsIndex*/ {f, c, kw},
3697  /*resIndex*/ {n, f, w}}))
3698  return conv(Conv1DOpOrder::Ncw);
3699 
3700  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3701  }
3702 
3703  /// Entry point that transposes into the common form:
3704  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3705  FailureOr<Operation *> generateNwcPooling() {
3706  AffineExpr n, w, c, kw;
3707  bindDims(ctx, n, w, c, kw);
3708  if (!iters({Par(), Par(), Par(), Red()}))
3709  return rewriter.notifyMatchFailure(op,
3710  "failed to match pooling 3-par 1-red");
3711 
3712  // No transposition needed.
3713  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3714  /*rhsIndex*/ {kw},
3715  /*resIndex*/ {n, w, c}}))
3716  return conv(Conv1DOpOrder::Nwc);
3717 
3718  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3719  }
3720 
3721  /// Entry point that transposes into the common form:
3722  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3723  FailureOr<Operation *> generateNcwPooling() {
3724  AffineExpr n, w, c, kw;
3725  bindDims(ctx, n, c, w, kw);
3726  if (!iters({Par(), Par(), Par(), Red()}))
3727  return rewriter.notifyMatchFailure(op,
3728  "failed to match pooling 3-par 1-red");
3729 
3730  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3731  /*rhsIndex*/ {kw},
3732  /*resIndex*/ {n, c, w}}))
3733  return conv(Conv1DOpOrder::Ncw);
3734 
3735  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3736  }
3737 
3738  /// Entry point that transposes into the common form:
3739  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3740  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3741  bool vecChDimScalableFlag = false,
3742  bool flatten = false) {
3743  AffineExpr n, w, c, kw;
3744  bindDims(ctx, n, w, c, kw);
3745  if (!iters({Par(), Par(), Par(), Red()}))
3746  return rewriter.notifyMatchFailure(
3747  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3748 
3749  // No transposition needed.
3750  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3751  /*rhsIndex*/ {kw, c},
3752  /*resIndex*/ {n, w, c}}))
3753  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3754 
3755  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3756  }
3757 
3758 private:
3759  enum OperKind { Conv, Pool };
3760  bool valid = false;
3761  OperKind oper = Conv;
3762  StringAttr redOp;
3763  StringAttr poolExtOp;
3764  bool isPoolExt = false;
3765  int strideW, dilationW;
3766  Value lhsShaped, rhsShaped, resShaped;
3767  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3768  vector::CombiningKind reductionKind;
3769 
3770  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3771  // Returns true iff it is a valid conv/pooling op.
3772  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3773  // + yield) and rhs is not used) then it is the body of a pooling
3774  // If conv, check for single `mul` predecessor. The `mul` operands must be
3775  // block arguments or extension of block arguments.
3776  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3777  // must be block arguments or extension of block arguments.
3778  bool setOperKind(Operation *reduceOp) {
3779  int numBlockArguments =
3780  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3781  switch (numBlockArguments) {
3782  case 1: {
3783  // Will be convolution if feeder is a MulOp.
3784  // A strength reduced version of MulOp for i1 type is AndOp which is also
3785  // supported. Otherwise, it can be pooling. This strength reduction logic
3786  // is in `buildBinaryFn` helper in the Linalg dialect.
3787  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3788  llvm::IsaPred<BlockArgument>);
3789  Operation *feedOp = (*feedValIt).getDefiningOp();
3790  if (isCastOfBlockArgument(feedOp)) {
3791  oper = Pool;
3792  isPoolExt = true;
3793  poolExtOp = feedOp->getName().getIdentifier();
3794  } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3795  (isa<arith::AndIOp>(feedOp) &&
3796  feedOp->getResultTypes()[0].isInteger(1))) &&
3797  llvm::all_of(feedOp->getOperands(), [](Value v) {
3798  if (isa<BlockArgument>(v))
3799  return true;
3800  if (Operation *op = v.getDefiningOp())
3801  return isCastOfBlockArgument(op);
3802  return false;
3803  }))) {
3804  return false;
3805  }
3806  return true;
3807  }
3808  case 2:
3809  // Must be pooling
3810  oper = Pool;
3811  isPoolExt = false;
3812  return true;
3813  default:
3814  return false;
3815  }
3816  }
3817 };
3818 } // namespace
3819 
3820 /// Helper function to vectorize a LinalgOp with convolution semantics.
3821 // TODO: extend the generic vectorization to support windows and drop this.
3822 static FailureOr<Operation *> vectorizeConvolution(
3823  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3824  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3825  // The ConvolutionOpInterface gives us guarantees of existence for
3826  // strides/dilations. However, we do not need to rely on those, we can
3827  // simply use them if present, otherwise use the default and let the generic
3828  // conv. matcher in the ConvGenerator succeed or fail.
3829  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3830  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3831  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3832  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3833  Conv1DGenerator e(rewriter, op, stride, dilation);
3834  auto res = e.generateNonChanneledConv();
3835  if (succeeded(res))
3836  return res;
3837  res = e.generateNwcConv();
3838  if (succeeded(res))
3839  return res;
3840  res = e.generateNcwConv();
3841  if (succeeded(res))
3842  return res;
3843  res = e.generateNwcPooling();
3844  if (succeeded(res))
3845  return res;
3846  res = e.generateNcwPooling();
3847  if (succeeded(res))
3848  return res;
3849 
3850  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3851  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3852  // masked/scalable) is the channel dim (i.e. the trailing dim).
3853  uint64_t vecChDimSize = ShapedType::kDynamic;
3854  bool vecChDimScalableFlag = false;
3855  if (!inputVecSizes.empty()) {
3856  // Only use the input vector size corresponding to the channel dim. Other
3857  // vector dims will be inferred from the Ops.
3858  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3859  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3860  "Not a 1D depthwise conv!");
3861  size_t chDimIdx =
3863  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3864  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3865 
3866  vecChDimSize = inputVecSizes[chDimIdx];
3867  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3868  }
3869  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3870  flatten1DDepthwiseConv);
3871 }
3872 
3875 
3876  LogicalResult matchAndRewrite(LinalgOp op,
3877  PatternRewriter &rewriter) const override {
3878  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3879  if (failed(resultOrFail))
3880  return failure();
3881  Operation *newOp = *resultOrFail;
3882  if (newOp->getNumResults() == 0) {
3883  rewriter.eraseOp(op.getOperation());
3884  return success();
3885  }
3886  assert(newOp->getNumResults() == 1 && "expected single result");
3887  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3888  return success();
3889  }
3890 };
3891 
3894  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3895 }
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:142
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:606
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:454
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:216
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2443
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:815
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite tensor.insert.slice as a vector.transfer_read + vector.transfer_write pair.
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp, PatternRewriter &rewriter) const final
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.