MLIR  20.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66  OpType res;
67  block.walk([&](OpType op) {
68  if (res) {
69  res = nullptr;
70  return WalkResult::interrupt();
71  }
72  res = op;
73  return WalkResult::advance();
74  });
75  return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
82  int64_t nSize, int64_t wSize, int64_t cSize,
83  int64_t kwSize, int strideW, int dilationW,
84  int64_t wSizeStep, bool isSingleChanneled) {
85  SmallVector<Value> result;
86  if (isSingleChanneled) {
87  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88  // convolution.
89  SmallVector<int64_t> sizes{wSizeStep};
90  SmallVector<int64_t> strides{1};
91  for (int64_t kw = 0; kw < kwSize; ++kw) {
92  for (int64_t w = 0; w < wSize; w += wSizeStep) {
93  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95  }
96  }
97  } else {
98  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99  // for channeled convolution.
100  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101  SmallVector<int64_t> strides{1, 1, 1};
102  for (int64_t kw = 0; kw < kwSize; ++kw) {
103  for (int64_t w = 0; w < wSize; w += wSizeStep) {
104  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105  loc, input,
106  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107  sizes, strides));
108  }
109  }
110  }
111  return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
117  Location loc, Value filter,
118  int64_t kwSize) {
119  SmallVector<Value> result;
120  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121  // non-chanelled convolution] @ [kw].
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  result.push_back(rewriter.create<vector::ExtractOp>(
124  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125  }
126  return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
133  int64_t nSize, int64_t wSize, int64_t fSize,
134  int64_t wSizeStep, bool isSingleChanneled) {
135  SmallVector<Value> result;
136  if (isSingleChanneled) {
137  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138  SmallVector<int64_t> sizes{wSizeStep};
139  SmallVector<int64_t> strides{1};
140  for (int64_t w = 0; w < wSize; w += wSizeStep) {
141  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143  }
144  } else {
145  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146  // convolution.
147  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148  SmallVector<int64_t> strides{1, 1, 1};
149  for (int64_t w = 0; w < wSize; w += wSizeStep) {
150  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152  }
153  }
154  return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
159  Value res, int64_t wSize, int64_t wSizeStep,
160  SmallVectorImpl<Value> &resVals,
161  bool isSingleChanneled) {
162 
163  if (isSingleChanneled) {
164  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165  // This does not depend on kw.
166  SmallVector<int64_t> strides{1};
167  for (int64_t w = 0; w < wSize; w += wSizeStep) {
168  res = rewriter.create<vector::InsertStridedSliceOp>(
169  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170  }
171  } else {
172  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173  // convolution. This does not depend on kw.
174  SmallVector<int64_t> strides{1, 1, 1};
175  for (int64_t w = 0; w < wSize; w += wSizeStep) {
176  res = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178  strides);
179  }
180  }
181  return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
187  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189  /// Initializes the vectorization state, including the computation of the
190  /// canonical vector shape for vectorization.
191  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192  ArrayRef<int64_t> inputVectorSizes,
193  ArrayRef<bool> inputScalableVecDims);
194 
195  /// Returns the canonical vector shape used to vectorize the iteration space.
196  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198  /// Returns the vector dimensions that are scalable in the canonical vector
199  /// shape.
200  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202  /// Returns a vector type of the provided `elementType` with the canonical
203  /// vector shape and the corresponding fixed/scalable dimensions bit. If
204  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205  /// accordingly.
207  Type elementType,
208  std::optional<AffineMap> dimPermutation = std::nullopt) const {
210  SmallVector<bool> scalableDims;
211  if (dimPermutation.has_value()) {
212  vectorShape =
213  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214  scalableDims =
215  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216  } else {
217  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219  }
220 
221  return VectorType::get(vectorShape, elementType, scalableDims);
222  }
223 
224  /// Masks an operation with the canonical vector mask if the operation needs
225  /// masking. Returns the masked operation or the original operation if masking
226  /// is not needed. If provided, the canonical mask for this operation is
227  /// permuted using `maybeIndexingMap`.
228  Operation *
229  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230  std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231 
232 private:
233  /// Initializes the iteration space static sizes using the Linalg op
234  /// information. This may become more complicated in the future.
235  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237  }
238 
239  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240  /// all the static and dynamic dimensions of the iteration space to be
241  /// vectorized and store them in `iterSpaceValueSizes`.
242  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243  LinalgOp linalgOp);
244 
245  /// Create or retrieve an existing mask value to mask `opToMask` in the
246  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247  /// permuted using that permutation map. If a new mask is created, it will be
248  /// cached for future users.
249  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250  LinalgOp linalgOp,
251  std::optional<AffineMap> maybeMaskingMap);
252 
253  /// Check whether this permutation map can be used for masking. At the
254  /// moment we only make sure that there are no broadcast dimensions, but this
255  /// might change if indexing maps evolve.
256  bool isValidMaskingMap(AffineMap maskingMap) {
257  return maskingMap.getBroadcastDims().size() == 0;
258  }
259 
260  /// Turn the input indexing map into a valid masking map.
261  ///
262  /// The input indexing map may contain "zero" results, e.g.:
263  /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264  /// Applying such maps to canonical vector shapes like this one:
265  /// (1, 16, 16, 4)
266  /// would yield an invalid vector shape like this:
267  /// (16, 16, 1, 0)
268  /// Instead, drop the broadcasting dims that make no sense for masking perm.
269  /// maps:
270  /// (d0, d1, d2, d3) -> (d2, d1, d0)
271  /// This way, the corresponding vector/mask type will be:
272  /// vector<16x16x1xty>
273  /// rather than this invalid Vector type:
274  /// vector<16x16x1x0xty>
275  AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
276  return indexingMap.dropZeroResults();
277  }
278 
279  // Holds the compile-time static sizes of the iteration space to vectorize.
280  // Dynamic dimensions are represented using ShapedType::kDynamic.
281  SmallVector<int64_t> iterSpaceStaticSizes;
282 
283  /// Holds the value sizes of the iteration space to vectorize. Static
284  /// dimensions are represented by 'arith.constant' and dynamic
285  /// dimensions by 'tensor/memref.dim'.
286  SmallVector<Value> iterSpaceValueSizes;
287 
288  /// Holds the canonical vector shape used to vectorize the iteration space.
289  SmallVector<int64_t> canonicalVecShape;
290 
291  /// Holds the vector dimensions that are scalable in the canonical vector
292  /// shape.
293  SmallVector<bool> scalableVecDims;
294 
295  /// Holds the active masks for permutations of the canonical vector iteration
296  /// space.
297  DenseMap<AffineMap, Value> activeMaskCache;
298 
299  /// Global vectorization guard for the incoming rewriter. It's initialized
300  /// when the vectorization state is initialized.
302 };
303 
304 LogicalResult
305 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
306  LinalgOp linalgOp) {
307  // TODO: Support 0-d vectors.
308  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
310  // Create constant index op for static dimensions.
311  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
312  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
313  continue;
314  }
315 
316  // Find an operand defined on this dimension of the iteration space to
317  // extract the runtime dimension size.
318  Value operand;
319  unsigned operandDimPos;
320  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
321  operandDimPos)))
322  return failure();
323 
324  Value dynamicDim = linalgOp.hasPureTensorSemantics()
325  ? (Value)rewriter.create<tensor::DimOp>(
326  linalgOp.getLoc(), operand, operandDimPos)
327  : (Value)rewriter.create<memref::DimOp>(
328  linalgOp.getLoc(), operand, operandDimPos);
329  iterSpaceValueSizes.push_back(dynamicDim);
330  }
331 
332  return success();
333 }
334 
335 /// Initializes the vectorization state, including the computation of the
336 /// canonical vector shape for vectorization.
337 // TODO: Move this to the constructor when we can remove the failure cases.
338 LogicalResult
339 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
340  ArrayRef<int64_t> inputVectorSizes,
341  ArrayRef<bool> inputScalableVecDims) {
342  // Initialize the insertion point.
343  rewriter.setInsertionPoint(linalgOp);
344 
345  if (!inputVectorSizes.empty()) {
346  // Get the canonical vector shape from the input vector sizes provided. This
347  // path should be taken to vectorize code with dynamic shapes and when using
348  // vector sizes greater than the iteration space sizes.
349  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350  scalableVecDims.append(inputScalableVecDims.begin(),
351  inputScalableVecDims.end());
352  } else {
353  // Compute the canonical vector shape from the operation shape. If there are
354  // dynamic shapes, the operation won't be vectorized. We assume all the
355  // vector dimensions are fixed.
356  canonicalVecShape = linalgOp.getStaticLoopRanges();
357  scalableVecDims.append(linalgOp.getNumLoops(), false);
358  }
359 
360  LDBG("Canonical vector shape: ");
361  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362  LLVM_DEBUG(llvm::dbgs() << "\n");
363  LDBG("Scalable vector dims: ");
364  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365  LLVM_DEBUG(llvm::dbgs() << "\n");
366 
367  if (ShapedType::isDynamicShape(canonicalVecShape))
368  return failure();
369 
370  // Initialize iteration space static sizes.
371  initIterSpaceStaticSizes(linalgOp);
372 
373  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
374  // all the static and dynamic dimensions of the iteration space, needed to
375  // compute a mask during vectorization.
376  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
377  return failure();
378 
379  return success();
380 }
381 
382 /// Create or retrieve an existing mask value to mask `opToMask` in the
383 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
384 /// using that permutation map. If a new mask is created, it will be cached for
385 /// future users.
386 Value VectorizationState::getOrCreateMaskFor(
387  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
388  std::optional<AffineMap> maybeMaskingMap) {
389 
390  assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391  "Ill-formed masking map.");
392 
393  // No mask is needed if the operation is not maskable.
394  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
395  if (!maskableOp)
396  return Value();
397 
398  assert(!maskableOp.isMasked() &&
399  "Masking an operation that is already masked");
400 
401  // If no masking map was provided, use an identity map with the loop dims.
402  assert((!maybeMaskingMap || *maybeMaskingMap) &&
403  "Unexpected null mask permutation map");
404  AffineMap maskingMap =
405  maybeMaskingMap ? *maybeMaskingMap
407  linalgOp.getNumLoops(), rewriter.getContext());
408 
409  LDBG("Masking map: " << maskingMap << "\n");
410 
411  // Return the active mask for the masking map of this operation if it was
412  // already created.
413  auto activeMaskIt = activeMaskCache.find(maskingMap);
414  if (activeMaskIt != activeMaskCache.end()) {
415  Value mask = activeMaskIt->second;
416  LDBG("Reusing mask: " << mask << "\n");
417  return mask;
418  }
419 
420  // Compute permuted projection of the iteration space to be masked and the
421  // corresponding mask shape. If the resulting iteration space dimensions are
422  // static and identical to the mask shape, masking is not needed for this
423  // operation.
424  // TODO: Improve this check. Only projected permutation indexing maps are
425  // supported.
426  SmallVector<int64_t> permutedStaticSizes =
427  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
429  auto maskShape = maskType.getShape();
430 
431  LDBG("Mask shape: ");
432  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433  LLVM_DEBUG(llvm::dbgs() << "\n");
434 
435  if (permutedStaticSizes == maskShape) {
436  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
437  activeMaskCache[maskingMap] = Value();
438  return Value();
439  }
440 
441  // Permute the iteration space value sizes to compute the mask upper bounds.
442  SmallVector<Value> upperBounds =
443  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
444  assert(!maskShape.empty() && !upperBounds.empty() &&
445  "Masked 0-d vectors are not supported yet");
446 
447  // Create the mask based on the dimension values.
448  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
449  maskType, upperBounds);
450  LDBG("Creating new mask: " << mask << "\n");
451  activeMaskCache[maskingMap] = mask;
452  return mask;
453 }
454 
455 Operation *
457  LinalgOp linalgOp,
458  std::optional<AffineMap> maybeIndexingMap) {
459  LDBG("Trying to mask: " << *opToMask << "\n");
460 
461  std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462  if (maybeIndexingMap)
463  maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
464 
465  // Create or retrieve mask for this operation.
466  Value mask =
467  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
468 
469  if (!mask) {
470  LDBG("No mask required\n");
471  return opToMask;
472  }
473 
474  // Wrap the operation with a new `vector.mask` and update D-U chain.
475  assert(opToMask && "Expected a valid operation to mask");
476  auto maskOp = cast<vector::MaskOp>(
477  mlir::vector::maskOperation(rewriter, opToMask, mask));
478  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
479 
480  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
481  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
482  maskOpTerminator);
483 
484  LDBG("Masked operation: " << *maskOp << "\n");
485  return maskOp;
486 }
487 
488 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
489 /// projectedPermutation, compress the unused dimensions to serve as a
490 /// permutation_map for a vector transfer operation.
491 /// For example, given a linalg op such as:
492 ///
493 /// ```
494 /// %0 = linalg.generic {
495 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
496 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
497 /// }
498 /// ins(%0 : tensor<2x3x4xf32>)
499 /// outs(%1 : tensor<5x6xf32>)
500 /// ```
501 ///
502 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
503 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
504 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
506  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
507  "expected projected permutation");
508  auto res = compressUnusedDims(map);
509  assert(res.getNumDims() ==
510  (res.getNumResults() - res.getNumOfZeroResults()) &&
511  "expected reindexed map with same number of dims and results");
512  return res;
513 }
514 
515 /// Helper enum to represent conv1d input traversal order.
516 enum class Conv1DOpOrder {
517  W, // Corresponds to non-channeled 1D convolution operation.
518  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
519  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
520 };
521 
522 /// Helper data structure to represent the result of vectorization.
523 /// In certain specific cases, like terminators, we do not want to propagate/
525  /// Op failed to vectorize.
526  Failure = 0,
527  /// Op vectorized and custom function took care of replacement logic
529  /// Op vectorized into a new Op whose results will replace original Op's
530  /// results.
531  NewOp
532  // TODO: support values if Op vectorized to Many-Ops whose results we need to
533  // aggregate for replacement.
534 };
536  /// Return status from vectorizing the current op.
538  /// New vectorized operation to replace the current op.
539  /// Replacement behavior is specified by `status`.
541 };
542 
543 std::optional<vector::CombiningKind>
545  using ::mlir::vector::CombiningKind;
546 
547  if (!combinerOp)
548  return std::nullopt;
550  .Case<arith::AddIOp, arith::AddFOp>(
551  [&](auto op) { return CombiningKind::ADD; })
552  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
553  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
554  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
555  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
556  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
557  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
558  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
559  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
560  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
561  .Case<arith::MulIOp, arith::MulFOp>(
562  [&](auto op) { return CombiningKind::MUL; })
563  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
564  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
565  .Default([&](auto op) { return std::nullopt; });
566 }
567 
568 /// Check whether `outputOperand` is a reduction with a single combiner
569 /// operation. Return the combiner operation of the reduction. Return
570 /// nullptr otherwise. Multiple reduction operations would impose an
571 /// ordering between reduction dimensions and is currently unsupported in
572 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
573 /// max(min(X))
574 // TODO: use in LinalgOp verification, there is a circular dependency atm.
575 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
576  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
577  unsigned outputPos =
578  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
579  // Only single combiner operations are supported for now.
580  SmallVector<Operation *, 4> combinerOps;
581  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582  combinerOps.size() != 1)
583  return nullptr;
584 
585  // Return the combiner operation.
586  return combinerOps[0];
587 }
588 
589 /// Broadcast `value` to a vector of `shape` if possible. Return value
590 /// otherwise.
591 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
592  auto dstVecType = dyn_cast<VectorType>(dstType);
593  // If no shape to broadcast to, just return `value`.
594  if (dstVecType.getRank() == 0)
595  return value;
596  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
598  return value;
599  Location loc = b.getInsertionPoint()->getLoc();
600  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
601 }
602 
603 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
604 /// assumes that `reductionOp` has two operands and one of them is the reduction
605 /// initial value.buildMultiDimReduce
606 // Note: this is a true builder that notifies the OpBuilder listener.
607 // TODO: Consider moving as a static helper on the ReduceOp.
609  Value valueToReduce, Value acc,
610  ArrayRef<bool> dimsToMask) {
611  auto maybeKind = getCombinerOpKind(reduceOp);
612  assert(maybeKind && "Failed precondition: could not get reduction kind");
613  return b.create<vector::MultiDimReductionOp>(
614  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
615 }
616 
617 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
618  return llvm::to_vector(
619  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
620 }
621 
622 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
623 /// reduction iterator.
624 static bool hasReductionIterator(LinalgOp &op) {
625  return isa<linalg::ReduceOp>(op) ||
626  (isa<linalg::GenericOp>(op) &&
627  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
628 }
629 
630 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
631 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
632 /// currently being vectorized. If `dest` has null rank, build an memref.store.
633 /// Return the produced value or null if no value is produced.
634 // Note: this is a true builder that notifies the OpBuilder listener.
635 // TODO: Consider moving as a static helper on the ReduceOp.
636 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
637  OpOperand *outputOperand,
638  VectorizationState &state) {
639  Location loc = value.getLoc();
640  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
641  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
642 
643  // Compute the vector type of the value to store. This type should be an
644  // identity or projection of the canonical vector type without any permutation
645  // applied, given that any permutation in a transfer write happens as part of
646  // the write itself.
648  opOperandMap.getContext(), opOperandMap.getNumInputs(),
649  [&](AffineDimExpr dimExpr) -> bool {
650  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
651  });
652  auto vectorType = state.getCanonicalVecType(
653  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
654 
655  Operation *write;
656  if (vectorType.getRank() > 0) {
657  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
658  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
659  rewriter.create<arith::ConstantIndexOp>(loc, 0));
660  value = broadcastIfNeeded(rewriter, value, vectorType);
661  assert(value.getType() == vectorType && "Incorrect type");
662  write = rewriter.create<vector::TransferWriteOp>(
663  loc, value, outputOperand->get(), indices, writeMap);
664  } else {
665  // 0-d case is still special: do not invert the reindexing writeMap.
666  if (!isa<VectorType>(value.getType()))
667  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
668  assert(value.getType() == vectorType && "Incorrect type");
669  write = rewriter.create<vector::TransferWriteOp>(
670  loc, value, outputOperand->get(), ValueRange{});
671  }
672 
673  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
674 
675  // If masked, set in-bounds to true. Masking guarantees that the access will
676  // be in-bounds.
677  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
679  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
680  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
681  }
682 
683  LDBG("vectorized op: " << *write << "\n");
684  if (!write->getResults().empty())
685  return write->getResult(0);
686  return Value();
687 }
688 
689 // Custom vectorization precondition function type. This is intented to be used
690 // with CustomVectorizationHook. Returns success if the corresponding custom
691 // hook can vectorize the op.
693  std::function<LogicalResult(Operation *, bool)>;
694 
695 // Custom vectorization function type. Produce a vector form of Operation*
696 // assuming all its vectorized operands are already in the IRMapping.
697 // Return nullptr if the Operation cannot be vectorized.
699  std::function<VectorizationResult(Operation *, const IRMapping &)>;
700 
701 /// Helper function to vectorize the terminator of a `linalgOp`. New result
702 /// vector values are appended to `newResults`. Return
703 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
704 /// should not try to map produced operations and instead return the results
705 /// using the `newResults` vector making them available to the vectorization
706 /// algorithm for RAUW. This function is meant to be used as a
707 /// CustomVectorizationHook.
708 static VectorizationResult
710  const IRMapping &bvm, VectorizationState &state,
711  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
712  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
713  if (!yieldOp)
715  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
716  // TODO: Scan for an opportunity for reuse.
717  // TODO: use a map.
718  Value vectorValue = bvm.lookup(output.value());
719  Value newResult =
720  buildVectorWrite(rewriter, vectorValue,
721  linalgOp.getDpsInitOperand(output.index()), state);
722  if (newResult)
723  newResults.push_back(newResult);
724  }
725 
727 }
728 
729 /// Helper function to vectorize the index operations of a `linalgOp`. Return
730 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
731 /// should map the produced operations. This function is meant to be used as a
732 /// CustomVectorizationHook.
734  VectorizationState &state,
735  Operation *op,
736  LinalgOp linalgOp) {
737  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
738  if (!indexOp)
740  auto loc = indexOp.getLoc();
741  // Compute the static loop sizes of the index op.
742  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
743  auto dim = indexOp.getDim();
744  // Compute a one-dimensional index vector for the index op dimension.
745  auto indexVectorType =
746  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
747  state.getScalableVecDims()[dim]);
748  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
749  // Return the one-dimensional index vector if it lives in the trailing
750  // dimension of the iteration space since the vectorization algorithm in this
751  // case can handle the broadcast.
752  if (dim == targetShape.size() - 1)
754  // Otherwise permute the targetShape to move the index dimension last,
755  // broadcast the one-dimensional index vector to the permuted shape, and
756  // finally transpose the broadcasted index vector to undo the permutation.
757  auto permPattern =
758  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759  std::swap(permPattern[dim], permPattern.back());
760  auto permMap =
761  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
762 
763  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
764  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
765  indexSteps);
766  SmallVector<int64_t> transposition =
767  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768  std::swap(transposition.back(), transposition[dim]);
769  auto transposeOp =
770  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
771  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
772 }
773 
774 /// Helper function to check if the tensor.extract can be vectorized by the
775 /// custom hook vectorizeTensorExtract.
776 static LogicalResult
777 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
778  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
779  if (!extractOp)
780  return failure();
781 
782  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
783  return failure();
784 
785  // Check the index type, but only for non 0-d tensors (for which we do need
786  // access indices).
787  if (not extractOp.getIndices().empty()) {
788  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
789  return failure();
790  }
791 
792  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
793  return !VectorType::isValidElementType(type);
794  })) {
795  return failure();
796  }
797 
798  return success();
799 }
800 
801 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
802 /// generated from `tensor.extract`. The offset is calculated as follows
803 /// (example using scalar values):
804 ///
805 /// offset = extractOp.indices[0]
806 /// for (i = 1; i < numIndices; i++)
807 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
808 ///
809 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
810 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
812  VectorizationState &state,
813  tensor::ExtractOp extractOp,
814  const IRMapping &bvm) {
815  // The vector of indices for GatherOp should be shaped as the output vector.
816  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
817  auto loc = extractOp.getLoc();
818 
819  Value offset = broadcastIfNeeded(
820  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
821 
822  const size_t numIndices = extractOp.getIndices().size();
823  for (size_t i = 1; i < numIndices; i++) {
824  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
825 
826  auto dimSize = broadcastIfNeeded(
827  rewriter,
828  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
829  indexVecType);
830 
831  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
832 
833  auto extractOpIndex = broadcastIfNeeded(
834  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
835 
836  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
837  }
838 
839  return offset;
840 }
841 
843 
844 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
845 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
846 /// represents a contiguous load operation.
847 ///
848 /// Note that when calling this hook, it is assumed that the output vector is
849 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
850 /// labelled as a gather load before entering this method.
851 ///
852 /// Following on from the above, it is assumed that:
853 /// * for statically shaped loops, when no masks are used, only one dim is !=
854 /// 1 (that's what the shape of the output vector is based on).
855 /// * for dynamically shaped loops, there might be more non-unit dims
856 /// as the output vector type is user-specified.
857 ///
858 /// TODO: Statically shaped loops + vector masking
859 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
860  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
861  assert(
862  (linalgOp.hasDynamicShape() ||
863  llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864  "For statically shaped Linalg Ops, only one "
865  "non-unit loop dim is expected");
866  assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
867 
868  size_t idx = loopRanges.size() - 1;
869  for (; idx != 0; idx--)
870  if (loopRanges[idx] != 1)
871  break;
872 
873  return idx;
874 }
875 
876 /// Checks whether `val` can be used for calculating a loop invariant index.
877 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
878  VectorType resType) {
879 
880  assert(((llvm::count_if(resType.getShape(),
881  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882  "n-D vectors are not yet supported");
883 
884  // Blocks outside _this_ linalg.generic are effectively loop invariant.
885  // However, analysing block arguments for _this_ linalg.generic Op is a bit
886  // tricky. Just bail out in the latter case.
887  // TODO: We could try analysing the corresponding affine map here.
888  auto *block = linalgOp.getBlock();
889  if (isa<BlockArgument>(val))
890  return llvm::all_of(block->getArguments(),
891  [&val](Value v) { return (v != val); });
892 
893  Operation *defOp = val.getDefiningOp();
894  assert(defOp && "This is neither a block argument nor an operation result");
895 
896  // IndexOp is loop invariant as long as its result remains constant across
897  // iterations. Note that for dynamic shapes, the corresponding dim will also
898  // be conservatively treated as != 1.
899  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900  return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
901  }
902 
903  auto *ancestor = block->findAncestorOpInBlock(*defOp);
904 
905  // Values define outside `linalgOp` are loop invariant.
906  if (!ancestor)
907  return true;
908 
909  // Values defined inside `linalgOp`, which are constant, are loop invariant.
910  if (isa<arith::ConstantOp>(ancestor))
911  return true;
912 
913  bool result = true;
914  for (auto op : ancestor->getOperands())
915  result &= isLoopInvariantIdx(linalgOp, op, resType);
916 
917  return result;
918 }
919 
920 /// Check whether `val` could be used for calculating the trailing index for a
921 /// contiguous load operation.
922 ///
923 /// There are currently 3 types of values that are allowed here:
924 /// 1. loop-invariant values,
925 /// 2. values that increment by 1 with every loop iteration,
926 /// 3. results of basic arithmetic operations (linear and continuous)
927 /// involving 1., 2. and 3.
928 /// This method returns True if indeed only such values are used in calculating
929 /// `val.`
930 ///
931 /// Additionally, the trailing index for a contiguous load operation should
932 /// increment by 1 with every loop iteration, i.e. be based on:
933 /// * `linalg.index <dim>` ,
934 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
935 /// `linalg.index <dim>` increments by 1 with every loop iteration).
936 /// `foundIndexOp` is updated to `true` when such Op is found.
937 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
938  bool &foundIndexOp, VectorType resType) {
939 
940  assert(((llvm::count_if(resType.getShape(),
941  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942  "n-D vectors are not yet supported");
943 
944  // Blocks outside _this_ linalg.generic are effectively loop invariant.
945  // However, analysing block arguments for _this_ linalg.generic Op is a bit
946  // tricky. Just bail out in the latter case.
947  // TODO: We could try analysing the corresponding affine map here.
948  auto *block = linalgOp.getBlock();
949  if (isa<BlockArgument>(val))
950  return llvm::all_of(block->getArguments(),
951  [&val](Value v) { return (v != val); });
952 
953  Operation *defOp = val.getDefiningOp();
954  assert(defOp && "This is neither a block argument nor an operation result");
955 
956  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
957  auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
958 
959  foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
960  return true;
961  }
962 
963  auto *ancestor = block->findAncestorOpInBlock(*defOp);
964 
965  if (!ancestor)
966  return false;
967 
968  // Conservatively reject Ops that could lead to indices with stride other
969  // than 1.
970  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
971  return false;
972 
973  bool result = false;
974  for (auto op : ancestor->getOperands())
975  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
976 
977  return result;
978 }
979 
980 /// Infer the memory access pattern for the input ExtractOp
981 ///
982 /// Based on the ExtratOp result shape and the access indices, decides whether
983 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
984 /// or a gather load. When analysing the ExtractOp indices (to identify
985 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
986 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
987 /// Op).
988 ///
989 /// Note that it is always safe to use gather load operations for contiguous
990 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
991 /// that `extractOp` is a gather load.
993 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
994  LinalgOp &linalgOp, VectorType resType) {
995 
996  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
997 
998  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
999  if (inputShape.getShape().empty())
1001 
1002  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1003  // otherwise.
1004  bool isOutput1DVector =
1005  (llvm::count_if(resType.getShape(),
1006  [](int64_t dimSize) { return dimSize > 1; }) == 1);
1007  // 1. Assume that it's a gather load when reading non-1D vector.
1008  if (!isOutput1DVector)
1010 
1011  bool leadingIdxsLoopInvariant = true;
1012 
1013  // 2. Analyze the leading indices of `extractOp`.
1014  // Look at the way each index is calculated and decide whether it is suitable
1015  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1016  // gather load.
1017  auto indices = extractOp.getIndices();
1018  auto leadIndices = indices.drop_back(1);
1019 
1020  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1021  if (inputShape.getShape()[i] == 1)
1022  continue;
1023 
1024  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1025  }
1026 
1027  if (!leadingIdxsLoopInvariant) {
1028  LDBG("Found gather load: " << extractOp);
1030  }
1031 
1032  // 3. Analyze the trailing index for `extractOp`.
1033  // At this point we know that the leading indices are loop invariant. This
1034  // means that is potentially a scalar or a contiguous load. We can decide
1035  // based on the trailing idx.
1036  auto extractOpTrailingIdx = indices.back();
1037 
1038  // 3a. Scalar broadcast load
1039  // If the trailing index is loop invariant then this is a scalar load.
1040  if (leadingIdxsLoopInvariant &&
1041  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1042  LDBG("Found scalar broadcast load: " << extractOp);
1043 
1045  }
1046 
1047  // 3b. Contiguous loads
1048  // The trailing `extractOp` index should increment with every loop iteration.
1049  // This effectively means that it must be based on the trailing loop index.
1050  // This is what the following bool captures.
1051  bool foundIndexOp = false;
1052  bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1053  foundIndexOp, resType);
1054  // TODO: Support generating contiguous loads for column vectors - that will
1055  // require adding a permutation map to tranfer_read Ops.
1056  bool isRowVector = resType.getShape().back() != 1;
1057  isContiguousLoad &= (foundIndexOp && isRowVector);
1058 
1059  if (isContiguousLoad) {
1060  LDBG("Found contigous load: " << extractOp);
1062  }
1063 
1064  // 4. Fallback case - gather load.
1065  LDBG("Found gather load: " << extractOp);
1067 }
1068 
1069 /// Helper function to vectorize the tensor.extract operations. Returns
1070 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1071 /// should map the produced operations. This function is meant to be used as a
1072 /// CustomVectorizationHook.
1073 static VectorizationResult
1075  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1076  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1077  if (!extractOp)
1079  auto loc = extractOp.getLoc();
1080 
1081  // Compute the static loop sizes of the extract op.
1082  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1084  loc,
1085  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1086  /*value=*/true));
1087  auto passThruConstantOp =
1088  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1089 
1090  // Base indices are currently set to 0. We will need to re-visit if more
1091  // generic scenarios are to be supported.
1092  SmallVector<Value> baseIndices(
1093  extractOp.getIndices().size(),
1094  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1095 
1096  VectorMemoryAccessKind memAccessKind =
1097  getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1098 
1099  // 1. Handle gather access
1100  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1101  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1102 
1103  // Generate the gather load
1104  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1105  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106  maskConstantOp, passThruConstantOp);
1107  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1108 
1109  LDBG("Vectorised as gather load: " << extractOp << "\n");
1111  }
1112 
1113  // 2. Handle:
1114  // a. scalar loads + broadcast,
1115  // b. contiguous loads.
1116  // Both cases use vector.transfer_read.
1117 
1118  assert(llvm::count_if(resultType.getShape(),
1119  [](uint64_t dim) { return dim != 1; }) &&
1120  "Contiguous loads and scalar loads + broadcast only support 1-D "
1121  "vectors ATM!");
1122 
1123  // Collect indices for `vector.transfer_read`. At this point, the indices will
1124  // either be scalars or would have been broadcast to vectors matching the
1125  // result type. For indices that are vectors, there are two options:
1126  // * for non-trailing indices, all elements are identical (contiguous
1127  // loads are identified by looking for non-trailing indices that are
1128  // invariant with respect to the corresponding linalg.generic), or
1129  // * for trailing indices, the index vector will contain values with stride
1130  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1131  // needed.
1132  // This means that
1133  // * for scalar indices - just re-use it,
1134  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1135  // (0th) element and use that.
1136  SmallVector<Value> transferReadIdxs;
1137  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1138  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1139  if (idx.getType().isIndex()) {
1140  transferReadIdxs.push_back(idx);
1141  continue;
1142  }
1143 
1144  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1145  loc,
1146  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1147  resultType.getScalableDims().back()),
1148  idx);
1149  transferReadIdxs.push_back(
1150  rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1151  }
1152 
1153  // `tensor.extract_element` is always in-bounds, hence the following holds.
1154  auto dstRank = resultType.getRank();
1155  auto srcRank = extractOp.getTensor().getType().getRank();
1156  SmallVector<bool> inBounds(dstRank, true);
1157 
1158  // 2a. Handle scalar broadcast access.
1159  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1160  MLIRContext *ctx = rewriter.getContext();
1161  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1162  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1163 
1164  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1165  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1166  permutationMap, inBounds);
1167 
1168  // Mask this broadcasting xfer_read here rather than relying on the generic
1169  // path (the generic path assumes identity masking map, which wouldn't be
1170  // valid here).
1171  SmallVector<int64_t> readMaskShape = {1};
1172  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1173  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1174  loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1175  auto *maskedReadOp =
1176  mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1177 
1178  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1179  return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1180  }
1181 
1182  // 2b. Handle contiguous access.
1183  auto permutationMap = AffineMap::getMinorIdentityMap(
1184  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1185 
1186  int32_t rankDiff = dstRank - srcRank;
1187  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1188  // dims so that the ranks match. This is done by extending the map with 0s.
1189  // For example, for dstRank = 3, srcRank = 2, the following map created
1190  // above:
1191  // (d0, d1) --> (d0, d1)
1192  // is extended as:
1193  // (d0, d1) --> (0, d0, d1)
1194  while (rankDiff > 0) {
1195  permutationMap = permutationMap.insertResult(
1196  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1197  rankDiff--;
1198  }
1199 
1200  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1201  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1202  inBounds);
1203 
1204  LDBG("Vectorised as contiguous load: " << extractOp);
1205  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1206 }
1207 
1208 /// Emit reduction operations if the shapes of the value to reduce is different
1209 /// that the result shape.
1210 // Note: this is a true builder that notifies the OpBuilder listener.
1211 // TODO: Consider moving as a static helper on the ReduceOp.
1212 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1213  Value reduceValue, Value initialValue,
1214  const IRMapping &bvm) {
1215  Value reduceVec = bvm.lookup(reduceValue);
1216  Value outputVec = bvm.lookup(initialValue);
1217  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1218  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1219  // Reduce only if needed as the value may already have been reduce for
1220  // contraction vectorization.
1221  if (!reduceType ||
1222  (outputType && reduceType.getShape() == outputType.getShape()))
1223  return nullptr;
1224  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1225  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1226 }
1227 
1228 /// Generic vectorization for a single operation `op`, given already vectorized
1229 /// operands carried by `bvm`. Vectorization occurs as follows:
1230 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1231 /// result on success.
1232 /// 2. Clone any constant in the current scope without vectorization: each
1233 /// consumer of the constant will later determine the shape to which the
1234 /// constant needs to be broadcast to.
1235 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1236 /// of the `customVectorizationHooks` to cover such cases.
1237 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1238 /// operand of maximal rank. Other operands have smaller rank and are
1239 /// broadcast accordingly. It is assumed this broadcast is always legal,
1240 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1241 ///
1242 /// This function assumes all operands of `op` have been vectorized and are in
1243 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1244 /// a topologically-sorted list of ops.
1245 /// This function does not update `bvm` but returns a VectorizationStatus that
1246 /// instructs the caller what `bvm` update needs to occur.
1247 static VectorizationResult
1249  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1250  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1251  LDBG("vectorize op " << *op << "\n");
1252 
1253  // 1. Try to apply any CustomVectorizationHook.
1254  if (!customVectorizationHooks.empty()) {
1255  for (auto &customFunc : customVectorizationHooks) {
1256  VectorizationResult result = customFunc(op, bvm);
1257  if (result.status == VectorizationStatus::Failure)
1258  continue;
1259  return result;
1260  }
1261  }
1262 
1263  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1264  // Clone so that the constant is not confined to the linalgOp block .
1265  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1266  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1267 
1268  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1271 
1272  // 4 . Check if the operation is a reduction.
1273  SmallVector<std::pair<Value, Value>> reductionOperands;
1274  for (Value operand : op->getOperands()) {
1275  auto blockArg = dyn_cast<BlockArgument>(operand);
1276  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1277  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1278  continue;
1279  SmallVector<Operation *> reductionOps;
1280  Value reduceValue = matchReduction(
1281  linalgOp.getRegionOutputArgs(),
1282  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1283  if (!reduceValue)
1284  continue;
1285  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1286  }
1287  if (!reductionOperands.empty()) {
1288  assert(reductionOperands.size() == 1);
1289  Operation *reduceOp =
1290  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1291  reductionOperands[0].second, bvm);
1292  if (reduceOp)
1294  }
1295 
1296  // 5. Generic vectorization path for ElementwiseMappable ops.
1297  // a. Get the first max ranked shape.
1298  VectorType firstMaxRankedType;
1299  for (Value operand : op->getOperands()) {
1300  auto vecOperand = bvm.lookup(operand);
1301  assert(vecOperand && "Vector operand couldn't be found");
1302 
1303  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1304  if (vecType && (!firstMaxRankedType ||
1305  firstMaxRankedType.getRank() < vecType.getRank()))
1306  firstMaxRankedType = vecType;
1307  }
1308  // b. Broadcast each op if needed.
1309  SmallVector<Value> vecOperands;
1310  for (Value scalarOperand : op->getOperands()) {
1311  Value vecOperand = bvm.lookup(scalarOperand);
1312  assert(vecOperand && "Vector operand couldn't be found");
1313 
1314  if (firstMaxRankedType) {
1315  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1316  getElementTypeOrSelf(vecOperand.getType()),
1317  firstMaxRankedType.getScalableDims());
1318  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1319  } else {
1320  vecOperands.push_back(vecOperand);
1321  }
1322  }
1323  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1324  SmallVector<Type> resultTypes;
1325  for (Type resultType : op->getResultTypes()) {
1326  resultTypes.push_back(
1327  firstMaxRankedType
1328  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1329  firstMaxRankedType.getScalableDims())
1330  : resultType);
1331  }
1332  // d. Build and return the new op.
1333  return VectorizationResult{
1335  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1336  resultTypes, op->getAttrs())};
1337 }
1338 
1339 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1340 /// vector form. Generic vectorization proceeds as follows:
1341 /// 1. Verify the `linalgOp` has one non-empty region.
1342 /// 2. Values defined above the region are mapped to themselves and will be
1343 /// broadcasted on a per-need basis by their consumers.
1344 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1345 /// load).
1346 /// TODO: Reuse opportunities for RAR dependencies.
1347 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1348 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1349 /// iteration indices.
1350 /// 5. Iteratively call vectorizeOneOp on the region operations.
1351 ///
1352 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1353 /// performed to the maximal common vector size implied by the `linalgOp`
1354 /// iteration space. This eager broadcasting is introduced in the
1355 /// permutation_map of the vector.transfer_read operations. The eager
1356 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1357 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1358 /// the absence of good canonicalizations, the amount of work increases.
1359 /// This is not deemed a problem as we expect canonicalizations and foldings to
1360 /// aggressively clean up the useless work.
1361 static LogicalResult
1363  LinalgOp linalgOp,
1364  SmallVectorImpl<Value> &newResults) {
1365  LDBG("Vectorizing operation as linalg generic\n");
1366  Block *block = linalgOp.getBlock();
1367 
1368  // 2. Values defined above the region can only be broadcast for now. Make them
1369  // map to themselves.
1370  IRMapping bvm;
1371  SetVector<Value> valuesSet;
1372  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1373  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1374 
1375  if (linalgOp.getNumDpsInits() == 0)
1376  return failure();
1377 
1378  // 3. Turn all BBArgs into vector.transfer_read / load.
1379  Location loc = linalgOp.getLoc();
1380  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1381  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1382  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1383  if (linalgOp.isScalar(opOperand)) {
1384  bvm.map(bbarg, opOperand->get());
1385  continue;
1386  }
1387 
1388  // 3.a. Convert the indexing map for this input/output to a transfer read
1389  // permutation map and masking map.
1390  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1391 
1392  AffineMap readMap;
1393  VectorType readType;
1394  Type elemType = getElementTypeOrSelf(opOperand->get());
1395  if (linalgOp.isDpsInput(opOperand)) {
1396  // 3.a.i. For input reads we use the canonical vector shape.
1397  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1398  readType = state.getCanonicalVecType(elemType);
1399  } else {
1400  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1401  // reductions), the vector shape is computed by mapping the canonical
1402  // vector shape to the output domain and back to the canonical domain.
1403  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1404  readType =
1405  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1406  }
1407 
1408  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1409 
1410  Operation *read = rewriter.create<vector::TransferReadOp>(
1411  loc, readType, opOperand->get(), indices, readMap);
1412  read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1413  Value readValue = read->getResult(0);
1414 
1415  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1416  // will be in-bounds.
1417  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1418  SmallVector<bool> inBounds(readType.getRank(), true);
1419  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1420  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1421  }
1422 
1423  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1424  // TODO: remove this.
1425  if (readType.getRank() == 0)
1426  readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1427  ArrayRef<int64_t>());
1428 
1429  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1430  << "\n");
1431  bvm.map(bbarg, readValue);
1432  bvm.map(opOperand->get(), readValue);
1433  }
1434 
1436  // 4a. Register CustomVectorizationHook for yieldOp.
1437  CustomVectorizationHook vectorizeYield =
1438  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1439  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1440  };
1441  hooks.push_back(vectorizeYield);
1442 
1443  // 4b. Register CustomVectorizationHook for indexOp.
1444  CustomVectorizationHook vectorizeIndex =
1445  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1446  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1447  };
1448  hooks.push_back(vectorizeIndex);
1449 
1450  // 4c. Register CustomVectorizationHook for extractOp.
1451  CustomVectorizationHook vectorizeExtract =
1452  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1453  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1454  };
1455  hooks.push_back(vectorizeExtract);
1456 
1457  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1458  for (Operation &op : block->getOperations()) {
1459  VectorizationResult result =
1460  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1461  if (result.status == VectorizationStatus::Failure) {
1462  LDBG("failed to vectorize: " << op << "\n");
1463  return failure();
1464  }
1465  if (result.status == VectorizationStatus::NewOp) {
1466  Operation *maybeMaskedOp =
1467  state.maskOperation(rewriter, result.newOp, linalgOp);
1468  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1469  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1470  }
1471  }
1472 
1473  return success();
1474 }
1475 
1476 /// Given a tensor::PackOp, return the `dest` shape before any packing
1477 /// permutations.
1478 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1479  ArrayRef<int64_t> destShape) {
1480  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1481 }
1482 
1483 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1484 /// create an empty destination tensor and create a TransferWriteOp from the
1485 /// input to the empty tensor. If the destination shape is not the same as the
1486 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1487 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1488 /// inBounds attribute of the transfer write op instead of masking.
1490  Value input,
1491  SmallVector<OpFoldResult> destSizes,
1492  ArrayRef<int64_t> inputVectorSizes,
1493  bool useInBoundsInsteadOfMasking) {
1494 
1495  auto inputType = cast<VectorType>(input.getType());
1496  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1497  inputType.getElementType());
1498  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1499  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1500  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1501  SmallVector<bool> inBoundsVal(rank, true);
1502  if (useInBoundsInsteadOfMasking) {
1503  // Update the inBounds attribute.
1504  for (unsigned i = 0; i < rank; i++)
1505  inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1506  !ShapedType::isDynamic(destShape[i]);
1507  }
1508  Operation *write = builder.create<vector::TransferWriteOp>(
1509  loc,
1510  /*vector=*/input,
1511  /*source=*/dest,
1512  /*indices=*/SmallVector<Value>(rank, zero),
1513  /*inBounds=*/inBoundsVal);
1514  assert(llvm::none_of(
1515  destShape.drop_front(inputVectorSizes.size()),
1516  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1517  "Only dims aligned with inputVectorSizes may be dynamic");
1518  if (useInBoundsInsteadOfMasking)
1519  return write;
1520  bool needMaskForWrite = !llvm::equal(
1521  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1522  if (needMaskForWrite) {
1523  SmallVector<int64_t> writeMaskShape;
1524  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1525  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1526  destShape.end());
1527  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1528  Value maskForWrite =
1529  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1530  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1531  }
1532  return write;
1533 }
1534 
1535 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1536 /// padding value and (3) input vector sizes into:
1537 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1538 /// As in the following example:
1539 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1540 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1541 ///
1542 /// This pack would be vectorized to:
1543 ///
1544 /// %load = vector.mask %mask {
1545 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1546 /// {in_bounds = [true, true, true]} :
1547 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1548 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1549 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1550 /// to vector<32x4x2x1x16xf32>
1551 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1552 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1553 /// %write = vector.transfer_write %transpose,
1554 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1555 /// {in_bounds = [true, true, true, true, true]}
1556 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1557 ///
1558 /// If the (3) input vector sizes are not provided, the vector sizes are
1559 /// determined by the result tensor shape. Also, we update the inBounds
1560 /// attribute instead of masking.
1561 static LogicalResult
1562 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1563  ArrayRef<int64_t> inputVectorSizes,
1564  SmallVectorImpl<Value> &newResults) {
1565  OpBuilder::InsertionGuard g(rewriter);
1566  rewriter.setInsertionPoint(packOp);
1567 
1568  Location loc = packOp.getLoc();
1569  auto padValue = packOp.getPaddingValue();
1570  if (!padValue) {
1571  padValue = rewriter.create<arith::ConstantOp>(
1572  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1573  }
1574  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1575  LogicalResult status =
1576  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1577  .reifyResultShapes(rewriter, reifiedReturnShapes);
1578  (void)status; // prevent unused variable warning on non-assert builds.
1579  assert(succeeded(status) && "failed to reify result shapes");
1580 
1581  // If the input vector sizes are not provided, then the vector sizes are
1582  // determined by the result tensor shape. In case the vector sizes aren't
1583  // provided, we update the inBounds attribute instead of masking.
1584  bool useInBoundsInsteadOfMasking = false;
1585  if (inputVectorSizes.empty()) {
1586  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1587  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1588  useInBoundsInsteadOfMasking = true;
1589  }
1590 
1591  // Create masked TransferReadOp.
1592  SmallVector<int64_t> inputShape(inputVectorSizes);
1593  auto innerTiles = packOp.getStaticInnerTiles();
1594  auto innerDimsPos = packOp.getInnerDimsPos();
1595  auto outerDimsPerm = packOp.getOuterDimsPerm();
1596  if (!outerDimsPerm.empty())
1597  applyPermutationToVector(inputShape,
1598  invertPermutationVector(outerDimsPerm));
1599  for (auto [idx, size] : enumerate(innerTiles))
1600  inputShape[innerDimsPos[idx]] *= size;
1601  auto maskedRead = vector::createReadOrMaskedRead(
1602  rewriter, loc, packOp.getSource(), inputShape, padValue,
1603  useInBoundsInsteadOfMasking);
1604 
1605  // Create ShapeCastOp.
1606  SmallVector<int64_t> destShape(inputVectorSizes);
1607  destShape.append(innerTiles.begin(), innerTiles.end());
1608  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1609  packOp.getDestType().getElementType());
1610  auto shapeCastOp =
1611  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1612 
1613  // Create TransposeOp.
1614  auto destPermutation =
1616  auto transposeOp = rewriter.create<vector::TransposeOp>(
1617  loc, shapeCastOp.getResult(), destPermutation);
1618 
1619  // Create TransferWriteOp.
1621  rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1622  inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1623  newResults.push_back(write->getResult(0));
1624  return success();
1625 }
1626 
1627 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1628 /// Vector::TransferReadOp - Reads a vector from the source tensor
1629 /// vector::TransposeOp - Transpose the Source tensor
1630 /// ShapeCastOp - Reshape the data based on the target.
1631 /// vector::TransferWriteOp. - Write the result vector back to the destination
1632 /// tensor.
1633 /// If the vector sizes are not provided:
1634 /// * the vector sizes are determined by the input operand and attributes,
1635 /// * update the inBounds attribute instead of masking.
1636 static LogicalResult
1637 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1638  ArrayRef<int64_t> inputVectorSizes,
1639  SmallVectorImpl<Value> &newResults) {
1640 
1641  OpBuilder::InsertionGuard g(rewriter);
1642  rewriter.setInsertionPoint(unpackOp);
1643 
1644  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1645 
1646  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1647  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1648  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1649  bool useInBoundsInsteadOfMasking = false;
1650  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1651 
1652  auto destSize = unpackOp.getDestRank();
1653 
1654  if (!inputVectorSizes.empty())
1655  assert(inputVectorSizes.size() == destSize &&
1656  "Incorrect number of input vector sizes");
1657 
1658  // vectorSizes is the shape of the vector that will be used to do final
1659  // write on the destination tensor. It is set like this: Let's say the
1660  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1661  // Thus:
1662  // 1. vectorSizes = sourceShape.take_front(N)
1663  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1664  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1665  // innerTiles attribute value.
1666  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1667  if (vectorSizes.empty()) {
1668  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1669  if (!outerDimsPerm.empty())
1670  applyPermutationToVector(vectorSizes, outerDimsPerm);
1671  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1672  vectorSizes[pos] *= innerTiles[i];
1673 
1674  useInBoundsInsteadOfMasking = true;
1675  }
1676 
1677  // readVectorSizes is the size of tensor used to read and apply mask. It is
1678  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1679  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1680  // size M-N
1681  // Thus:
1682  // - initially: readVectorSizes = vectorInputSizes
1683  // - Divide all the readMaskShape locations pointed by innerDimPos
1684  // by the innerTileSize attribute value.
1685  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1686  // - Append the remaining shape from SS
1687  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1688  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1689  // 128] and outer_dims_perm is [1, 0] then read shape is:
1690  // ReadVectorSizes(initial): [512, 128]
1691  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1692  // = [16, 8]
1693  // After applying outer_dims_perm: [8, 16]
1694  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1695 
1696  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1697 
1698  for (auto [index, size] : enumerate(innerTiles)) {
1699  readVectorSizes[innerDimPos[index]] =
1700  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1701  }
1702  if (!outerDimsPerm.empty()) {
1703  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1704  }
1705  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1706  sourceShape.end());
1707 
1708  ReifiedRankedShapedTypeDims reifiedRetShapes;
1709  LogicalResult status =
1710  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1711  .reifyResultShapes(rewriter, reifiedRetShapes);
1712  if (status.failed()) {
1713  LDBG("Unable to reify result shapes of " << unpackOp);
1714  return failure();
1715  }
1716  Location loc = unpackOp->getLoc();
1717 
1718  auto padValue = rewriter.create<arith::ConstantOp>(
1719  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1720 
1721  // Read result, mask if necessary. If transferReadOp shape is not equal
1722  // to shape of source, then a mask is necessary.
1724  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1725  /*useInBoundsInsteadOfMasking=*/false);
1726 
1727  PackingMetadata packMetadata;
1728  SmallVector<int64_t> lastDimToInsertPosPerm =
1729  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1730  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1731  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1732  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1733  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1734  RankedTensorType stripMineTensorType =
1735  RankedTensorType::get(stripMineShape, stripMineElemType);
1736  // Transpose the appropriate rows to match output.
1737  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1738  loc, readResult, lastDimToInsertPosPerm);
1739 
1740  // Collapse the vector to the size required by result.
1741  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1742  stripMineTensorType, packMetadata.reassociations);
1743  mlir::VectorType vecCollapsedType =
1744  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1745  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1746  loc, vecCollapsedType, transposeOp->getResult(0));
1747 
1748  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1749  // otherwise the validator complains that the mask size is invalid.
1750  SmallVector<int64_t> writeVectorSizes(
1751  unpackOp.getDestType().hasStaticShape()
1752  ? vectorSizes
1753  : shapeCastOp.getResultVectorType().getShape());
1755  rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1756  writeVectorSizes, useInBoundsInsteadOfMasking);
1757  newResults.push_back(write->getResult(0));
1758  return success();
1759 }
1760 
1761 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1762 /// and (3) all-zero lowPad to
1763 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1764 static LogicalResult
1765 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1766  ArrayRef<int64_t> inputVectorSizes,
1767  SmallVectorImpl<Value> &newResults) {
1768  auto padValue = padOp.getConstantPaddingValue();
1769  Location loc = padOp.getLoc();
1770 
1771  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1772  OpBuilder::InsertionGuard g(rewriter);
1773  rewriter.setInsertionPoint(padOp);
1774 
1775  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1776  LogicalResult status =
1777  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1778  .reifyResultShapes(rewriter, reifiedReturnShapes);
1779  (void)status; // prevent unused variable warning on non-assert builds
1780  assert(succeeded(status) && "failed to reify result shapes");
1781  auto maskedRead = vector::createReadOrMaskedRead(
1782  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1783  /*useInBoundsInsteadOfMasking=*/false);
1785  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1786  /*useInBoundsInsteadOfMasking=*/false);
1787  newResults.push_back(write->getResult(0));
1788  return success();
1789 }
1790 
1791 // TODO: probably need some extra checks for reduction followed by consumer
1792 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1793 static LogicalResult reductionPreconditions(LinalgOp op) {
1794  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1795  LDBG("reduction precondition failed: no reduction iterator\n");
1796  return failure();
1797  }
1798  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1799  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1800  if (indexingMap.isPermutation())
1801  continue;
1802 
1803  Operation *reduceOp = matchLinalgReduction(&opOperand);
1804  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1805  LDBG("reduction precondition failed: reduction detection failed\n");
1806  return failure();
1807  }
1808  }
1809  return success();
1810 }
1811 
1812 static LogicalResult
1814  bool flatten1DDepthwiseConv) {
1815  if (flatten1DDepthwiseConv) {
1816  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1817  "supported\n");
1818  return failure();
1819  }
1820 
1821  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1822  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1823  return failure();
1824  }
1825 
1826  // Support dynamic shapes in 1D depthwise convolution, but only in the
1827  // _channel_ dimension.
1828  Value lhs = conv.getDpsInputOperand(0)->get();
1829  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1830  auto shapeWithoutCh = lhsShape.drop_back(1);
1831  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1832  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1833  "channel dim can be dynamic\n");
1834  return failure();
1835  }
1836 
1837  return success();
1838 }
1839 
1840 static LogicalResult
1842  bool flatten1DDepthwiseConv) {
1843  if (isa<ConvolutionOpInterface>(op.getOperation()))
1844  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1845 
1846  if (hasReductionIterator(op))
1847  return reductionPreconditions(op);
1848 
1849  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1850  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1851  if (!isElementwise(op) &&
1852  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1853  op.getOperation()))
1854  return failure();
1855 
1856  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1857  return success();
1858 }
1859 
1860 /// Need to check if the inner-tiles are static/constant.
1861 static LogicalResult
1862 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1863  ArrayRef<int64_t> inputVectorSizes) {
1864 
1865  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1866  return !getConstantIntValue(res).has_value();
1867  })) {
1868  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1869  return failure();
1870  }
1871  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1872  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1873  unpackOp.getDestType().hasStaticShape() &&
1874  unpackOp.getSourceType().hasStaticShape();
1875  if (!satisfyEmptyCond &&
1876  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1877  return failure();
1878 
1879  return success();
1880 }
1881 
1882 static LogicalResult vectorizeLinalgOpPrecondition(
1883  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1884  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1885  // tensor with dimension of 0 cannot be vectorized.
1886  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1887  return failure();
1888  // Check API contract for input vector sizes.
1889  if (!inputVectorSizes.empty() &&
1890  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1891  inputVectorSizes)))
1892  return failure();
1893 
1894  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1895  linalgOp, flatten1DDepthwiseConv))) {
1896  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1897  return failure();
1898  }
1899 
1901 
1902  // Register CustomVectorizationPrecondition for extractOp.
1903  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1904 
1905  // All types in the body should be a supported element type for VectorType.
1906  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1907  // Check if any custom hook can vectorize the inner op.
1908  if (llvm::any_of(
1909  customPreconditions,
1910  [&](const CustomVectorizationPrecondition &customPrecondition) {
1911  return succeeded(
1912  customPrecondition(&innerOp, vectorizeNDExtract));
1913  })) {
1914  continue;
1915  }
1916  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1917  return !VectorType::isValidElementType(type);
1918  })) {
1919  return failure();
1920  }
1921  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1922  return !VectorType::isValidElementType(type);
1923  })) {
1924  return failure();
1925  }
1926  }
1927  if (isElementwise(linalgOp))
1928  return success();
1929 
1930  // TODO: isaConvolutionOpInterface that can also infer from generic
1931  // features. But we will still need stride/dilation attributes that will be
1932  // annoying to reverse-engineer...
1933  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1934  return success();
1935  // TODO: the common vector shape is equal to the static loop sizes only when
1936  // all indexing maps are projected permutations. For convs and stencils the
1937  // logic will need to evolve.
1938  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1939  LDBG("precondition failed: not projected permutations\n");
1940  return failure();
1941  }
1942  if (failed(reductionPreconditions(linalgOp))) {
1943  LDBG("precondition failed: reduction preconditions\n");
1944  return failure();
1945  }
1946  return success();
1947 }
1948 
1949 static LogicalResult
1950 vectorizePackOpPrecondition(tensor::PackOp packOp,
1951  ArrayRef<int64_t> inputVectorSizes) {
1952  auto padValue = packOp.getPaddingValue();
1953  Attribute cstAttr;
1954  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1955  LDBG("pad value is not constant: " << packOp << "\n");
1956  return failure();
1957  }
1958  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1959  bool satisfyEmptyCond = true;
1960  if (inputVectorSizes.empty()) {
1961  if (!packOp.getDestType().hasStaticShape() ||
1962  !packOp.getSourceType().hasStaticShape())
1963  satisfyEmptyCond = false;
1964  }
1965 
1966  if (!satisfyEmptyCond &&
1968  resultTensorShape.take_front(packOp.getSourceRank()),
1969  inputVectorSizes)))
1970  return failure();
1971 
1972  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1973  return !getConstantIntValue(v).has_value();
1974  })) {
1975  LDBG("inner_tiles must be constant: " << packOp << "\n");
1976  return failure();
1977  }
1978 
1979  return success();
1980 }
1981 
1982 static LogicalResult
1983 vectorizePadOpPrecondition(tensor::PadOp padOp,
1984  ArrayRef<int64_t> inputVectorSizes) {
1985  auto padValue = padOp.getConstantPaddingValue();
1986  if (!padValue) {
1987  LDBG("pad value is not constant: " << padOp << "\n");
1988  return failure();
1989  }
1990 
1991  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1992  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1993  inputVectorSizes)))
1994  return failure();
1995 
1996  if (llvm::any_of(padOp.getLow(), [](Value v) {
1997  std::optional<int64_t> res = getConstantIntValue(v);
1998  return !res.has_value() || res.value() != 0;
1999  })) {
2000  LDBG("low pad must all be zero: " << padOp << "\n");
2001  return failure();
2002  }
2003 
2004  return success();
2005 }
2006 
2007 /// Preconditions for scalable vectors. This is quite restrictive - it models
2008 /// the fact that in practice we would only make selected dimensions scalable.
2009 static LogicalResult
2011  ArrayRef<int64_t> inputVectorSizes,
2012  ArrayRef<bool> inputScalableVecDims) {
2013  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2014  "Number of input vector sizes and scalable dims doesn't match");
2015 
2016  size_t numOfScalableDims =
2017  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2018 
2019  if (numOfScalableDims == 0)
2020  return success();
2021 
2022  auto linalgOp = dyn_cast<LinalgOp>(op);
2023 
2024  // Cond 1: There's been no need for scalable vectorisation of
2025  // non-linalg Ops so far
2026  if (!linalgOp)
2027  return failure();
2028 
2029  // Cond 2: There's been no need for more than 2 scalable dims so far
2030  if (numOfScalableDims > 2)
2031  return failure();
2032 
2033  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2034  // it matches one of the supported cases:
2035  // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2036  // (*).
2037  // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2038  // parallel dims.
2039  // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2040  // dim.
2041  // The 2nd restriction above means that only Matmul-like Ops are supported
2042  // when 2 dims are scalable, e.g. :
2043  // * iterators = [parallel, parallel, reduction]
2044  // * scalable flags = [true, true, false]
2045  //
2046  // (*) Non-unit dims get folded away in practice.
2047  // TODO: Relax these conditions as good motivating examples are identified.
2048 
2049  // Find the first scalable flag.
2050  bool seenNonUnitParallel = false;
2051  auto iterators = linalgOp.getIteratorTypesArray();
2052  SmallVector<bool> scalableFlags(inputScalableVecDims);
2053  int64_t idx = scalableFlags.size() - 1;
2054  while (!scalableFlags[idx]) {
2055  bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2056  seenNonUnitParallel |=
2057  (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2058 
2059  iterators.pop_back();
2060  scalableFlags.pop_back();
2061  --idx;
2062  }
2063 
2064  // Analyze the iterator corresponding to the first scalable dim.
2065  switch (iterators.back()) {
2066  case utils::IteratorType::reduction: {
2067  // Check 3. above is met.
2068  if (iterators.size() != inputVectorSizes.size()) {
2069  LDBG("Non-trailing reduction dim requested for scalable "
2070  "vectorization\n");
2071  return failure();
2072  }
2073  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2074  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2075  "is not supported\n");
2076  return failure();
2077  }
2078  break;
2079  }
2080  case utils::IteratorType::parallel: {
2081  // Check 1. and 2. above are met.
2082  if (seenNonUnitParallel) {
2083  LDBG("Inner parallel dim not requested for scalable "
2084  "vectorization\n");
2085  return failure();
2086  }
2087  break;
2088  }
2089  }
2090 
2091  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2092  // supported for which expect the folowing config:
2093  // * iterators = [parallel, parallel, reduction]
2094  // * scalable flags = [true, true, false]
2095  if (numOfScalableDims == 2) {
2096  // Disallow below case which breaks 3. above:
2097  // * iterators = [..., parallel, reduction]
2098  // * scalable flags = [..., true, true]
2099  if (iterators.back() == utils::IteratorType::reduction) {
2100  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2101  "vectorization\n");
2102  return failure();
2103  }
2104  scalableFlags.pop_back();
2105  iterators.pop_back();
2106 
2107  if (!scalableFlags.back() ||
2108  (iterators.back() != utils::IteratorType::parallel))
2109  return failure();
2110  }
2111 
2112  // Check to not let go the matmul with extended semantic, through this
2113  // transform.
2114  if (linalgOp.hasUserDefinedMaps())
2115  return failure();
2116 
2117  // Cond 4: Only the following ops are supported in the
2118  // presence of scalable vectors
2119  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2120  isa<linalg::MatmulTransposeAOp>(op) ||
2121  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2122  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2123 }
2124 
2126  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2127  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2128  bool flatten1DDepthwiseConv) {
2129 
2130  if (!hasVectorizationImpl(op))
2131  return failure();
2132 
2133  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2134  inputScalableVecDims)))
2135  return failure();
2136 
2138  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2139  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2140  vectorizeNDExtract,
2141  flatten1DDepthwiseConv);
2142  })
2143  .Case<tensor::PadOp>([&](auto padOp) {
2144  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2145  })
2146  .Case<tensor::PackOp>([&](auto packOp) {
2147  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2148  })
2149  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2150  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2151  })
2152  .Default([](auto) { return failure(); });
2153 }
2154 
2155 /// Converts affine.apply Ops to arithmetic operations.
2156 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2157  OpBuilder::InsertionGuard g(rewriter);
2158  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2159 
2160  for (auto op : make_early_inc_range(toReplace)) {
2161  rewriter.setInsertionPoint(op);
2162  auto expanded = affine::expandAffineExpr(
2163  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2164  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2165  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2166  rewriter.replaceOp(op, expanded);
2167  }
2168 }
2169 
2171  return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2172  op);
2173 }
2174 
2175 /// Emit a suitable vector form for an operation. If provided,
2176 /// `inputVectorSizes` are used to vectorize this operation.
2177 /// `inputVectorSizes` must match the rank of the iteration space of the
2178 /// operation and the input vector sizes must be greater than or equal to
2179 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2180 /// also allows the vectorization of operations with dynamic shapes.
2181 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2182  ArrayRef<int64_t> inputVectorSizes,
2183  ArrayRef<bool> inputScalableVecDims,
2184  bool vectorizeNDExtract,
2185  bool flatten1DDepthwiseConv) {
2186  LDBG("Attempting to vectorize:\n" << *op << "\n");
2187  LDBG("Input vector sizes: ");
2188  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2189  LLVM_DEBUG(llvm::dbgs() << "\n");
2190  LDBG("Input scalable vector dims: ");
2191  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2192  LLVM_DEBUG(llvm::dbgs() << "\n");
2193 
2194  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2195  vectorizeNDExtract,
2196  flatten1DDepthwiseConv))) {
2197  LDBG("Vectorization pre-conditions failed\n");
2198  return failure();
2199  }
2200 
2201  // Initialize vectorization state.
2202  VectorizationState state(rewriter);
2203  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2204  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2205  inputScalableVecDims))) {
2206  LDBG("Vectorization state couldn't be initialized\n");
2207  return failure();
2208  }
2209  }
2210 
2211  SmallVector<Value> results;
2212  auto vectorizeResult =
2214  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2215  // TODO: isaConvolutionOpInterface that can also infer from
2216  // generic features. Will require stride/dilation attributes
2217  // inference.
2218  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2219  FailureOr<Operation *> convOr = vectorizeConvolution(
2220  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2221  flatten1DDepthwiseConv);
2222  if (succeeded(convOr)) {
2223  llvm::append_range(results, (*convOr)->getResults());
2224  return success();
2225  }
2226 
2227  LDBG("Unsupported convolution can't be vectorized.\n");
2228  return failure();
2229  }
2230 
2231  LDBG("Vectorize generic by broadcasting to the canonical vector "
2232  "shape\n");
2233 
2234  // Pre-process before proceeding.
2235  convertAffineApply(rewriter, linalgOp);
2236 
2237  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2238  // to 'OpBuilder' when it is passed over to some methods like
2239  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2240  // erase an op within these methods, the actual rewriter won't be
2241  // notified and we will end up with read-after-free issues!
2242  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2243  })
2244  .Case<tensor::PadOp>([&](auto padOp) {
2245  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2246  results);
2247  })
2248  .Case<tensor::PackOp>([&](auto packOp) {
2249  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2250  results);
2251  })
2252  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2253  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2254  inputVectorSizes, results);
2255  })
2256  .Default([](auto) { return failure(); });
2257 
2258  if (failed(vectorizeResult)) {
2259  LDBG("Vectorization failed\n");
2260  return failure();
2261  }
2262 
2263  if (!results.empty())
2264  rewriter.replaceOp(op, results);
2265  else
2266  rewriter.eraseOp(op);
2267 
2268  return success();
2269 }
2270 
2272  memref::CopyOp copyOp) {
2273  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2274  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2275  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2276  return failure();
2277 
2278  auto srcElementType = getElementTypeOrSelf(srcType);
2279  auto dstElementType = getElementTypeOrSelf(dstType);
2280  if (!VectorType::isValidElementType(srcElementType) ||
2281  !VectorType::isValidElementType(dstElementType))
2282  return failure();
2283 
2284  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2285  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2286 
2287  Location loc = copyOp->getLoc();
2288  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2289  SmallVector<Value> indices(srcType.getRank(), zero);
2290 
2291  Value readValue = rewriter.create<vector::TransferReadOp>(
2292  loc, readType, copyOp.getSource(), indices,
2293  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2294  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2295  readValue =
2296  rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2297  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2298  }
2299  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2300  loc, readValue, copyOp.getTarget(), indices,
2301  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2302  rewriter.replaceOp(copyOp, writeValue->getResults());
2303  return success();
2304 }
2305 
2306 //----------------------------------------------------------------------------//
2307 // Misc. vectorization patterns.
2308 //----------------------------------------------------------------------------//
2309 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2310 /// given operation type OpTy.
2311 template <typename OpTy>
2312 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2314 
2315  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2316  PatternRewriter &rewriter) const final {
2317  bool changed = false;
2318  // Insert users in vector, because some users may be replaced/removed.
2319  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2320  if (auto op = dyn_cast<OpTy>(user))
2321  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2322  return success(changed);
2323  }
2324 
2325 protected:
2326  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2327  tensor::PadOp padOp, OpTy op) const = 0;
2328 };
2329 
2330 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2331 /// ```
2332 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2333 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2334 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2335 /// ```
2336 /// is rewritten to:
2337 /// ```
2338 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2339 /// {in_bounds = [true, true]}
2340 /// : tensor<?x?xf32>, vector<17x5xf32>
2341 /// ```
2342 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2343 /// sure that the original padding value %cst was never used.
2344 ///
2345 /// This rewrite is possible if:
2346 /// - `xferOp` has no out-of-bounds dims or mask.
2347 /// - Low padding is static 0.
2348 /// - Single, scalar padding value.
2350  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2352  vector::TransferReadOp>::VectorizePadOpUserPattern;
2353 
2354  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2355  vector::TransferReadOp xferOp) const override {
2356  // Low padding must be static 0.
2357  if (!padOp.hasZeroLowPad())
2358  return failure();
2359  // Pad value must be a constant.
2360  auto padValue = padOp.getConstantPaddingValue();
2361  if (!padValue)
2362  return failure();
2363  // Padding value of existing `xferOp` is unused.
2364  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2365  return failure();
2366 
2367  rewriter.modifyOpInPlace(xferOp, [&]() {
2368  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2369  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2370  rewriter.getBoolArrayAttr(inBounds));
2371  xferOp.getSourceMutable().assign(padOp.getSource());
2372  xferOp.getPaddingMutable().assign(padValue);
2373  });
2374 
2375  return success();
2376  }
2377 };
2378 
2379 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2380 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2381 /// value, where the same amount of padding is immediately removed again after
2382 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2383 /// tensor value and apply out-of-bounds masking. E.g.:
2384 /// ```
2385 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2386 /// : tensor<...> to tensor<?x?xf32>
2387 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2388 /// %2 = vector.transfer_write %vec, %1[...]
2389 /// : vector<17x5xf32>, tensor<17x5xf32>
2390 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2391 /// : tensor<17x5xf32> to tensor<?x?xf32>
2392 /// ```
2393 /// is rewritten to:
2394 /// ```
2395 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2396 /// : tensor<...> to tensor<?x?xf32>
2397 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2398 /// tensor<?x?xf32>
2399 /// ```
2400 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2401 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2402 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2403 /// from %r's old dimensions.
2404 ///
2405 /// This rewrite is possible if:
2406 /// - Low padding is static 0.
2407 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2408 /// ExtractSliceOp trims the same amount of padding that was added
2409 /// beforehand.
2410 /// - Single, scalar padding value.
2412  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2414  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2415 
2416  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2417  vector::TransferWriteOp xferOp) const override {
2418  // TODO: support 0-d corner case.
2419  if (xferOp.getTransferRank() == 0)
2420  return failure();
2421 
2422  // Low padding must be static 0.
2423  if (!padOp.hasZeroLowPad())
2424  return failure();
2425  // Pad value must be a constant.
2426  auto padValue = padOp.getConstantPaddingValue();
2427  if (!padValue)
2428  return failure();
2429  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2430  if (!xferOp->hasOneUse())
2431  return failure();
2432  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2433  if (!trimPadding)
2434  return failure();
2435  // Only static zero offsets supported when trimming padding.
2436  if (!trimPadding.hasZeroOffset())
2437  return failure();
2438  // trimPadding must remove the amount of padding that was added earlier.
2439  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2440  return failure();
2441 
2442  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2443  rewriter.setInsertionPoint(xferOp);
2444 
2445  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2446  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2447  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2448  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2449  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2450  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2451 
2452  return success();
2453  }
2454 
2455  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2456  /// i.e., same dimensions.
2457  ///
2458  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2459  /// dimensions, this function tries to infer the (static) tensor size by
2460  /// looking at the defining op and utilizing op-specific knowledge.
2461  ///
2462  /// This is a conservative analysis. In case equal tensor sizes cannot be
2463  /// proven statically, this analysis returns `false` even though the tensor
2464  /// sizes may turn out to be equal at runtime.
2465  bool hasSameTensorSize(Value beforePadding,
2466  tensor::ExtractSliceOp afterTrimming) const {
2467  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2468  // result and CastOp operand.
2469  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2470  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2471  return true;
2472 
2473  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2474  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2475  // Only RankedTensorType supported.
2476  if (!t1 || !t2)
2477  return false;
2478  // Rank of both values must be the same.
2479  if (t1.getRank() != t2.getRank())
2480  return false;
2481 
2482  // All static dimensions must be the same. Mixed cases (e.g., dimension
2483  // static in `t1` but dynamic in `t2`) are not supported.
2484  for (unsigned i = 0; i < t1.getRank(); ++i) {
2485  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2486  return false;
2487  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2488  return false;
2489  }
2490 
2491  // Nothing more to check if all dimensions are static.
2492  if (t1.getNumDynamicDims() == 0)
2493  return true;
2494 
2495  // All dynamic sizes must be the same. The only supported case at the
2496  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2497  // thereof).
2498 
2499  // Apart from CastOp, only ExtractSliceOp is supported.
2500  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2501  if (!beforeSlice)
2502  return false;
2503 
2504  assert(static_cast<size_t>(t1.getRank()) ==
2505  beforeSlice.getMixedSizes().size());
2506  assert(static_cast<size_t>(t2.getRank()) ==
2507  afterTrimming.getMixedSizes().size());
2508 
2509  for (unsigned i = 0; i < t1.getRank(); ++i) {
2510  // Skip static dimensions.
2511  if (!t1.isDynamicDim(i))
2512  continue;
2513  auto size1 = beforeSlice.getMixedSizes()[i];
2514  auto size2 = afterTrimming.getMixedSizes()[i];
2515 
2516  // Case 1: Same value or same constant int.
2517  if (isEqualConstantIntOrValue(size1, size2))
2518  continue;
2519 
2520  // Other cases: Take a deeper look at defining ops of values.
2521  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2522  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2523  if (!v1 || !v2)
2524  return false;
2525 
2526  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2527  // CSE is run.)
2528  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2529  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2530  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2531  minOp1.getOperands() == minOp2.getOperands())
2532  continue;
2533 
2534  // Add additional cases as needed.
2535  }
2536 
2537  // All tests passed.
2538  return true;
2539  }
2540 };
2541 
2542 /// Returns the effective Pad value for the input op, provided it's a scalar.
2543 ///
2544 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2545 /// this Op performs padding, retrieve the padding value provided that it's
2546 /// a scalar and static/fixed for all the padded values. Returns an empty value
2547 /// otherwise.
2549  if (!op)
2550  return {};
2551 
2552  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2553  // being broadcast, provided that it's a scalar.
2554  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2555  auto source = bcast.getSource();
2556  if (llvm::dyn_cast<VectorType>(source.getType()))
2557  return {};
2558 
2559  return source;
2560  }
2561 
2562  // 2. linalg.fill - use the scalar input value that used to fill the output
2563  // tensor.
2564  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2565  return fill.getInputs()[0];
2566  }
2567 
2568  // 3. tensor.generateOp - can't guarantee the value is fixed without
2569  // analysing, bail out.
2570  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2571  return {};
2572  }
2573 
2574  // 4. vector.transfer_write - inspect the input vector that's written from. If
2575  // if contains a single value that has been broadcast (e.g. via
2576  // vector.broadcast), extract it, fail otherwise.
2577  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2578  return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2579 
2580  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2581  // than the input tensor, then, provided it's constant, we'll extract the
2582  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2583  // TODO: Clarify the semantics when the input tensor is larger than the
2584  // destination.
2585  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2586  return getStaticPadVal(slice.getDest().getDefiningOp());
2587 
2588  return {};
2589 }
2590 
2591 /// Rewrite tensor.insert.slice as a vector.transfer_read +
2592 /// vector.transfer_write pair. The vector size is inferred from the static
2593 /// dims in the input and output tensors. If a dim is dynamic in both the input
2594 /// and output tensors, bails out.
2595 ///
2596 /// Before:
2597 /// !t_in_type = tensor<1x2x3xf32>
2598 /// !t_out_type = tensor<9x8x7x1x2x3xf32>
2599 /// !v_type = vector<1x2x3xf32>
2600 /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2601 /// into !t_out_type
2602 /// After:
2603 /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2604 /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2605 ///
2606 /// TODO: Support masking
2608  : public OpRewritePattern<tensor::InsertSliceOp> {
2610 
2611  LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2612  PatternRewriter &rewriter) const final {
2613  auto sourceType = sliceOp.getSource().getType();
2614  if (!VectorType::isValidElementType(sourceType.getElementType()))
2615  return failure();
2616 
2617  auto resultType = sliceOp.getResultType();
2618 
2619  // 1. Get the pad value.
2620  // TransferReadOp requires a scalar padding value. Note that:
2621  // * for in-bounds access, the value is actually irrelevant.
2622  // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2623  // 1. The source shape is static (output vector sizes would be based on
2624  // the source shape and hence all memory accesses would be in-bounds),
2625  // 2. Masking is used (output vector sizes would be user-provided, in which
2626  // case it is assumed that all memory accesses are in-bounds). This
2627  // remains a TODO.
2628  //
2629  // When the value is not known and not needed, use 0. Otherwise, bail out.
2630  Value padValue = getStaticPadVal(sliceOp);
2631  bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2632 
2633  if (!padValue && isOutOfBoundsRead) {
2634  LDBG("Failed to get a pad value for out-of-bounds read access\n");
2635  return failure();
2636  }
2637 
2638  if (!padValue) {
2639  auto elemType = sourceType.getElementType();
2640  padValue = rewriter.create<arith::ConstantOp>(
2641  sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2642  }
2643 
2644  // 2. Get the vector shape and in-bounds attributes
2645  SmallVector<int64_t> vecShape;
2646  SmallVector<bool> readInBounds;
2647  SmallVector<bool> writeInBounds;
2648  size_t rankDiff = resultType.getRank() - sourceType.getRank();
2649  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2650  if (!sourceType.isDynamicDim(i)) {
2651  vecShape.push_back(sourceType.getDimSize(i));
2652  // Source shape is statically known: Neither read nor write are
2653  // out-of-bounds.
2654  readInBounds.push_back(true);
2655  writeInBounds.push_back(true);
2656  } else if (!resultType.isDynamicDim(i)) {
2657  // Source shape is not statically known, but result shape is.
2658  // Vectorize with size of result shape. This may be larger than the
2659  // source size.
2660  // FIXME: Using rankDiff implies that the source tensor is inserted at
2661  // the end of the destination tensor. However, that's not required.
2662  vecShape.push_back(resultType.getDimSize(rankDiff + i));
2663  // Read may be out-of-bounds because the result size could be larger
2664  // than the source size.
2665  readInBounds.push_back(false);
2666  // Write will in-bounds provided that the corresponding write idx is 0.
2667  // To keep this logic simple, conservatively mark as out-of-bounds.
2668  writeInBounds.push_back(false);
2669  } else {
2670  // Neither source nor result dim of padOp is static. Cannot vectorize
2671  // the copy.
2672  // TODO: Add support for masking
2673  return failure();
2674  }
2675  }
2676  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2677 
2678  // 3. Generate TransferReadOp.
2679  SmallVector<Value> readIndices(
2680  vecType.getRank(),
2681  rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2682  auto read = rewriter.create<vector::TransferReadOp>(
2683  sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2684  ArrayRef<bool>{readInBounds});
2685 
2686  // 4. Generate TransferWriteOp.
2687  auto writeIndices = getValueOrCreateConstantIndexOp(
2688  rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2689 
2690  // 5. Finalize
2691  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2692  sliceOp, read, sliceOp.getDest(), writeIndices,
2693  ArrayRef<bool>{writeInBounds});
2694 
2695  return success();
2696  }
2697 };
2698 
2699 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2700 /// ```
2701 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2702 /// %r = tensor.insert_slice %0
2703 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2704 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2705 /// ```
2706 /// is rewritten to:
2707 /// ```
2708 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2709 /// : tensor<?x?xf32>, vector<17x5xf32>
2710 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2711 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2712 /// ```
2713 ///
2714 /// This rewrite is possible if:
2715 /// - Low padding is static 0.
2716 /// - `padOp` result shape is static.
2717 /// - The entire padded tensor is inserted.
2718 /// (Implies that sizes of `insertOp` are all static.)
2719 /// - Only unit strides in `insertOp`.
2720 /// - Single, scalar padding value.
2721 /// - `padOp` result not used as destination.
2723  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2725  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2726 
2727  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2728  tensor::InsertSliceOp insertOp) const override {
2729  // Low padding must be static 0.
2730  if (!padOp.hasZeroLowPad())
2731  return failure();
2732  // Only unit stride supported.
2733  if (!insertOp.hasUnitStride())
2734  return failure();
2735  // Pad value must be a constant.
2736  auto padValue = padOp.getConstantPaddingValue();
2737  if (!padValue)
2738  return failure();
2739  // Dynamic shapes not supported.
2740  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2741  return failure();
2742  // Pad result not used as destination.
2743  if (insertOp.getDest() == padOp.getResult())
2744  return failure();
2745 
2746  auto vecType = VectorType::get(padOp.getType().getShape(),
2747  padOp.getType().getElementType());
2748  unsigned vecRank = vecType.getRank();
2749  unsigned tensorRank = insertOp.getType().getRank();
2750 
2751  // Check if sizes match: Insert the entire tensor into most minor dims.
2752  // (No permutations allowed.)
2753  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2754  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2755  if (!llvm::all_of(
2756  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2757  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2758  }))
2759  return failure();
2760 
2761  // Insert the TransferReadOp and TransferWriteOp at the position of the
2762  // InsertSliceOp.
2763  rewriter.setInsertionPoint(insertOp);
2764 
2765  // Generate TransferReadOp: Read entire source tensor and add high
2766  // padding.
2767  SmallVector<Value> readIndices(
2768  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2769  auto read = rewriter.create<vector::TransferReadOp>(
2770  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2771 
2772  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2773  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2774  // source must fit into the destination at the specified offsets.
2775  auto writeIndices = getValueOrCreateConstantIndexOp(
2776  rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2777  SmallVector<bool> inBounds(vecRank, true);
2778  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2779  insertOp, read, insertOp.getDest(), writeIndices,
2780  ArrayRef<bool>{inBounds});
2781 
2782  return success();
2783  }
2784 };
2785 
2787  RewritePatternSet &patterns) {
2788  patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2789 }
2790 
2792  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2796  patterns.getContext(), baseBenefit.getBenefit() + 1);
2797 }
2798 
2799 //----------------------------------------------------------------------------//
2800 // Forwarding patterns
2801 //----------------------------------------------------------------------------//
2802 
2803 /// Check whether there is any interleaved use of any `values` between
2804 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2805 /// is in a different block.
2806 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2807  ValueRange values) {
2808  if (firstOp->getBlock() != secondOp->getBlock() ||
2809  !firstOp->isBeforeInBlock(secondOp)) {
2810  LDBG("interleavedUses precondition failed, firstOp: "
2811  << *firstOp << ", second op: " << *secondOp << "\n");
2812  return true;
2813  }
2814  for (auto v : values) {
2815  for (auto &u : v.getUses()) {
2816  Operation *owner = u.getOwner();
2817  if (owner == firstOp || owner == secondOp)
2818  continue;
2819  // TODO: this is too conservative, use dominance info in the future.
2820  if (owner->getBlock() == firstOp->getBlock() &&
2821  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2822  continue;
2823  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2824  << ", second op: " << *secondOp << "\n");
2825  return true;
2826  }
2827  }
2828  return false;
2829 }
2830 
2831 /// Return the unique subview use of `v` if it is indeed unique, null
2832 /// otherwise.
2833 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2834  memref::SubViewOp subViewOp;
2835  for (auto &u : v.getUses()) {
2836  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2837  if (subViewOp)
2838  return memref::SubViewOp();
2839  subViewOp = newSubViewOp;
2840  }
2841  }
2842  return subViewOp;
2843 }
2844 
2845 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2846 /// when available.
2848  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2849 
2850  // TODO: support mask.
2851  if (xferOp.getMask())
2852  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2853 
2854  // Transfer into `view`.
2855  Value viewOrAlloc = xferOp.getSource();
2856  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2857  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2858  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2859 
2860  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2861  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2862  if (!subViewOp)
2863  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2864  Value subView = subViewOp.getResult();
2865 
2866  // Find the copy into `subView` without interleaved uses.
2867  memref::CopyOp copyOp;
2868  for (auto &u : subView.getUses()) {
2869  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2870  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2871  if (newCopyOp.getTarget() != subView)
2872  continue;
2873  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2874  continue;
2875  copyOp = newCopyOp;
2876  break;
2877  }
2878  }
2879  if (!copyOp)
2880  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2881 
2882  // Find the fill into `viewOrAlloc` without interleaved uses before the
2883  // copy.
2884  FillOp maybeFillOp;
2885  for (auto &u : viewOrAlloc.getUses()) {
2886  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2887  assert(isa<MemRefType>(newFillOp.output().getType()));
2888  if (newFillOp.output() != viewOrAlloc)
2889  continue;
2890  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2891  continue;
2892  maybeFillOp = newFillOp;
2893  break;
2894  }
2895  }
2896  // Ensure padding matches.
2897  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2898  return rewriter.notifyMatchFailure(xferOp,
2899  "padding value does not match fill");
2900 
2901  // `in` is the subview that memref.copy reads. Replace it.
2902  Value in = copyOp.getSource();
2903 
2904  // memref.copy + linalg.fill can be used to create a padded local buffer.
2905  // The `masked` attribute is only valid on this padded buffer.
2906  // When forwarding to vector.transfer_read, the attribute must be reset
2907  // conservatively.
2908  auto vectorType = xferOp.getVectorType();
2909  Value res = rewriter.create<vector::TransferReadOp>(
2910  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2911  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2912  rewriter.getBoolArrayAttr(
2913  SmallVector<bool>(vectorType.getRank(), false)));
2914 
2915  if (maybeFillOp)
2916  rewriter.eraseOp(maybeFillOp);
2917  rewriter.eraseOp(copyOp);
2918  rewriter.replaceOp(xferOp, res);
2919 
2920  return success();
2921 }
2922 
2923 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2924 /// when available.
2926  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2927  // TODO: support mask.
2928  if (xferOp.getMask())
2929  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2930 
2931  // Transfer into `viewOrAlloc`.
2932  Value viewOrAlloc = xferOp.getSource();
2933  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2934  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2935  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2936 
2937  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2938  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2939  if (!subViewOp)
2940  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2941  Value subView = subViewOp.getResult();
2942 
2943  // Find the copy from `subView` without interleaved uses.
2944  memref::CopyOp copyOp;
2945  for (auto &u : subViewOp.getResult().getUses()) {
2946  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2947  if (newCopyOp.getSource() != subView)
2948  continue;
2949  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2950  continue;
2951  copyOp = newCopyOp;
2952  break;
2953  }
2954  }
2955  if (!copyOp)
2956  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2957 
2958  // `out` is the subview copied into that we replace.
2959  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2960  Value out = copyOp.getTarget();
2961 
2962  // Forward vector.transfer into copy.
2963  // memref.copy + linalg.fill can be used to create a padded local buffer.
2964  // The `masked` attribute is only valid on this padded buffer.
2965  // When forwarding to vector.transfer_write, the attribute must be reset
2966  // conservatively.
2967  auto vector = xferOp.getVector();
2968  rewriter.create<vector::TransferWriteOp>(
2969  xferOp.getLoc(), vector, out, xferOp.getIndices(),
2970  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2971  rewriter.getBoolArrayAttr(
2972  SmallVector<bool>(vector.getType().getRank(), false)));
2973 
2974  rewriter.eraseOp(copyOp);
2975  rewriter.eraseOp(xferOp);
2976 
2977  return success();
2978 }
2979 
2980 //===----------------------------------------------------------------------===//
2981 // Convolution vectorization patterns
2982 //===----------------------------------------------------------------------===//
2983 
2984 template <int N>
2985 static void bindShapeDims(ShapedType shapedType) {}
2986 
2987 template <int N, typename IntTy, typename... IntTy2>
2988 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2989  val = shapedType.getShape()[N];
2990  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2991 }
2992 
2993 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2994 template <typename... IntTy>
2995 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2996  bindShapeDims<0>(shapedType, vals...);
2997 }
2998 
2999 namespace {
3000 bool isCastOfBlockArgument(Operation *op) {
3001  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
3002  isa<BlockArgument>(op->getOperand(0));
3003 }
3004 
3005 bool isSupportedPoolKind(vector::CombiningKind kind) {
3006  switch (kind) {
3007  case vector::CombiningKind::ADD:
3008  case vector::CombiningKind::MAXNUMF:
3009  case vector::CombiningKind::MAXIMUMF:
3010  case vector::CombiningKind::MAXSI:
3011  case vector::CombiningKind::MAXUI:
3012  case vector::CombiningKind::MINNUMF:
3013  case vector::CombiningKind::MINIMUMF:
3014  case vector::CombiningKind::MINSI:
3016  return true;
3017  default:
3018  return false;
3019  }
3020 }
3021 
3022 /// Generate a vector implementation for either:
3023 /// ```
3024 /// Op def: ( w, kw )
3025 /// Iters: ({Par(), Red()})
3026 /// Layout: {{w + kw}, {kw}, {w}}
3027 /// ```
3028 /// kw is unrolled.
3029 ///
3030 /// or
3031 ///
3032 /// ```
3033 /// Op def: ( n, w, c, kw, f )
3034 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3035 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3036 /// ```
3037 /// kw is unrolled, w is unrolled iff dilationW > 1.
3038 ///
3039 /// or
3040 ///
3041 /// ```
3042 /// Op def: ( n, c, w, f, kw )
3043 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3044 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3045 /// ```
3046 /// kw is unrolled, w is unrolled iff dilationW > 1.
3047 ///
3048 /// or
3049 ///
3050 /// ```
3051 /// Op def: ( n, w, c, kw )
3052 /// Iters: ({Par(), Par(), Par(), Red()})
3053 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3054 /// ```
3055 /// kw is unrolled, w is unrolled iff dilationW > 1.
3056 struct Conv1DGenerator
3057  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3058  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3059  int dilationW)
3060  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3061  strideW(strideW), dilationW(dilationW) {
3062  // Determine whether `linalgOp` can be generated with this generator
3063  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3064  return;
3065  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3066  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3067  resShaped = linalgOp.getDpsInitOperand(0)->get();
3068  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3069  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3070  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3071  if (!lhsShapedType || !rhsShapedType || !resShapedType)
3072  return;
3073  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3074  // (non-channeled convolution -> LHS and RHS both have single dimensions).
3075  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3076  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3077  return;
3078 
3079  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3080  if (!reduceOp)
3081  return;
3082  redOp = reduceOp->getName().getIdentifier();
3083 
3084  if (!setOperKind(reduceOp))
3085  return;
3086  auto maybeKind = getCombinerOpKind(reduceOp);
3087  // Typically convolution will have a `Add` CombiningKind but for i1 type it
3088  // can get strength reduced to `OR` which is also supported. This strength
3089  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3090  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3091  *maybeKind != vector::CombiningKind::OR) &&
3092  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3093  return;
3094  }
3095  reductionKind = maybeKind.value();
3096 
3097  auto rhsRank = rhsShapedType.getRank();
3098  switch (oper) {
3099  case Conv:
3100  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3101  return;
3102  break;
3103  case Pool:
3104  if (rhsRank != 1)
3105  return;
3106  break;
3107  }
3108  // The op is now known to be valid.
3109  valid = true;
3110  }
3111 
3112  /// Generate a vector implementation for:
3113  /// ```
3114  /// Op def: ( w, kw )
3115  /// Iters: ({Par(), Red()})
3116  /// Layout: {{w + kw}, {kw}, {w}}
3117  /// ```
3118  /// kw is always unrolled.
3119  ///
3120  /// or
3121  ///
3122  /// ```
3123  /// Op def: ( n, w, c, kw, f )
3124  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3125  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3126  /// ```
3127  /// kw is always unrolled.
3128  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3129  /// > 1.
3130  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3131  if (!valid)
3132  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3133 
3134  int64_t nSize, wSize, cSize, kwSize, fSize;
3135  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3136  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3137  switch (conv1DOpOrder) {
3138  case Conv1DOpOrder::W:
3139  // Initialize unused dimensions
3140  nSize = fSize = cSize = 0;
3141  // out{W}
3142  bindShapeDims(resShapedType, wSize);
3143  // kernel{kw}
3144  bindShapeDims(rhsShapedType, kwSize);
3145  lhsShape = {// iw = ow + kw - 1
3146  // (i.e. 16 convolved with 3 -> 14)
3147  (wSize + kwSize - 1)};
3148  rhsShape = {kwSize};
3149  resShape = {wSize};
3150  break;
3151  case Conv1DOpOrder::Nwc:
3152  // out{n, w, f}
3153  bindShapeDims(resShapedType, nSize, wSize, fSize);
3154  switch (oper) {
3155  case Conv:
3156  // kernel{kw, c, f}
3157  bindShapeDims(rhsShapedType, kwSize, cSize);
3158  break;
3159  case Pool:
3160  // kernel{kw}
3161  bindShapeDims(rhsShapedType, kwSize);
3162  cSize = fSize;
3163  break;
3164  }
3165  lhsShape = {nSize,
3166  // iw = ow * sw + kw * dw - 1
3167  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3168  // Perform the proper inclusive -> exclusive -> inclusive.
3169  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3170  1,
3171  cSize};
3172  switch (oper) {
3173  case Conv:
3174  rhsShape = {kwSize, cSize, fSize};
3175  break;
3176  case Pool:
3177  rhsShape = {kwSize};
3178  break;
3179  }
3180  resShape = {nSize, wSize, fSize};
3181  break;
3182  case Conv1DOpOrder::Ncw:
3183  // out{n, f, w}
3184  bindShapeDims(resShapedType, nSize, fSize, wSize);
3185  switch (oper) {
3186  case Conv:
3187  // kernel{f, c, kw}
3188  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3189  break;
3190  case Pool:
3191  // kernel{kw}
3192  bindShapeDims(rhsShapedType, kwSize);
3193  cSize = fSize;
3194  break;
3195  }
3196  lhsShape = {nSize, cSize,
3197  // iw = ow * sw + kw * dw - 1
3198  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3199  // Perform the proper inclusive -> exclusive -> inclusive.
3200  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3201  1};
3202  switch (oper) {
3203  case Conv:
3204  rhsShape = {fSize, cSize, kwSize};
3205  break;
3206  case Pool:
3207  rhsShape = {kwSize};
3208  break;
3209  }
3210  resShape = {nSize, fSize, wSize};
3211  break;
3212  }
3213 
3214  vector::TransferWriteOp write;
3215  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3216 
3217  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3218  // When strideW == 1, we can batch the contiguous loads and avoid
3219  // unrolling
3220  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3221 
3222  Type lhsEltType = lhsShapedType.getElementType();
3223  Type rhsEltType = rhsShapedType.getElementType();
3224  Type resEltType = resShapedType.getElementType();
3225  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3226  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3227  auto resType = VectorType::get(resShape, resEltType);
3228  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3229  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3230  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3231  SmallVector<Value> resPadding(resShape.size(), zero);
3232 
3233  // Read the whole lhs, rhs and res in one shot (with zero padding).
3234  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3235  lhsPadding);
3236  // This is needed only for Conv.
3237  Value rhs = nullptr;
3238  if (oper == Conv)
3239  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3240  rhsPadding);
3241  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3242  resPadding);
3243 
3244  // The base vectorization case for channeled convolution is input:
3245  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3246  // vectorization case, we do pre transpose on input, weight, and output.
3247  switch (conv1DOpOrder) {
3248  case Conv1DOpOrder::W:
3249  case Conv1DOpOrder::Nwc:
3250  // Base case, so no transposes necessary.
3251  break;
3252  case Conv1DOpOrder::Ncw: {
3253  // To match base vectorization case, we pre-transpose current case.
3254  // ncw -> nwc
3255  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3256  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3257  // fcw -> wcf
3258  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3259 
3260  // This is needed only for Conv.
3261  if (oper == Conv)
3262  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3263  // nfw -> nwf
3264  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3265  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3266  break;
3267  }
3268  }
3269 
3270  //===------------------------------------------------------------------===//
3271  // Begin vector-only rewrite part
3272  //===------------------------------------------------------------------===//
3273  // Unroll along kw and read slices of lhs and rhs.
3274  SmallVector<Value> lhsVals, rhsVals, resVals;
3275  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3276  kwSize, strideW, dilationW, wSizeStep,
3277  isSingleChanneled);
3278  // Do not do for pooling.
3279  if (oper == Conv)
3280  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3281  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3282  wSizeStep, isSingleChanneled);
3283 
3284  auto linearIndex = [&](int64_t kw, int64_t w) {
3285  return kw * (wSize / wSizeStep) + w;
3286  };
3287 
3288  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3289  // or perform outerproduct for non-channeled convolution or perform simple
3290  // arith operation for pooling
3291  for (int64_t kw = 0; kw < kwSize; ++kw) {
3292  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3293  switch (oper) {
3294  case Conv:
3295  if (isSingleChanneled) {
3296  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3297  lhsVals[linearIndex(kw, w)],
3298  rhsVals[kw], resVals[w]);
3299  } else {
3300  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3301  lhsVals[linearIndex(kw, w)],
3302  rhsVals[kw], resVals[w]);
3303  }
3304  break;
3305  case Pool:
3306  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3307  resVals[w]);
3308  break;
3309  }
3310  }
3311  }
3312 
3313  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3314  isSingleChanneled);
3315  //===------------------------------------------------------------------===//
3316  // End vector-only rewrite part
3317  //===------------------------------------------------------------------===//
3318 
3319  // The base vectorization case for channeled convolution is output:
3320  // {n,w,f} To reuse the result from base pattern vectorization case, we
3321  // post transpose the base case result.
3322  switch (conv1DOpOrder) {
3323  case Conv1DOpOrder::W:
3324  case Conv1DOpOrder::Nwc:
3325  // Base case, so no transposes necessary.
3326  break;
3327  case Conv1DOpOrder::Ncw: {
3328  // nwf -> nfw
3329  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3330  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3331  break;
3332  }
3333  }
3334 
3335  return rewriter
3336  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3337  .getOperation();
3338  }
3339 
3340  // Take a value and widen to have the same element type as `ty`.
3341  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3342  const Type srcElementType = getElementTypeOrSelf(val.getType());
3343  const Type dstElementType = getElementTypeOrSelf(ty);
3344  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3345  if (srcElementType == dstElementType)
3346  return val;
3347 
3348  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3349  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3350  const Type dstType =
3351  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3352 
3353  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3354  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3355  }
3356 
3357  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3358  srcWidth < dstWidth)
3359  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3360 
3361  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3362  srcWidth < dstWidth)
3363  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3364 
3365  assert(false && "unhandled promotion case");
3366  return nullptr;
3367  }
3368 
3369  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3370  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3371  Value lhs, Value rhs, Value res) {
3372  vector::IteratorType par = vector::IteratorType::parallel;
3373  vector::IteratorType red = vector::IteratorType::reduction;
3374  AffineExpr n, w, f, c;
3375  bindDims(ctx, n, w, f, c);
3376  lhs = promote(rewriter, loc, lhs, res.getType());
3377  rhs = promote(rewriter, loc, rhs, res.getType());
3378  auto contrationOp = rewriter.create<vector::ContractionOp>(
3379  loc, lhs, rhs, res,
3380  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3381  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3382  contrationOp.setKind(reductionKind);
3383  return contrationOp;
3384  }
3385 
3386  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3387  // convolution.
3388  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3389  Value lhs, Value rhs, Value res) {
3390  return rewriter.create<vector::OuterProductOp>(
3391  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3392  }
3393 
3394  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3395  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3396  Value res) {
3397  if (isPoolExt)
3398  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3399  return rewriter
3400  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3401  ->getResult(0);
3402  }
3403 
3404  /// Generate a vector implementation for:
3405  /// ```
3406  /// Op def: ( n, w, c, kw)
3407  /// Iters: ({Par(), Par(), Par(), Red()})
3408  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3409  /// ```
3410  /// kw is always unrolled.
3411  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3412  /// > 1.
3413  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3414  bool channelDimScalableFlag,
3415  bool flatten) {
3416  if (!valid)
3417  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3418 
3419  bool scalableChDim = false;
3420  bool useMasking = false;
3421  int64_t nSize, wSize, cSize, kwSize;
3422  // kernel{kw, c}
3423  bindShapeDims(rhsShapedType, kwSize, cSize);
3424  if (ShapedType::isDynamic(cSize)) {
3425  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3426  cSize = channelDimVecSize;
3427  // Scalable vectors are only used when both conditions are met:
3428  // 1. channel dim is dynamic
3429  // 2. channelDimScalableFlag is set
3430  scalableChDim = channelDimScalableFlag;
3431  useMasking = true;
3432  }
3433 
3434  assert(!(useMasking && flatten) &&
3435  "Unsupported flattened conv with dynamic shapes");
3436 
3437  // out{n, w, c}
3438  bindShapeDims(resShapedType, nSize, wSize);
3439 
3440  vector::TransferWriteOp write;
3441  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3442 
3443  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3444  // When strideW == 1, we can batch the contiguous loads and avoid
3445  // unrolling
3446  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3447 
3448  Type lhsEltType = lhsShapedType.getElementType();
3449  Type rhsEltType = rhsShapedType.getElementType();
3450  Type resEltType = resShapedType.getElementType();
3451  VectorType lhsType = VectorType::get(
3452  {nSize,
3453  // iw = ow * sw + kw * dw - 1
3454  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3455  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3456  cSize},
3457  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3458  VectorType rhsType =
3459  VectorType::get({kwSize, cSize}, rhsEltType,
3460  /*scalableDims=*/{false, scalableChDim});
3461  VectorType resType =
3462  VectorType::get({nSize, wSize, cSize}, resEltType,
3463  /*scalableDims=*/{false, false, scalableChDim});
3464 
3465  // Masks the input xfer Op along the channel dim, iff the corresponding
3466  // scalable flag is set.
3467  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3468  ArrayRef<bool> scalableDims,
3469  Operation *opToMask) {
3470  if (!useMasking)
3471  return opToMask;
3472  auto maskType =
3473  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3474 
3475  SmallVector<bool> inBounds(maskShape.size(), true);
3476  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3477  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3478  rewriter.getBoolArrayAttr(inBounds));
3479 
3481  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3482 
3483  Value maskOp =
3484  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3485 
3486  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3487  };
3488 
3489  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3490  // 0].
3491  Value lhs = rewriter.create<vector::TransferReadOp>(
3492  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3493  auto maybeMaskedLhs = maybeMaskXferOp(
3494  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3495 
3496  // Read rhs slice of size {kw, c} @ [0, 0].
3497  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3498  ValueRange{zero, zero});
3499  auto maybeMaskedRhs = maybeMaskXferOp(
3500  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3501 
3502  // Read res slice of size {n, w, c} @ [0, 0, 0].
3503  Value res = rewriter.create<vector::TransferReadOp>(
3504  loc, resType, resShaped, ValueRange{zero, zero, zero});
3505  auto maybeMaskedRes = maybeMaskXferOp(
3506  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3507 
3508  //===------------------------------------------------------------------===//
3509  // Begin vector-only rewrite part
3510  //===------------------------------------------------------------------===//
3511  // Unroll along kw and read slices of lhs and rhs.
3512  SmallVector<Value> lhsVals, rhsVals, resVals;
3513  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3514  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3515 
3516  // Extract lhs slice of size {n, wSizeStep, c}
3517  // @ [0, sw * w + dw * kw, 0].
3518  for (int64_t kw = 0; kw < kwSize; ++kw) {
3519  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3520  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3521  loc, maybeMaskedLhs->getResult(0),
3522  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3523  inOutSliceSizes, inOutStrides));
3524  }
3525  }
3526  // Extract rhs slice of size {c} @ [kw].
3527  for (int64_t kw = 0; kw < kwSize; ++kw) {
3528  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3529  loc, maybeMaskedRhs->getResult(0),
3530  /*offsets=*/ArrayRef<int64_t>{kw}));
3531  }
3532  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3533  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3534  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3535  loc, maybeMaskedRes->getResult(0),
3536  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3537  inOutStrides));
3538  }
3539 
3540  auto linearIndex = [&](int64_t kw, int64_t w) {
3541  return kw * (wSize / wSizeStep) + w;
3542  };
3543 
3544  // Note - the scalable flags are ignored as flattening combined with
3545  // scalable vectorization is not supported.
3546  auto inOutFlattenSliceSizes =
3547  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3548  auto lhsTypeAfterFlattening =
3549  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3550  auto resTypeAfterFlattening =
3551  VectorType::get(inOutFlattenSliceSizes, resEltType);
3552 
3553  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3554  for (int64_t kw = 0; kw < kwSize; ++kw) {
3555  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3556  Value lhsVal = lhsVals[linearIndex(kw, w)];
3557  Value resVal = resVals[w];
3558  if (flatten) {
3559  // Flatten the input and output vectors (collapse the channel
3560  // dimension)
3561  lhsVal = rewriter.create<vector::ShapeCastOp>(
3562  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3563  resVal = rewriter.create<vector::ShapeCastOp>(
3564  loc, resTypeAfterFlattening, resVals[w]);
3565  }
3566  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3567  rhsVals[kw], resVal, flatten);
3568  if (flatten) {
3569  // Un-flatten the output vector (restore the channel dimension)
3570  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3571  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3572  }
3573  }
3574  }
3575 
3576  // Its possible we failed to create the Fma.
3577  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3578  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3579  for (auto &collection :
3580  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3581  for (Value v : collection)
3582  rewriter.eraseOp(v.getDefiningOp());
3583  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3584  }
3585 
3586  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3587  // This does not depend on kw.
3588  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3589  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3590  loc, resVals[w], maybeMaskedRes->getResult(0),
3591  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3592  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3593  }
3594  //===------------------------------------------------------------------===//
3595  // End vector-only rewrite part
3596  //===------------------------------------------------------------------===//
3597 
3598  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3599  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3600  loc, maybeMaskedRes->getResult(0), resShaped,
3601  ValueRange{zero, zero, zero});
3602  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3603  resOut);
3604  }
3605 
3606  /// Lower:
3607  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3608  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3609  /// to MulAcc.
3610  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3611  Value lhs, Value rhs, Value res,
3612  bool flatten) {
3613  auto rhsTy = cast<ShapedType>(rhs.getType());
3614  auto resTy = cast<ShapedType>(res.getType());
3615 
3616  // TODO(suderman): Change this to use a vector.ima intrinsic.
3617  lhs = promote(rewriter, loc, lhs, resTy);
3618 
3619  if (flatten) {
3620  // NOTE: This following logic won't work for scalable vectors. For this
3621  // reason, "flattening" is not supported when shapes are dynamic (this
3622  // should be captured by one of the pre-conditions).
3623 
3624  // There are two options for handling the filter:
3625  // * shape_cast(broadcast(filter))
3626  // * broadcast(shuffle(filter))
3627  // Opt for the option without shape_cast to simplify the codegen.
3628  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3629  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3630 
3631  SmallVector<int64_t, 16> indices;
3632  for (int i = 0; i < resSize / rhsSize; ++i) {
3633  for (int j = 0; j < rhsSize; ++j)
3634  indices.push_back(j);
3635  }
3636 
3637  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3638  }
3639  // Broadcast the filter to match the output vector
3640  rhs = rewriter.create<vector::BroadcastOp>(
3641  loc, resTy.clone(rhsTy.getElementType()), rhs);
3642 
3643  rhs = promote(rewriter, loc, rhs, resTy);
3644 
3645  if (!lhs || !rhs)
3646  return nullptr;
3647 
3648  if (isa<FloatType>(resTy.getElementType()))
3649  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3650 
3651  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3652  return rewriter.create<arith::AddIOp>(loc, mul, res);
3653  }
3654 
3655  /// Entry point for non-channeled convolution:
3656  /// {{w + kw}, {kw}, {w}}
3657  FailureOr<Operation *> generateNonChanneledConv() {
3658  AffineExpr w, kw;
3659  bindDims(ctx, w, kw);
3660  if (!iters({Par(), Red()}))
3661  return rewriter.notifyMatchFailure(op,
3662  "failed to match conv::W 1-par 1-red");
3663 
3664  // No transposition needed.
3665  if (layout({/*lhsIndex*/ {w + kw},
3666  /*rhsIndex*/ {kw},
3667  /*resIndex*/ {w}}))
3668  return conv(Conv1DOpOrder::W);
3669 
3670  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3671  }
3672 
3673  /// Entry point that transposes into the common form:
3674  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3675  FailureOr<Operation *> generateNwcConv() {
3676  AffineExpr n, w, f, kw, c;
3677  bindDims(ctx, n, w, f, kw, c);
3678  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3679  return rewriter.notifyMatchFailure(
3680  op, "failed to match conv::Nwc 3-par 2-red");
3681 
3682  // No transposition needed.
3683  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3684  /*rhsIndex*/ {kw, c, f},
3685  /*resIndex*/ {n, w, f}}))
3686  return conv(Conv1DOpOrder::Nwc);
3687 
3688  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3689  }
3690 
3691  /// Entry point that transposes into the common form:
3692  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3693  FailureOr<Operation *> generateNcwConv() {
3694  AffineExpr n, w, f, kw, c;
3695  bindDims(ctx, n, f, w, c, kw);
3696  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3697  return rewriter.notifyMatchFailure(
3698  op, "failed to match conv::Ncw 3-par 2-red");
3699 
3700  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3701  /*rhsIndex*/ {f, c, kw},
3702  /*resIndex*/ {n, f, w}}))
3703  return conv(Conv1DOpOrder::Ncw);
3704 
3705  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3706  }
3707 
3708  /// Entry point that transposes into the common form:
3709  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3710  FailureOr<Operation *> generateNwcPooling() {
3711  AffineExpr n, w, c, kw;
3712  bindDims(ctx, n, w, c, kw);
3713  if (!iters({Par(), Par(), Par(), Red()}))
3714  return rewriter.notifyMatchFailure(op,
3715  "failed to match pooling 3-par 1-red");
3716 
3717  // No transposition needed.
3718  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3719  /*rhsIndex*/ {kw},
3720  /*resIndex*/ {n, w, c}}))
3721  return conv(Conv1DOpOrder::Nwc);
3722 
3723  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3724  }
3725 
3726  /// Entry point that transposes into the common form:
3727  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3728  FailureOr<Operation *> generateNcwPooling() {
3729  AffineExpr n, w, c, kw;
3730  bindDims(ctx, n, c, w, kw);
3731  if (!iters({Par(), Par(), Par(), Red()}))
3732  return rewriter.notifyMatchFailure(op,
3733  "failed to match pooling 3-par 1-red");
3734 
3735  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3736  /*rhsIndex*/ {kw},
3737  /*resIndex*/ {n, c, w}}))
3738  return conv(Conv1DOpOrder::Ncw);
3739 
3740  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3741  }
3742 
3743  /// Entry point that transposes into the common form:
3744  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3745  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3746  bool vecChDimScalableFlag = false,
3747  bool flatten = false) {
3748  AffineExpr n, w, c, kw;
3749  bindDims(ctx, n, w, c, kw);
3750  if (!iters({Par(), Par(), Par(), Red()}))
3751  return rewriter.notifyMatchFailure(
3752  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3753 
3754  // No transposition needed.
3755  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3756  /*rhsIndex*/ {kw, c},
3757  /*resIndex*/ {n, w, c}}))
3758  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3759 
3760  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3761  }
3762 
3763 private:
3764  enum OperKind { Conv, Pool };
3765  bool valid = false;
3766  OperKind oper = Conv;
3767  StringAttr redOp;
3768  StringAttr poolExtOp;
3769  bool isPoolExt = false;
3770  int strideW, dilationW;
3771  Value lhsShaped, rhsShaped, resShaped;
3772  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3773  vector::CombiningKind reductionKind;
3774 
3775  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3776  // Returns true iff it is a valid conv/pooling op.
3777  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3778  // + yield) and rhs is not used) then it is the body of a pooling
3779  // If conv, check for single `mul` predecessor. The `mul` operands must be
3780  // block arguments or extension of block arguments.
3781  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3782  // must be block arguments or extension of block arguments.
3783  bool setOperKind(Operation *reduceOp) {
3784  int numBlockArguments =
3785  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3786  switch (numBlockArguments) {
3787  case 1: {
3788  // Will be convolution if feeder is a MulOp.
3789  // A strength reduced version of MulOp for i1 type is AndOp which is also
3790  // supported. Otherwise, it can be pooling. This strength reduction logic
3791  // is in `buildBinaryFn` helper in the Linalg dialect.
3792  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3793  llvm::IsaPred<BlockArgument>);
3794  Operation *feedOp = (*feedValIt).getDefiningOp();
3795  if (isCastOfBlockArgument(feedOp)) {
3796  oper = Pool;
3797  isPoolExt = true;
3798  poolExtOp = feedOp->getName().getIdentifier();
3799  } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3800  (isa<arith::AndIOp>(feedOp) &&
3801  feedOp->getResultTypes()[0].isInteger(1))) &&
3802  llvm::all_of(feedOp->getOperands(), [](Value v) {
3803  if (isa<BlockArgument>(v))
3804  return true;
3805  if (Operation *op = v.getDefiningOp())
3806  return isCastOfBlockArgument(op);
3807  return false;
3808  }))) {
3809  return false;
3810  }
3811  return true;
3812  }
3813  case 2:
3814  // Must be pooling
3815  oper = Pool;
3816  isPoolExt = false;
3817  return true;
3818  default:
3819  return false;
3820  }
3821  }
3822 };
3823 } // namespace
3824 
3825 /// Helper function to vectorize a LinalgOp with convolution semantics.
3826 // TODO: extend the generic vectorization to support windows and drop this.
3827 static FailureOr<Operation *> vectorizeConvolution(
3828  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3829  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3830  // The ConvolutionOpInterface gives us guarantees of existence for
3831  // strides/dilations. However, we do not need to rely on those, we can
3832  // simply use them if present, otherwise use the default and let the generic
3833  // conv. matcher in the ConvGenerator succeed or fail.
3834  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3835  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3836  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3837  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3838  Conv1DGenerator e(rewriter, op, stride, dilation);
3839  auto res = e.generateNonChanneledConv();
3840  if (succeeded(res))
3841  return res;
3842  res = e.generateNwcConv();
3843  if (succeeded(res))
3844  return res;
3845  res = e.generateNcwConv();
3846  if (succeeded(res))
3847  return res;
3848  res = e.generateNwcPooling();
3849  if (succeeded(res))
3850  return res;
3851  res = e.generateNcwPooling();
3852  if (succeeded(res))
3853  return res;
3854 
3855  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3856  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3857  // masked/scalable) is the channel dim (i.e. the trailing dim).
3858  uint64_t vecChDimSize = ShapedType::kDynamic;
3859  bool vecChDimScalableFlag = false;
3860  if (!inputVecSizes.empty()) {
3861  // Only use the input vector size corresponding to the channel dim. Other
3862  // vector dims will be inferred from the Ops.
3863  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3864  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3865  "Not a 1D depthwise conv!");
3866  size_t chDimIdx =
3868  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3869  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3870 
3871  vecChDimSize = inputVecSizes[chDimIdx];
3872  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3873  }
3874  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3875  flatten1DDepthwiseConv);
3876 }
3877 
3880 
3881  LogicalResult matchAndRewrite(LinalgOp op,
3882  PatternRewriter &rewriter) const override {
3883  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3884  if (failed(resultOrFail))
3885  return failure();
3886  Operation *newOp = *resultOrFail;
3887  if (newOp->getNumResults() == 0) {
3888  rewriter.eraseOp(op.getOperation());
3889  return success();
3890  }
3891  assert(newOp->getNumResults() == 1 && "expected single result");
3892  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3893  return success();
3894  }
3895 };
3896 
3898  RewritePatternSet &patterns, PatternBenefit benefit) {
3899  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3900 }
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:142
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
Definition: AffineMap.cpp:606
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:454
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:216
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2443
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:496
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:815
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite tensor.insert.slice as a vector.transfer_read + vector.transfer_write pair.
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp, PatternRewriter &rewriter) const final
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.