MLIR  20.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66  OpType res;
67  block.walk([&](OpType op) {
68  if (res) {
69  res = nullptr;
70  return WalkResult::interrupt();
71  }
72  res = op;
73  return WalkResult::advance();
74  });
75  return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
82  int64_t nSize, int64_t wSize, int64_t cSize,
83  int64_t kwSize, int strideW, int dilationW,
84  int64_t wSizeStep, bool isSingleChanneled) {
85  SmallVector<Value> result;
86  if (isSingleChanneled) {
87  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88  // convolution.
89  SmallVector<int64_t> sizes{wSizeStep};
90  SmallVector<int64_t> strides{1};
91  for (int64_t kw = 0; kw < kwSize; ++kw) {
92  for (int64_t w = 0; w < wSize; w += wSizeStep) {
93  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95  }
96  }
97  } else {
98  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99  // for channeled convolution.
100  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101  SmallVector<int64_t> strides{1, 1, 1};
102  for (int64_t kw = 0; kw < kwSize; ++kw) {
103  for (int64_t w = 0; w < wSize; w += wSizeStep) {
104  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105  loc, input,
106  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107  sizes, strides));
108  }
109  }
110  }
111  return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
117  Location loc, Value filter,
118  int64_t kwSize) {
119  SmallVector<Value> result;
120  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121  // non-chanelled convolution] @ [kw].
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  result.push_back(rewriter.create<vector::ExtractOp>(
124  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125  }
126  return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
133  int64_t nSize, int64_t wSize, int64_t fSize,
134  int64_t wSizeStep, bool isSingleChanneled) {
135  SmallVector<Value> result;
136  if (isSingleChanneled) {
137  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138  SmallVector<int64_t> sizes{wSizeStep};
139  SmallVector<int64_t> strides{1};
140  for (int64_t w = 0; w < wSize; w += wSizeStep) {
141  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143  }
144  } else {
145  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146  // convolution.
147  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148  SmallVector<int64_t> strides{1, 1, 1};
149  for (int64_t w = 0; w < wSize; w += wSizeStep) {
150  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152  }
153  }
154  return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
159  Value res, int64_t wSize, int64_t wSizeStep,
160  SmallVectorImpl<Value> &resVals,
161  bool isSingleChanneled) {
162 
163  if (isSingleChanneled) {
164  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165  // This does not depend on kw.
166  SmallVector<int64_t> strides{1};
167  for (int64_t w = 0; w < wSize; w += wSizeStep) {
168  res = rewriter.create<vector::InsertStridedSliceOp>(
169  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170  }
171  } else {
172  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173  // convolution. This does not depend on kw.
174  SmallVector<int64_t> strides{1, 1, 1};
175  for (int64_t w = 0; w < wSize; w += wSizeStep) {
176  res = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178  strides);
179  }
180  }
181  return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
187  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189  /// Initializes the vectorization state, including the computation of the
190  /// canonical vector shape for vectorization.
191  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192  ArrayRef<int64_t> inputVectorSizes,
193  ArrayRef<bool> inputScalableVecDims);
194 
195  /// Returns the canonical vector shape used to vectorize the iteration space.
196  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198  /// Returns the vector dimensions that are scalable in the canonical vector
199  /// shape.
200  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202  /// Returns a vector type of the provided `elementType` with the canonical
203  /// vector shape and the corresponding fixed/scalable dimensions bit. If
204  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205  /// accordingly.
207  Type elementType,
208  std::optional<AffineMap> dimPermutation = std::nullopt) const {
210  SmallVector<bool> scalableDims;
211  if (dimPermutation.has_value()) {
212  vectorShape =
213  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214  scalableDims =
215  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216  } else {
217  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219  }
220 
221  return VectorType::get(vectorShape, elementType, scalableDims);
222  }
223 
224  /// Masks an operation with the canonical vector mask if the operation needs
225  /// masking. Returns the masked operation or the original operation if masking
226  /// is not needed. If provided, the canonical mask for this operation is
227  /// permuted using `maybeMaskingMap`.
228  Operation *
229  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230  std::optional<AffineMap> maybeMaskingMap = std::nullopt);
231 
232 private:
233  /// Initializes the iteration space static sizes using the Linalg op
234  /// information. This may become more complicated in the future.
235  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237  }
238 
239  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240  /// all the static and dynamic dimensions of the iteration space to be
241  /// vectorized and store them in `iterSpaceValueSizes`.
242  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243  LinalgOp linalgOp);
244 
245  /// Create or retrieve an existing mask value to mask `opToMask` in the
246  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247  /// permuted using that permutation map. If a new mask is created, it will be
248  /// cached for future users.
249  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250  LinalgOp linalgOp,
251  std::optional<AffineMap> maybeMaskingMap);
252 
253  // Holds the compile-time static sizes of the iteration space to vectorize.
254  // Dynamic dimensions are represented using ShapedType::kDynamic.
255  SmallVector<int64_t> iterSpaceStaticSizes;
256 
257  /// Holds the value sizes of the iteration space to vectorize. Static
258  /// dimensions are represented by 'arith.constant' and dynamic
259  /// dimensions by 'tensor/memref.dim'.
260  SmallVector<Value> iterSpaceValueSizes;
261 
262  /// Holds the canonical vector shape used to vectorize the iteration space.
263  SmallVector<int64_t> canonicalVecShape;
264 
265  /// Holds the vector dimensions that are scalable in the canonical vector
266  /// shape.
267  SmallVector<bool> scalableVecDims;
268 
269  /// Holds the active masks for permutations of the canonical vector iteration
270  /// space.
271  DenseMap<AffineMap, Value> activeMaskCache;
272 
273  /// Global vectorization guard for the incoming rewriter. It's initialized
274  /// when the vectorization state is initialized.
276 };
277 
278 LogicalResult
279 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
280  LinalgOp linalgOp) {
281  // TODO: Support 0-d vectors.
282  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
283  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
284  // Create constant index op for static dimensions.
285  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
286  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
287  continue;
288  }
289 
290  // Find an operand defined on this dimension of the iteration space to
291  // extract the runtime dimension size.
292  Value operand;
293  unsigned operandDimPos;
294  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
295  operandDimPos)))
296  return failure();
297 
298  Value dynamicDim = linalgOp.hasPureTensorSemantics()
299  ? (Value)rewriter.create<tensor::DimOp>(
300  linalgOp.getLoc(), operand, operandDimPos)
301  : (Value)rewriter.create<memref::DimOp>(
302  linalgOp.getLoc(), operand, operandDimPos);
303  iterSpaceValueSizes.push_back(dynamicDim);
304  }
305 
306  return success();
307 }
308 
309 /// Initializes the vectorization state, including the computation of the
310 /// canonical vector shape for vectorization.
311 // TODO: Move this to the constructor when we can remove the failure cases.
312 LogicalResult
313 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
314  ArrayRef<int64_t> inputVectorSizes,
315  ArrayRef<bool> inputScalableVecDims) {
316  // Initialize the insertion point.
317  rewriter.setInsertionPoint(linalgOp);
318 
319  if (!inputVectorSizes.empty()) {
320  // Get the canonical vector shape from the input vector sizes provided. This
321  // path should be taken to vectorize code with dynamic shapes and when using
322  // vector sizes greater than the iteration space sizes.
323  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
324  scalableVecDims.append(inputScalableVecDims.begin(),
325  inputScalableVecDims.end());
326  } else {
327  // Compute the canonical vector shape from the operation shape. If there are
328  // dynamic shapes, the operation won't be vectorized. We assume all the
329  // vector dimensions are fixed.
330  canonicalVecShape = linalgOp.getStaticLoopRanges();
331  scalableVecDims.append(linalgOp.getNumLoops(), false);
332  }
333 
334  LDBG("Canonical vector shape: ");
335  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
336  LLVM_DEBUG(llvm::dbgs() << "\n");
337  LDBG("Scalable vector dims: ");
338  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
339  LLVM_DEBUG(llvm::dbgs() << "\n");
340 
341  if (ShapedType::isDynamicShape(canonicalVecShape))
342  return failure();
343 
344  // Initialize iteration space static sizes.
345  initIterSpaceStaticSizes(linalgOp);
346 
347  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
348  // all the static and dynamic dimensions of the iteration space, needed to
349  // compute a mask during vectorization.
350  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
351  return failure();
352 
353  return success();
354 }
355 
356 /// Create or retrieve an existing mask value to mask `opToMask` in the
357 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
358 /// using that permutation map. If a new mask is created, it will be cached for
359 /// future users.
360 Value VectorizationState::getOrCreateMaskFor(
361  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362  std::optional<AffineMap> maybeMaskingMap) {
363  // No mask is needed if the operation is not maskable.
364  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365  if (!maskableOp)
366  return Value();
367 
368  assert(!maskableOp.isMasked() &&
369  "Masking an operation that is already masked");
370 
371  // If no masking map was provided, use an identity map with the loop dims.
372  assert((!maybeMaskingMap || *maybeMaskingMap) &&
373  "Unexpected null mask permutation map");
374  AffineMap maskingMap =
375  maybeMaskingMap ? *maybeMaskingMap
377  linalgOp.getNumLoops(), rewriter.getContext());
378 
379  LDBG("Masking map: " << maskingMap << "\n");
380 
381  // Return the active mask for the masking map of this operation if it was
382  // already created.
383  auto activeMaskIt = activeMaskCache.find(maskingMap);
384  if (activeMaskIt != activeMaskCache.end()) {
385  Value mask = activeMaskIt->second;
386  LDBG("Reusing mask: " << mask << "\n");
387  return mask;
388  }
389 
390  // Compute permuted projection of the iteration space to be masked and the
391  // corresponding mask shape. If the resulting iteration space dimensions are
392  // static and identical to the mask shape, masking is not needed for this
393  // operation.
394  // TODO: Improve this check. Only projected permutation indexing maps are
395  // supported.
396  SmallVector<int64_t> permutedStaticSizes =
397  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
398  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
399  auto maskShape = maskType.getShape();
400 
401  LDBG("Mask shape: ");
402  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
403  LLVM_DEBUG(llvm::dbgs() << "\n");
404 
405  if (permutedStaticSizes == maskShape) {
406  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
407  activeMaskCache[maskingMap] = Value();
408  return Value();
409  }
410 
411  // Permute the iteration space value sizes to compute the mask upper bounds.
412  SmallVector<Value> upperBounds =
413  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
414  assert(!maskShape.empty() && !upperBounds.empty() &&
415  "Masked 0-d vectors are not supported yet");
416 
417  // Create the mask based on the dimension values.
418  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
419  maskType, upperBounds);
420  LDBG("Creating new mask: " << mask << "\n");
421  activeMaskCache[maskingMap] = mask;
422  return mask;
423 }
424 
425 /// Masks an operation with the canonical vector mask if the operation needs
426 /// masking. Returns the masked operation or the original operation if masking
427 /// is not needed. If provided, the canonical mask for this operation is
428 /// permuted using `maybeMaskingMap`.
429 Operation *
431  LinalgOp linalgOp,
432  std::optional<AffineMap> maybeMaskingMap) {
433  LDBG("Trying to mask: " << *opToMask << "\n");
434 
435  // Create or retrieve mask for this operation.
436  Value mask =
437  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
438 
439  if (!mask) {
440  LDBG("No mask required\n");
441  return opToMask;
442  }
443 
444  // Wrap the operation with a new `vector.mask` and update D-U chain.
445  assert(opToMask && "Expected a valid operation to mask");
446  auto maskOp = cast<vector::MaskOp>(
447  mlir::vector::maskOperation(rewriter, opToMask, mask));
448  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
449 
450  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
451  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
452  maskOpTerminator);
453 
454  LDBG("Masked operation: " << *maskOp << "\n");
455  return maskOp;
456 }
457 
458 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
459 /// projectedPermutation, compress the unused dimensions to serve as a
460 /// permutation_map for a vector transfer operation.
461 /// For example, given a linalg op such as:
462 ///
463 /// ```
464 /// %0 = linalg.generic {
465 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
466 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
467 /// }
468 /// ins(%0 : tensor<2x3x4xf32>)
469 /// outs(%1 : tensor<5x6xf32>)
470 /// ```
471 ///
472 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
473 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
474 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
476  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
477  "expected projected permutation");
478  auto res = compressUnusedDims(map);
479  assert(res.getNumDims() == res.getNumResults() &&
480  "expected reindexed map with same number of dims and results");
481  return res;
482 }
483 
484 /// Helper enum to represent conv1d input traversal order.
485 enum class Conv1DOpOrder {
486  W, // Corresponds to non-channeled 1D convolution operation.
487  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
488  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
489 };
490 
491 /// Helper data structure to represent the result of vectorization.
492 /// In certain specific cases, like terminators, we do not want to propagate/
494  /// Op failed to vectorize.
495  Failure = 0,
496  /// Op vectorized and custom function took care of replacement logic
498  /// Op vectorized into a new Op whose results will replace original Op's
499  /// results.
500  NewOp
501  // TODO: support values if Op vectorized to Many-Ops whose results we need to
502  // aggregate for replacement.
503 };
505  /// Return status from vectorizing the current op.
507  /// New vectorized operation to replace the current op.
508  /// Replacement behavior is specified by `status`.
510 };
511 
512 std::optional<vector::CombiningKind>
514  using ::mlir::vector::CombiningKind;
515 
516  if (!combinerOp)
517  return std::nullopt;
519  .Case<arith::AddIOp, arith::AddFOp>(
520  [&](auto op) { return CombiningKind::ADD; })
521  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
522  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
523  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
524  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
525  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
526  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
527  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
528  .Case<arith::MulIOp, arith::MulFOp>(
529  [&](auto op) { return CombiningKind::MUL; })
530  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
531  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
532  .Default([&](auto op) { return std::nullopt; });
533 }
534 
535 /// Check whether `outputOperand` is a reduction with a single combiner
536 /// operation. Return the combiner operation of the reduction. Return
537 /// nullptr otherwise. Multiple reduction operations would impose an
538 /// ordering between reduction dimensions and is currently unsupported in
539 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
540 /// max(min(X))
541 // TODO: use in LinalgOp verification, there is a circular dependency atm.
542 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
543  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
544  unsigned outputPos =
545  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
546  // Only single combiner operations are supported for now.
547  SmallVector<Operation *, 4> combinerOps;
548  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
549  combinerOps.size() != 1)
550  return nullptr;
551 
552  // Return the combiner operation.
553  return combinerOps[0];
554 }
555 
556 /// Broadcast `value` to a vector of `shape` if possible. Return value
557 /// otherwise.
558 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
559  auto dstVecType = dyn_cast<VectorType>(dstType);
560  // If no shape to broadcast to, just return `value`.
561  if (dstVecType.getRank() == 0)
562  return value;
563  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
565  return value;
566  Location loc = b.getInsertionPoint()->getLoc();
567  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
568 }
569 
570 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
571 /// assumes that `reductionOp` has two operands and one of them is the reduction
572 /// initial value.buildMultiDimReduce
573 // Note: this is a true builder that notifies the OpBuilder listener.
574 // TODO: Consider moving as a static helper on the ReduceOp.
576  Value valueToReduce, Value acc,
577  ArrayRef<bool> dimsToMask) {
578  auto maybeKind = getCombinerOpKind(reduceOp);
579  assert(maybeKind && "Failed precondition: could not get reduction kind");
580  return b.create<vector::MultiDimReductionOp>(
581  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
582 }
583 
584 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
585  return llvm::to_vector(
586  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
587 }
588 
589 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
590 /// reduction iterator.
591 static bool hasReductionIterator(LinalgOp &op) {
592  return isa<linalg::ReduceOp>(op) ||
593  (isa<linalg::GenericOp>(op) &&
594  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
595 }
596 
597 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
598 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
599 /// currently being vectorized. If `dest` has null rank, build an memref.store.
600 /// Return the produced value or null if no value is produced.
601 // Note: this is a true builder that notifies the OpBuilder listener.
602 // TODO: Consider moving as a static helper on the ReduceOp.
603 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
604  OpOperand *outputOperand,
605  VectorizationState &state) {
606  Location loc = value.getLoc();
607  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
608  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
609 
610  // Compute the vector type of the value to store. This type should be an
611  // identity or projection of the canonical vector type without any permutation
612  // applied, given that any permutation in a transfer write happens as part of
613  // the write itself.
615  opOperandMap.getContext(), opOperandMap.getNumInputs(),
616  [&](AffineDimExpr dimExpr) -> bool {
617  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
618  });
619  auto vectorType = state.getCanonicalVecType(
620  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
621 
622  Operation *write;
623  if (vectorType.getRank() > 0) {
624  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
625  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
626  rewriter.create<arith::ConstantIndexOp>(loc, 0));
627  value = broadcastIfNeeded(rewriter, value, vectorType);
628  assert(value.getType() == vectorType && "Incorrect type");
629  write = rewriter.create<vector::TransferWriteOp>(
630  loc, value, outputOperand->get(), indices, writeMap);
631  } else {
632  // 0-d case is still special: do not invert the reindexing writeMap.
633  if (!isa<VectorType>(value.getType()))
634  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
635  assert(value.getType() == vectorType && "Incorrect type");
636  write = rewriter.create<vector::TransferWriteOp>(
637  loc, value, outputOperand->get(), ValueRange{});
638  }
639 
640  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
641 
642  // If masked, set in-bounds to true. Masking guarantees that the access will
643  // be in-bounds.
644  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
645  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
646  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
647  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
648  }
649 
650  LDBG("vectorized op: " << *write << "\n");
651  if (!write->getResults().empty())
652  return write->getResult(0);
653  return Value();
654 }
655 
656 // Custom vectorization precondition function type. This is intented to be used
657 // with CustomVectorizationHook. Returns success if the corresponding custom
658 // hook can vectorize the op.
660  std::function<LogicalResult(Operation *, bool)>;
661 
662 // Custom vectorization function type. Produce a vector form of Operation*
663 // assuming all its vectorized operands are already in the IRMapping.
664 // Return nullptr if the Operation cannot be vectorized.
666  std::function<VectorizationResult(Operation *, const IRMapping &)>;
667 
668 /// Helper function to vectorize the terminator of a `linalgOp`. New result
669 /// vector values are appended to `newResults`. Return
670 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
671 /// should not try to map produced operations and instead return the results
672 /// using the `newResults` vector making them available to the vectorization
673 /// algorithm for RAUW. This function is meant to be used as a
674 /// CustomVectorizationHook.
675 static VectorizationResult
677  const IRMapping &bvm, VectorizationState &state,
678  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
679  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
680  if (!yieldOp)
682  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
683  // TODO: Scan for an opportunity for reuse.
684  // TODO: use a map.
685  Value vectorValue = bvm.lookup(output.value());
686  Value newResult =
687  buildVectorWrite(rewriter, vectorValue,
688  linalgOp.getDpsInitOperand(output.index()), state);
689  if (newResult)
690  newResults.push_back(newResult);
691  }
692 
694 }
695 
696 /// Helper function to vectorize the index operations of a `linalgOp`. Return
697 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
698 /// should map the produced operations. This function is meant to be used as a
699 /// CustomVectorizationHook.
701  VectorizationState &state,
702  Operation *op,
703  LinalgOp linalgOp) {
704  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
705  if (!indexOp)
707  auto loc = indexOp.getLoc();
708  // Compute the static loop sizes of the index op.
709  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
710  auto dim = indexOp.getDim();
711  // Compute a one-dimensional index vector for the index op dimension.
712  auto indexVectorType =
713  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
714  state.getScalableVecDims()[dim]);
715  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
716  // Return the one-dimensional index vector if it lives in the trailing
717  // dimension of the iteration space since the vectorization algorithm in this
718  // case can handle the broadcast.
719  if (dim == targetShape.size() - 1)
721  // Otherwise permute the targetShape to move the index dimension last,
722  // broadcast the one-dimensional index vector to the permuted shape, and
723  // finally transpose the broadcasted index vector to undo the permutation.
724  auto permPattern =
725  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
726  std::swap(permPattern[dim], permPattern.back());
727  auto permMap =
728  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
729 
730  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
731  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
732  indexSteps);
733  SmallVector<int64_t> transposition =
734  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
735  std::swap(transposition.back(), transposition[dim]);
736  auto transposeOp =
737  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
738  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
739 }
740 
741 /// Helper function to check if the tensor.extract can be vectorized by the
742 /// custom hook vectorizeTensorExtract.
743 static LogicalResult
744 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
745  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
746  if (!extractOp)
747  return failure();
748 
749  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
750  return failure();
751 
752  // Check the index type, but only for non 0-d tensors (for which we do need
753  // access indices).
754  if (not extractOp.getIndices().empty()) {
755  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
756  return failure();
757  }
758 
759  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
760  return !VectorType::isValidElementType(type);
761  })) {
762  return failure();
763  }
764 
765  return success();
766 }
767 
768 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
769 /// generated from `tensor.extract`. The offset is calculated as follows
770 /// (example using scalar values):
771 ///
772 /// offset = extractOp.indices[0]
773 /// for (i = 1; i < numIndices; i++)
774 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
775 ///
776 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
777 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
779  VectorizationState &state,
780  tensor::ExtractOp extractOp,
781  const IRMapping &bvm) {
782  // The vector of indices for GatherOp should be shaped as the output vector.
783  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
784  auto loc = extractOp.getLoc();
785 
786  Value offset = broadcastIfNeeded(
787  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
788 
789  const size_t numIndices = extractOp.getIndices().size();
790  for (size_t i = 1; i < numIndices; i++) {
791  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
792 
793  auto dimSize = broadcastIfNeeded(
794  rewriter,
795  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
796  indexVecType);
797 
798  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
799 
800  auto extractOpIndex = broadcastIfNeeded(
801  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
802 
803  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
804  }
805 
806  return offset;
807 }
808 
810 
811 /// Checks whether /p val can be used for calculating a loop invariant index.
812 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
813 
814  auto targetShape = linalgOp.getStaticLoopRanges();
815  assert(((llvm::count_if(targetShape,
816  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
817  "n-D vectors are not yet supported");
818  assert(targetShape.back() != 1 &&
819  "1-D vectors with the trailing dim eqaual 1 are not yet supported");
820 
821  // Blocks outside _this_ linalg.generic are effectively loop invariant.
822  // However, analysing block arguments for _this_ linalg.generic Op is a bit
823  // tricky. Just bail out in the latter case.
824  // TODO: We could try analysing the corresponding affine map here.
825  auto *block = linalgOp.getBlock();
826  if (isa<BlockArgument>(val))
827  return llvm::all_of(block->getArguments(),
828  [&val](Value v) { return (v != val); });
829 
830  Operation *defOp = val.getDefiningOp();
831  assert(defOp && "This is neither a block argument nor an operation result");
832 
833  // IndexOp is loop invariant as long as its result remains constant across
834  // iterations. Given the assumptions on the loop ranges above, only the
835  // trailing loop dim ever changes.
836  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
837  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
838  return (indexOp.getDim() != trailingLoopDim);
839 
840  auto *ancestor = block->findAncestorOpInBlock(*defOp);
841 
842  // Values define outside `linalgOp` are loop invariant.
843  if (!ancestor)
844  return true;
845 
846  // Values defined inside `linalgOp`, which are constant, are loop invariant.
847  if (isa<arith::ConstantOp>(ancestor))
848  return true;
849 
850  bool result = true;
851  for (auto op : ancestor->getOperands())
852  result &= isLoopInvariantIdx(linalgOp, op);
853 
854  return result;
855 }
856 
857 /// Check whether \p val could be used for calculating the trailing index for a
858 /// contiguous load operation.
859 ///
860 /// There are currently 3 types of values that are allowed here:
861 /// 1. loop-invariant values,
862 /// 2. values that increment by 1 with every loop iteration,
863 /// 3. results of basic arithmetic operations (linear and continuous)
864 /// involving 1., 2. and 3.
865 /// This method returns True if indeed only such values are used in calculating
866 /// \p val.
867 ///
868 /// Additionally, the trailing index for a contiguous load operation should
869 /// increment by 1 with every loop iteration, i.e. be based on:
870 /// * `linalg.index <dim>` ,
871 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
872 /// updated to `true` when such an op is found.
873 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
874  bool &foundIndexOp) {
875 
876  auto targetShape = linalgOp.getStaticLoopRanges();
877  assert(((llvm::count_if(targetShape,
878  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
879  "n-D vectors are not yet supported");
880  assert(targetShape.back() != 1 &&
881  "1-D vectors with the trailing dim 1 are not yet supported");
882 
883  // Blocks outside _this_ linalg.generic are effectively loop invariant.
884  // However, analysing block arguments for _this_ linalg.generic Op is a bit
885  // tricky. Just bail out in the latter case.
886  // TODO: We could try analysing the corresponding affine map here.
887  auto *block = linalgOp.getBlock();
888  if (isa<BlockArgument>(val))
889  return llvm::all_of(block->getArguments(),
890  [&val](Value v) { return (v != val); });
891 
892  Operation *defOp = val.getDefiningOp();
893  assert(defOp && "This is neither a block argument nor an operation result");
894 
895  // Given the assumption on the loop ranges above, only the trailing loop
896  // index is not constant.
897  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
898  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
899  foundIndexOp = (indexOp.getDim() == trailingLoopDim);
900  return true;
901  }
902 
903  auto *ancestor = block->findAncestorOpInBlock(*defOp);
904 
905  if (!ancestor)
906  return false;
907 
908  // Conservatively reject Ops that could lead to indices with stride other
909  // than 1.
910  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
911  return false;
912 
913  bool result = false;
914  for (auto op : ancestor->getOperands())
915  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
916 
917  return result;
918 }
919 
920 /// Infer the memory access pattern for the input ExtractOp
921 ///
922 /// Based on the operation shapes and indices (usually based on the iteration
923 /// space of the parent `linalgOp` operation), decides whether the input
924 /// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
925 /// gather load.
926 ///
927 /// Note that it is always safe to use gather load operations for contiguous
928 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
929 /// that `extractOp` is a gather load.
931 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
932  LinalgOp &linalgOp) {
933 
934  auto targetShape = linalgOp.getStaticLoopRanges();
935  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
936 
937  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
938  if (inputShape.getShape().empty())
940 
941  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
942  // gather load.
943  // TODO: Relax this condition.
944  if (linalgOp.hasDynamicShape())
946 
947  // 1. Assume that it's a gather load when reading _into_:
948  // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
949  // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
950  // TODO: Relax these conditions.
951  // FIXME: This condition assumes non-dynamic sizes.
952  if ((llvm::count_if(targetShape,
953  [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
954  targetShape.back() == 1)
956 
957  // 2. Assume that it's a gather load when reading _from_ a tensor for which
958  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
959  // TODO: Relax this condition.
960  if (inputShape.getShape().back() == 1)
962 
963  bool leadingIdxsLoopInvariant = true;
964 
965  // 3. Analyze the leading indices of `extractOp`.
966  // Look at the way each index is calculated and decide whether it is suitable
967  // for a contiguous load, i.e. whether it's loop invariant.
968  auto indices = extractOp.getIndices();
969  auto leadIndices = indices.drop_back(1);
970 
971  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
972  if (inputShape.getShape()[i] == 1)
973  continue;
974 
975  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
976  }
977 
978  if (!leadingIdxsLoopInvariant) {
979  LDBG("Found gather load: " << extractOp);
981  }
982 
983  // 4. Analyze the trailing index for `extractOp`.
984  // At this point we know that the leading indices are loop invariant. This
985  // means that is potentially a scalar or a contiguous load. We can decide
986  // based on the trailing idx.
987  auto extractOpTrailingIdx = indices.back();
988 
989  // 4a. Scalar broadcast load
990  // If the trailing index is loop invariant then this is a scalar load.
991  if (leadingIdxsLoopInvariant &&
992  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
993  LDBG("Found scalar broadcast load: " << extractOp);
994 
996  }
997 
998  // 4b. Contiguous loads
999  // The trailing `extractOp` index should increment with every loop iteration.
1000  // This effectively means that it must be based on the trailing loop index.
1001  // This is what the following bool captures.
1002  bool foundIndexOp = false;
1003  bool isContiguousLoad =
1004  isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
1005  isContiguousLoad &= foundIndexOp;
1006 
1007  if (isContiguousLoad) {
1008  LDBG("Found contigous load: " << extractOp);
1010  }
1011 
1012  // 5. Fallback case - gather load.
1013  LDBG("Found gather load: " << extractOp);
1015 }
1016 
1017 /// Helper function to vectorize the tensor.extract operations. Returns
1018 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1019 /// should map the produced operations. This function is meant to be used as a
1020 /// CustomVectorizationHook.
1021 static VectorizationResult
1023  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1024  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1025  if (!extractOp)
1027  auto loc = extractOp.getLoc();
1028 
1029  // Compute the static loop sizes of the extract op.
1030  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1031  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1032  loc,
1033  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1034  /*value=*/true));
1035  auto passThruConstantOp =
1036  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1037 
1038  // Base indices are currently set to 0. We will need to re-visit if more
1039  // generic scenarios are to be supported.
1040  SmallVector<Value> baseIndices(
1041  extractOp.getIndices().size(),
1042  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1043 
1044  VectorMemoryAccessKind memAccessKind =
1045  getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1046 
1047  // 1. Handle gather access
1048  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1049  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1050 
1051  // Generate the gather load
1052  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1053  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1054  maskConstantOp, passThruConstantOp);
1055  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1056 
1057  LDBG("Vectorised as gather load: " << extractOp << "\n");
1059  }
1060 
1061  // 2. Handle:
1062  // a. scalar loads + broadcast,
1063  // b. contiguous loads.
1064  // Both cases use vector.transfer_read.
1065 
1066  // Collect indices for `vector.transfer_read`. At this point, the indices will
1067  // either be scalars or would have been broadcast to vectors matching the
1068  // result type. For indices that are vectors, there are two options:
1069  // * for non-trailing indices, all elements are identical (contiguous
1070  // loads are identified by looking for non-trailing indices that are
1071  // invariant with respect to the corresponding linalg.generic), or
1072  // * for trailing indices, the index vector will contain values with stride
1073  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1074  // needed.
1075  // This means that
1076  // * for scalar indices - just re-use it,
1077  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1078  // (0th) element and use that.
1079  SmallVector<Value> transferReadIdxs;
1080  auto zero = rewriter.create<arith::ConstantOp>(
1081  loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1082  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1083  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1084  if (idx.getType().isIndex()) {
1085  transferReadIdxs.push_back(idx);
1086  continue;
1087  }
1088 
1089  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1090  loc,
1091  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1092  resultType.getScalableDims().back()),
1093  idx);
1094  transferReadIdxs.push_back(
1095  rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1096  }
1097 
1098  // `tensor.extract_element` is always in-bounds, hence the following holds.
1099  auto dstRank = resultType.getRank();
1100  auto srcRank = extractOp.getTensor().getType().getRank();
1101  SmallVector<bool> inBounds(dstRank, true);
1102 
1103  // 2a. Handle scalar broadcast access.
1104  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1105  MLIRContext *ctx = rewriter.getContext();
1106  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1107  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1108 
1109  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1110  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1111  permutationMap, inBounds);
1112 
1113  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1114  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1115  }
1116 
1117  // 2b. Handle contiguous access.
1118  auto permutationMap = AffineMap::getMinorIdentityMap(
1119  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1120 
1121  int32_t rankDiff = dstRank - srcRank;
1122  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1123  // dims so that the ranks match. This is done by extending the map with 0s.
1124  // For example, for dstRank = 3, srcRank = 2, the following map created
1125  // above:
1126  // (d0, d1) --> (d0, d1)
1127  // is extended as:
1128  // (d0, d1) --> (0, d0, d1)
1129  while (rankDiff > 0) {
1130  permutationMap = permutationMap.insertResult(
1131  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1132  rankDiff--;
1133  }
1134 
1135  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1136  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1137  inBounds);
1138 
1139  LDBG("Vectorised as contiguous load: " << extractOp);
1140  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1141 }
1142 
1143 /// Emit reduction operations if the shapes of the value to reduce is different
1144 /// that the result shape.
1145 // Note: this is a true builder that notifies the OpBuilder listener.
1146 // TODO: Consider moving as a static helper on the ReduceOp.
1147 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1148  Value reduceValue, Value initialValue,
1149  const IRMapping &bvm) {
1150  Value reduceVec = bvm.lookup(reduceValue);
1151  Value outputVec = bvm.lookup(initialValue);
1152  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1153  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1154  // Reduce only if needed as the value may already have been reduce for
1155  // contraction vectorization.
1156  if (!reduceType ||
1157  (outputType && reduceType.getShape() == outputType.getShape()))
1158  return nullptr;
1159  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1160  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1161 }
1162 
1163 /// Generic vectorization for a single operation `op`, given already vectorized
1164 /// operands carried by `bvm`. Vectorization occurs as follows:
1165 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1166 /// result on success.
1167 /// 2. Clone any constant in the current scope without vectorization: each
1168 /// consumer of the constant will later determine the shape to which the
1169 /// constant needs to be broadcast to.
1170 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1171 /// of the `customVectorizationHooks` to cover such cases.
1172 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1173 /// operand of maximal rank. Other operands have smaller rank and are
1174 /// broadcast accordingly. It is assumed this broadcast is always legal,
1175 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1176 ///
1177 /// This function assumes all operands of `op` have been vectorized and are in
1178 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1179 /// a topologically-sorted list of ops.
1180 /// This function does not update `bvm` but returns a VectorizationStatus that
1181 /// instructs the caller what `bvm` update needs to occur.
1182 static VectorizationResult
1184  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1185  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1186  LDBG("vectorize op " << *op << "\n");
1187 
1188  // 1. Try to apply any CustomVectorizationHook.
1189  if (!customVectorizationHooks.empty()) {
1190  for (auto &customFunc : customVectorizationHooks) {
1191  VectorizationResult result = customFunc(op, bvm);
1192  if (result.status == VectorizationStatus::Failure)
1193  continue;
1194  return result;
1195  }
1196  }
1197 
1198  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1199  // Clone so that the constant is not confined to the linalgOp block .
1200  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1201  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1202 
1203  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1206 
1207  // 4 . Check if the operation is a reduction.
1208  SmallVector<std::pair<Value, Value>> reductionOperands;
1209  for (Value operand : op->getOperands()) {
1210  auto blockArg = dyn_cast<BlockArgument>(operand);
1211  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1212  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1213  continue;
1214  SmallVector<Operation *> reductionOps;
1215  Value reduceValue = matchReduction(
1216  linalgOp.getRegionOutputArgs(),
1217  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1218  if (!reduceValue)
1219  continue;
1220  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1221  }
1222  if (!reductionOperands.empty()) {
1223  assert(reductionOperands.size() == 1);
1224  Operation *reduceOp =
1225  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1226  reductionOperands[0].second, bvm);
1227  if (reduceOp)
1229  }
1230 
1231  // 5. Generic vectorization path for ElementwiseMappable ops.
1232  // a. Get the first max ranked shape.
1233  VectorType firstMaxRankedType;
1234  for (Value operand : op->getOperands()) {
1235  auto vecOperand = bvm.lookup(operand);
1236  assert(vecOperand && "Vector operand couldn't be found");
1237 
1238  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1239  if (vecType && (!firstMaxRankedType ||
1240  firstMaxRankedType.getRank() < vecType.getRank()))
1241  firstMaxRankedType = vecType;
1242  }
1243  // b. Broadcast each op if needed.
1244  SmallVector<Value> vecOperands;
1245  for (Value scalarOperand : op->getOperands()) {
1246  Value vecOperand = bvm.lookup(scalarOperand);
1247  assert(vecOperand && "Vector operand couldn't be found");
1248 
1249  if (firstMaxRankedType) {
1250  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1251  getElementTypeOrSelf(vecOperand.getType()),
1252  firstMaxRankedType.getScalableDims());
1253  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1254  } else {
1255  vecOperands.push_back(vecOperand);
1256  }
1257  }
1258  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1259  SmallVector<Type> resultTypes;
1260  for (Type resultType : op->getResultTypes()) {
1261  resultTypes.push_back(
1262  firstMaxRankedType
1263  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1264  firstMaxRankedType.getScalableDims())
1265  : resultType);
1266  }
1267  // d. Build and return the new op.
1268  return VectorizationResult{
1270  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1271  resultTypes, op->getAttrs())};
1272 }
1273 
1274 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1275 /// vector form. Generic vectorization proceeds as follows:
1276 /// 1. Verify the `linalgOp` has one non-empty region.
1277 /// 2. Values defined above the region are mapped to themselves and will be
1278 /// broadcasted on a per-need basis by their consumers.
1279 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1280 /// load).
1281 /// TODO: Reuse opportunities for RAR dependencies.
1282 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1283 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1284 /// iteration indices.
1285 /// 5. Iteratively call vectorizeOneOp on the region operations.
1286 ///
1287 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1288 /// performed to the maximal common vector size implied by the `linalgOp`
1289 /// iteration space. This eager broadcasting is introduced in the
1290 /// permutation_map of the vector.transfer_read operations. The eager
1291 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1292 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1293 /// the absence of good canonicalizations, the amount of work increases.
1294 /// This is not deemed a problem as we expect canonicalizations and foldings to
1295 /// aggressively clean up the useless work.
1296 static LogicalResult
1298  LinalgOp linalgOp,
1299  SmallVectorImpl<Value> &newResults) {
1300  LDBG("Vectorizing operation as linalg generic\n");
1301  Block *block = linalgOp.getBlock();
1302 
1303  // 2. Values defined above the region can only be broadcast for now. Make them
1304  // map to themselves.
1305  IRMapping bvm;
1306  SetVector<Value> valuesSet;
1307  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1308  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1309 
1310  if (linalgOp.getNumDpsInits() == 0)
1311  return failure();
1312 
1313  // 3. Turn all BBArgs into vector.transfer_read / load.
1314  Location loc = linalgOp.getLoc();
1315  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1316  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1317  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1318  if (linalgOp.isScalar(opOperand)) {
1319  bvm.map(bbarg, opOperand->get());
1320  continue;
1321  }
1322 
1323  // 3.a. Convert the indexing map for this input/output to a transfer read
1324  // permutation map and masking map.
1325  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1326 
1327  // Remove zeros from indexing map to use it as masking map.
1328  SmallVector<int64_t> zeroPos;
1329  auto results = indexingMap.getResults();
1330  for (const auto &result : llvm::enumerate(results)) {
1331  if (isa<AffineConstantExpr>(result.value())) {
1332  zeroPos.push_back(result.index());
1333  }
1334  }
1335  AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1336 
1337  AffineMap readMap;
1338  VectorType readType;
1339  Type elemType = getElementTypeOrSelf(opOperand->get());
1340  if (linalgOp.isDpsInput(opOperand)) {
1341  // 3.a.i. For input reads we use the canonical vector shape.
1342  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1343  readType = state.getCanonicalVecType(elemType);
1344  } else {
1345  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1346  // reductions), the vector shape is computed by mapping the canonical
1347  // vector shape to the output domain and back to the canonical domain.
1348  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1349  readType =
1350  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1351  }
1352 
1353  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1354 
1355  // Make sure that the in_bounds attribute corresponding to a broadcast dim
1356  // is `true`
1357  SmallVector<unsigned> broadcastedDims = readMap.getBroadcastDims();
1358  SmallVector<bool> inBounds(readType.getRank(), false);
1359 
1360  for (auto idx : broadcastedDims)
1361  inBounds[idx] = true;
1362 
1363  Operation *read = rewriter.create<vector::TransferReadOp>(
1364  loc, readType, opOperand->get(), indices, readMap,
1365  ArrayRef<bool>(inBounds));
1366  read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1367  Value readValue = read->getResult(0);
1368 
1369  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1370  // will be in-bounds.
1371  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1372  SmallVector<bool> inBounds(readType.getRank(), true);
1373  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1374  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1375  }
1376 
1377  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1378  // TODO: remove this.
1379  if (readType.getRank() == 0)
1380  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1381 
1382  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1383  << "\n");
1384  bvm.map(bbarg, readValue);
1385  bvm.map(opOperand->get(), readValue);
1386  }
1387 
1389  // 4a. Register CustomVectorizationHook for yieldOp.
1390  CustomVectorizationHook vectorizeYield =
1391  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1392  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1393  };
1394  hooks.push_back(vectorizeYield);
1395 
1396  // 4b. Register CustomVectorizationHook for indexOp.
1397  CustomVectorizationHook vectorizeIndex =
1398  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1399  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1400  };
1401  hooks.push_back(vectorizeIndex);
1402 
1403  // 4c. Register CustomVectorizationHook for extractOp.
1404  CustomVectorizationHook vectorizeExtract =
1405  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1406  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1407  };
1408  hooks.push_back(vectorizeExtract);
1409 
1410  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1411  for (Operation &op : block->getOperations()) {
1412  VectorizationResult result =
1413  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1414  if (result.status == VectorizationStatus::Failure) {
1415  LDBG("failed to vectorize: " << op << "\n");
1416  return failure();
1417  }
1418  if (result.status == VectorizationStatus::NewOp) {
1419  Operation *maybeMaskedOp =
1420  state.maskOperation(rewriter, result.newOp, linalgOp);
1421  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1422  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1423  }
1424  }
1425 
1426  return success();
1427 }
1428 
1429 /// Given a tensor::PackOp, return the `dest` shape before any packing
1430 /// permutations.
1431 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1432  ArrayRef<int64_t> destShape) {
1433  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1434 }
1435 
1436 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1437 /// create an empty destination tensor and create a TransferWriteOp from the
1438 /// input to the empty tensor. If the destination shape is not the same as the
1439 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1440 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1441 /// inBounds attribute of the transfer write op instead of masking.
1443  Value input,
1444  SmallVector<OpFoldResult> destSizes,
1445  ArrayRef<int64_t> inputVectorSizes,
1446  bool useInBoundsInsteadOfMasking) {
1447 
1448  auto inputType = cast<VectorType>(input.getType());
1449  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1450  inputType.getElementType());
1451  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1452  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1453  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1454  SmallVector<bool> inBoundsVal(rank, true);
1455  if (useInBoundsInsteadOfMasking) {
1456  // Update the inBounds attribute.
1457  for (unsigned i = 0; i < rank; i++)
1458  inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1459  !ShapedType::isDynamic(destShape[i]);
1460  }
1461  Operation *write = builder.create<vector::TransferWriteOp>(
1462  loc,
1463  /*vector=*/input,
1464  /*source=*/dest,
1465  /*indices=*/SmallVector<Value>(rank, zero),
1466  /*inBounds=*/inBoundsVal);
1467  assert(llvm::none_of(
1468  destShape.drop_front(inputVectorSizes.size()),
1469  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1470  "Only dims aligned with inputVectorSizes may be dynamic");
1471  if (useInBoundsInsteadOfMasking)
1472  return write;
1473  bool needMaskForWrite = !llvm::equal(
1474  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1475  if (needMaskForWrite) {
1476  SmallVector<int64_t> writeMaskShape;
1477  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1478  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1479  destShape.end());
1480  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1481  Value maskForWrite =
1482  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1483  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1484  }
1485  return write;
1486 }
1487 
1488 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1489 /// padding value and (3) input vector sizes into:
1490 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1491 /// As in the following example:
1492 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1493 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1494 ///
1495 /// This pack would be vectorized to:
1496 ///
1497 /// %load = vector.mask %mask {
1498 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1499 /// {in_bounds = [true, true, true]} :
1500 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1501 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1502 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1503 /// to vector<32x4x2x1x16xf32>
1504 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1505 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1506 /// %write = vector.transfer_write %transpose,
1507 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1508 /// {in_bounds = [true, true, true, true, true]}
1509 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1510 ///
1511 /// If the (3) input vector sizes are not provided, the vector sizes are
1512 /// determined by the result tensor shape. Also, we update the inBounds
1513 /// attribute instead of masking.
1514 static LogicalResult
1515 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1516  ArrayRef<int64_t> inputVectorSizes,
1517  SmallVectorImpl<Value> &newResults) {
1518  OpBuilder::InsertionGuard g(rewriter);
1519  rewriter.setInsertionPoint(packOp);
1520 
1521  Location loc = packOp.getLoc();
1522  auto padValue = packOp.getPaddingValue();
1523  if (!padValue) {
1524  padValue = rewriter.create<arith::ConstantOp>(
1525  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1526  }
1527  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1528  LogicalResult status =
1529  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1530  .reifyResultShapes(rewriter, reifiedReturnShapes);
1531  (void)status; // prevent unused variable warning on non-assert builds.
1532  assert(succeeded(status) && "failed to reify result shapes");
1533 
1534  // If the input vector sizes are not provided, then the vector sizes are
1535  // determined by the result tensor shape. In case the vector sizes aren't
1536  // provided, we update the inBounds attribute instead of masking.
1537  bool useInBoundsInsteadOfMasking = false;
1538  if (inputVectorSizes.empty()) {
1539  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1540  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1541  useInBoundsInsteadOfMasking = true;
1542  }
1543 
1544  // Create masked TransferReadOp.
1545  SmallVector<int64_t> inputShape(inputVectorSizes);
1546  auto innerTiles = packOp.getStaticInnerTiles();
1547  auto innerDimsPos = packOp.getInnerDimsPos();
1548  auto outerDimsPerm = packOp.getOuterDimsPerm();
1549  if (!outerDimsPerm.empty())
1550  applyPermutationToVector(inputShape,
1551  invertPermutationVector(outerDimsPerm));
1552  for (auto [idx, size] : enumerate(innerTiles))
1553  inputShape[innerDimsPos[idx]] *= size;
1554  auto maskedRead = vector::createReadOrMaskedRead(
1555  rewriter, loc, packOp.getSource(), inputShape, padValue,
1556  useInBoundsInsteadOfMasking);
1557 
1558  // Create ShapeCastOp.
1559  SmallVector<int64_t> destShape(inputVectorSizes);
1560  destShape.append(innerTiles.begin(), innerTiles.end());
1561  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1562  packOp.getDestType().getElementType());
1563  auto shapeCastOp =
1564  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1565 
1566  // Create TransposeOp.
1567  auto destPermutation =
1569  auto transposeOp = rewriter.create<vector::TransposeOp>(
1570  loc, shapeCastOp.getResult(), destPermutation);
1571 
1572  // Create TransferWriteOp.
1574  rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1575  inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1576  newResults.push_back(write->getResult(0));
1577  return success();
1578 }
1579 
1580 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1581 /// Vector::TransferReadOp - Reads a vector from the source tensor
1582 /// vector::TransposeOp - Transpose the Source tensor
1583 /// ShapeCastOp - Reshape the data based on the target.
1584 /// vector::TransferWriteOp. - Write the result vector back to the destination
1585 /// tensor.
1586 /// If the vector sizes are not provided:
1587 /// * the vector sizes are determined by the input operand and attributes,
1588 /// * update the inBounds attribute instead of masking.
1589 static LogicalResult
1590 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1591  ArrayRef<int64_t> inputVectorSizes,
1592  SmallVectorImpl<Value> &newResults) {
1593 
1594  OpBuilder::InsertionGuard g(rewriter);
1595  rewriter.setInsertionPoint(unpackOp);
1596 
1597  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1598 
1599  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1600  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1601  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1602  bool useInBoundsInsteadOfMasking = false;
1603  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1604 
1605  auto destSize = unpackOp.getDestRank();
1606 
1607  if (!inputVectorSizes.empty())
1608  assert(inputVectorSizes.size() == destSize &&
1609  "Incorrect number of input vector sizes");
1610 
1611  // vectorSizes is the shape of the vector that will be used to do final
1612  // write on the destination tensor. It is set like this: Let's say the
1613  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1614  // Thus:
1615  // 1. vectorSizes = sourceShape.take_front(N)
1616  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1617  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1618  // innerTiles attribute value.
1619  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1620  if (vectorSizes.empty()) {
1621  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1622  if (!outerDimsPerm.empty())
1623  applyPermutationToVector(vectorSizes, outerDimsPerm);
1624  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1625  vectorSizes[pos] *= innerTiles[i];
1626 
1627  useInBoundsInsteadOfMasking = true;
1628  }
1629 
1630  // readVectorSizes is the size of tensor used to read and apply mask. It is
1631  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1632  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1633  // size M-N
1634  // Thus:
1635  // - initially: readVectorSizes = vectorInputSizes
1636  // - Divide all the readMaskShape locations pointed by innerDimPos
1637  // by the innerTileSize attribute value.
1638  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1639  // - Append the remaining shape from SS
1640  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1641  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1642  // 128] and outer_dims_perm is [1, 0] then read shape is:
1643  // ReadVectorSizes(initial): [512, 128]
1644  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1645  // = [16, 8]
1646  // After applying outer_dims_perm: [8, 16]
1647  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1648 
1649  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1650 
1651  for (auto [index, size] : enumerate(innerTiles)) {
1652  readVectorSizes[innerDimPos[index]] =
1653  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1654  }
1655  if (!outerDimsPerm.empty()) {
1656  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1657  }
1658  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1659  sourceShape.end());
1660 
1661  ReifiedRankedShapedTypeDims reifiedRetShapes;
1662  LogicalResult status =
1663  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1664  .reifyResultShapes(rewriter, reifiedRetShapes);
1665  if (status.failed()) {
1666  LDBG("Unable to reify result shapes of " << unpackOp);
1667  return failure();
1668  }
1669  Location loc = unpackOp->getLoc();
1670 
1671  auto padValue = rewriter.create<arith::ConstantOp>(
1672  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1673 
1674  // Read result, mask if necessary. If transferReadOp shape is not equal
1675  // to shape of source, then a mask is necessary.
1677  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1678  /*useInBoundsInsteadOfMasking=*/false);
1679 
1680  PackingMetadata packMetadata;
1681  SmallVector<int64_t> lastDimToInsertPosPerm =
1682  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1683  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1684  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1685  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1686  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1687  RankedTensorType stripMineTensorType =
1688  RankedTensorType::get(stripMineShape, stripMineElemType);
1689  // Transpose the appropriate rows to match output.
1690  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1691  loc, readResult, lastDimToInsertPosPerm);
1692 
1693  // Collapse the vector to the size required by result.
1694  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1695  stripMineTensorType, packMetadata.reassociations);
1696  mlir::VectorType vecCollapsedType =
1697  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1698  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1699  loc, vecCollapsedType, transposeOp->getResult(0));
1700 
1701  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1702  // otherwise the validator complains that the mask size is invalid.
1703  SmallVector<int64_t> writeVectorSizes(
1704  unpackOp.getDestType().hasStaticShape()
1705  ? vectorSizes
1706  : shapeCastOp.getResultVectorType().getShape());
1708  rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1709  writeVectorSizes, useInBoundsInsteadOfMasking);
1710  newResults.push_back(write->getResult(0));
1711  return success();
1712 }
1713 
1714 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1715 /// and (3) all-zero lowPad to
1716 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1717 static LogicalResult
1718 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1719  ArrayRef<int64_t> inputVectorSizes,
1720  SmallVectorImpl<Value> &newResults) {
1721  auto padValue = padOp.getConstantPaddingValue();
1722  Location loc = padOp.getLoc();
1723 
1724  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1725  OpBuilder::InsertionGuard g(rewriter);
1726  rewriter.setInsertionPoint(padOp);
1727 
1728  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1729  LogicalResult status =
1730  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1731  .reifyResultShapes(rewriter, reifiedReturnShapes);
1732  (void)status; // prevent unused variable warning on non-assert builds
1733  assert(succeeded(status) && "failed to reify result shapes");
1734  auto maskedRead = vector::createReadOrMaskedRead(
1735  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1736  /*useInBoundsInsteadOfMasking=*/false);
1738  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1739  /*useInBoundsInsteadOfMasking=*/false);
1740  newResults.push_back(write->getResult(0));
1741  return success();
1742 }
1743 
1744 // TODO: probably need some extra checks for reduction followed by consumer
1745 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1746 static LogicalResult reductionPreconditions(LinalgOp op) {
1747  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1748  LDBG("reduction precondition failed: no reduction iterator\n");
1749  return failure();
1750  }
1751  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1752  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1753  if (indexingMap.isPermutation())
1754  continue;
1755 
1756  Operation *reduceOp = matchLinalgReduction(&opOperand);
1757  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1758  LDBG("reduction precondition failed: reduction detection failed\n");
1759  return failure();
1760  }
1761  }
1762  return success();
1763 }
1764 
1765 static LogicalResult
1767  bool flatten1DDepthwiseConv) {
1768  if (flatten1DDepthwiseConv) {
1769  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1770  "supported\n");
1771  return failure();
1772  }
1773 
1774  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1775  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1776  return failure();
1777  }
1778 
1779  // Support dynamic shapes in 1D depthwise convolution, but only in the
1780  // _channel_ dimension.
1781  Value lhs = conv.getDpsInputOperand(0)->get();
1782  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1783  auto shapeWithoutCh = lhsShape.drop_back(1);
1784  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1785  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1786  "channel dim can be dynamic\n");
1787  return failure();
1788  }
1789 
1790  return success();
1791 }
1792 
1793 static LogicalResult
1795  bool flatten1DDepthwiseConv) {
1796  if (isa<ConvolutionOpInterface>(op.getOperation()))
1797  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1798 
1799  if (hasReductionIterator(op))
1800  return reductionPreconditions(op);
1801 
1802  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1803  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1804  if (!isElementwise(op) &&
1805  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1806  op.getOperation()))
1807  return failure();
1808 
1809  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1810  return success();
1811 }
1812 
1813 /// Need to check if the inner-tiles are static/constant.
1814 static LogicalResult
1815 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1816  ArrayRef<int64_t> inputVectorSizes) {
1817 
1818  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1819  return !getConstantIntValue(res).has_value();
1820  })) {
1821  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1822  return failure();
1823  }
1824  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1825  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1826  unpackOp.getDestType().hasStaticShape() &&
1827  unpackOp.getSourceType().hasStaticShape();
1828  if (!satisfyEmptyCond &&
1829  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1830  return failure();
1831 
1832  return success();
1833 }
1834 
1835 static LogicalResult vectorizeLinalgOpPrecondition(
1836  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1837  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1838  // tensor with dimension of 0 cannot be vectorized.
1839  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1840  return failure();
1841  // Check API contract for input vector sizes.
1842  if (!inputVectorSizes.empty() &&
1843  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1844  inputVectorSizes)))
1845  return failure();
1846 
1847  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1848  linalgOp, flatten1DDepthwiseConv))) {
1849  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1850  return failure();
1851  }
1852 
1854 
1855  // Register CustomVectorizationPrecondition for extractOp.
1856  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1857 
1858  // All types in the body should be a supported element type for VectorType.
1859  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1860  // Check if any custom hook can vectorize the inner op.
1861  if (llvm::any_of(
1862  customPreconditions,
1863  [&](const CustomVectorizationPrecondition &customPrecondition) {
1864  return succeeded(
1865  customPrecondition(&innerOp, vectorizeNDExtract));
1866  })) {
1867  continue;
1868  }
1869  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1870  return !VectorType::isValidElementType(type);
1871  })) {
1872  return failure();
1873  }
1874  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1875  return !VectorType::isValidElementType(type);
1876  })) {
1877  return failure();
1878  }
1879  }
1880  if (isElementwise(linalgOp))
1881  return success();
1882 
1883  // TODO: isaConvolutionOpInterface that can also infer from generic
1884  // features. But we will still need stride/dilation attributes that will be
1885  // annoying to reverse-engineer...
1886  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1887  return success();
1888  // TODO: the common vector shape is equal to the static loop sizes only when
1889  // all indexing maps are projected permutations. For convs and stencils the
1890  // logic will need to evolve.
1891  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1892  LDBG("precondition failed: not projected permutations\n");
1893  return failure();
1894  }
1895  if (failed(reductionPreconditions(linalgOp))) {
1896  LDBG("precondition failed: reduction preconditions\n");
1897  return failure();
1898  }
1899  return success();
1900 }
1901 
1902 static LogicalResult
1903 vectorizePackOpPrecondition(tensor::PackOp packOp,
1904  ArrayRef<int64_t> inputVectorSizes) {
1905  auto padValue = packOp.getPaddingValue();
1906  Attribute cstAttr;
1907  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1908  LDBG("pad value is not constant: " << packOp << "\n");
1909  return failure();
1910  }
1911  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1912  bool satisfyEmptyCond = true;
1913  if (inputVectorSizes.empty()) {
1914  if (!packOp.getDestType().hasStaticShape() ||
1915  !packOp.getSourceType().hasStaticShape())
1916  satisfyEmptyCond = false;
1917  }
1918 
1919  if (!satisfyEmptyCond &&
1921  resultTensorShape.take_front(packOp.getSourceRank()),
1922  inputVectorSizes)))
1923  return failure();
1924 
1925  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1926  return !getConstantIntValue(v).has_value();
1927  })) {
1928  LDBG("inner_tiles must be constant: " << packOp << "\n");
1929  return failure();
1930  }
1931 
1932  return success();
1933 }
1934 
1935 static LogicalResult
1936 vectorizePadOpPrecondition(tensor::PadOp padOp,
1937  ArrayRef<int64_t> inputVectorSizes) {
1938  auto padValue = padOp.getConstantPaddingValue();
1939  if (!padValue) {
1940  LDBG("pad value is not constant: " << padOp << "\n");
1941  return failure();
1942  }
1943 
1944  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1945  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1946  inputVectorSizes)))
1947  return failure();
1948 
1949  if (llvm::any_of(padOp.getLow(), [](Value v) {
1950  std::optional<int64_t> res = getConstantIntValue(v);
1951  return !res.has_value() || res.value() != 0;
1952  })) {
1953  LDBG("low pad must all be zero: " << padOp << "\n");
1954  return failure();
1955  }
1956 
1957  return success();
1958 }
1959 
1960 /// Preconditions for scalable vectors. This is quite restrictive - it models
1961 /// the fact that in practice we would only make selected dimensions scalable.
1962 static LogicalResult
1964  ArrayRef<int64_t> inputVectorSizes,
1965  ArrayRef<bool> inputScalableVecDims) {
1966  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1967  "Number of input vector sizes and scalable dims doesn't match");
1968 
1969  size_t numOfScalableDims =
1970  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
1971 
1972  if (numOfScalableDims == 0)
1973  return success();
1974 
1975  auto linalgOp = dyn_cast<LinalgOp>(op);
1976 
1977  // Cond 1: There's been no need for scalable vectorisation of
1978  // non-linalg Ops so far
1979  if (!linalgOp)
1980  return failure();
1981 
1982  // Cond 2: There's been no need for more than 2 scalable dims so far
1983  if (numOfScalableDims > 2)
1984  return failure();
1985 
1986  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
1987  // it matches one of the supported cases:
1988  // 1. exactly 1 dim is scalable and that's the _last_ parallel dim
1989  // 2. exactly 2 dims are scalable and those are the _last two adjacent_
1990  // parallel dims
1991  // 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
1992  // The 2nd restriction above means that only Matmul-like Ops are supported
1993  // when 2 dims are scalable, e.g. :
1994  // * iterators = [parallel, parallel, reduction]
1995  // * scalable flags = [true, true, false]
1996 
1997  // Find the first scalable flag
1998  bool seenParalell = false;
1999  auto iterators = linalgOp.getIteratorTypesArray();
2000  SmallVector<bool> scalableFlags(inputScalableVecDims);
2001  while (!scalableFlags.back()) {
2002  seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2003 
2004  iterators.pop_back();
2005  scalableFlags.pop_back();
2006  }
2007 
2008  switch (iterators.back()) {
2009  case utils::IteratorType::reduction: {
2010  // Check 3. above is met.
2011  if (iterators.size() != inputVectorSizes.size()) {
2012  LDBG("Non-trailing reduction dim requested for scalable "
2013  "vectorization\n");
2014  return failure();
2015  }
2016  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2017  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2018  "is not supported\n");
2019  return failure();
2020  }
2021  break;
2022  }
2023  case utils::IteratorType::parallel: {
2024  // Check 1. and 2. above are met.
2025  if (seenParalell) {
2026  LDBG("Inner parallel dim not requested for scalable "
2027  "vectorization\n");
2028  return failure();
2029  }
2030  break;
2031  }
2032  }
2033 
2034  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2035  // supported for which expect the folowing config:
2036  // * iterators = [parallel, parallel, reduction]
2037  // * scalable flags = [true, true, false]
2038  if (numOfScalableDims == 2) {
2039  // Disallow below case which breaks 3. above:
2040  // * iterators = [..., parallel, reduction]
2041  // * scalable flags = [..., true, true]
2042  if (iterators.back() == utils::IteratorType::reduction) {
2043  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2044  "vectorization\n");
2045  return failure();
2046  }
2047  scalableFlags.pop_back();
2048  iterators.pop_back();
2049 
2050  if (!scalableFlags.back() ||
2051  (iterators.back() != utils::IteratorType::parallel))
2052  return failure();
2053  }
2054 
2055  // Cond 4: Only the following ops are supported in the
2056  // presence of scalable vectors
2057  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2058  isa<linalg::MatmulTransposeAOp>(op) ||
2059  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2060  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2061 }
2062 
2064  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2065  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2066  bool flatten1DDepthwiseConv) {
2067  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2068  inputScalableVecDims)))
2069  return failure();
2070 
2072  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2073  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2074  vectorizeNDExtract,
2075  flatten1DDepthwiseConv);
2076  })
2077  .Case<tensor::PadOp>([&](auto padOp) {
2078  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2079  })
2080  .Case<tensor::PackOp>([&](auto packOp) {
2081  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2082  })
2083  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2084  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2085  })
2086  .Default([](auto) { return failure(); });
2087 }
2088 
2089 /// Converts affine.apply Ops to arithmetic operations.
2090 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2091  OpBuilder::InsertionGuard g(rewriter);
2092  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2093 
2094  for (auto op : make_early_inc_range(toReplace)) {
2095  rewriter.setInsertionPoint(op);
2096  auto expanded = affine::expandAffineExpr(
2097  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2098  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2099  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2100  rewriter.replaceOp(op, expanded);
2101  }
2102 }
2103 
2104 /// Emit a suitable vector form for an operation. If provided,
2105 /// `inputVectorSizes` are used to vectorize this operation.
2106 /// `inputVectorSizes` must match the rank of the iteration space of the
2107 /// operation and the input vector sizes must be greater than or equal to
2108 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2109 /// also allows the vectorization of operations with dynamic shapes.
2110 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2111  ArrayRef<int64_t> inputVectorSizes,
2112  ArrayRef<bool> inputScalableVecDims,
2113  bool vectorizeNDExtract,
2114  bool flatten1DDepthwiseConv) {
2115  LDBG("Attempting to vectorize:\n" << *op << "\n");
2116  LDBG("Input vector sizes: ");
2117  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2118  LLVM_DEBUG(llvm::dbgs() << "\n");
2119  LDBG("Input scalable vector dims: ");
2120  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2121  LLVM_DEBUG(llvm::dbgs() << "\n");
2122 
2123  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2124  vectorizeNDExtract,
2125  flatten1DDepthwiseConv))) {
2126  LDBG("Vectorization pre-conditions failed\n");
2127  return failure();
2128  }
2129 
2130  // Initialize vectorization state.
2131  VectorizationState state(rewriter);
2132  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2133  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2134  inputScalableVecDims))) {
2135  LDBG("Vectorization state couldn't be initialized\n");
2136  return failure();
2137  }
2138  }
2139 
2140  SmallVector<Value> results;
2141  auto vectorizeResult =
2143  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2144  // TODO: isaConvolutionOpInterface that can also infer from
2145  // generic features. Will require stride/dilation attributes
2146  // inference.
2147  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2148  FailureOr<Operation *> convOr = vectorizeConvolution(
2149  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2150  flatten1DDepthwiseConv);
2151  if (succeeded(convOr)) {
2152  llvm::append_range(results, (*convOr)->getResults());
2153  return success();
2154  }
2155 
2156  LDBG("Unsupported convolution can't be vectorized.\n");
2157  return failure();
2158  }
2159 
2160  LDBG("Vectorize generic by broadcasting to the canonical vector "
2161  "shape\n");
2162 
2163  // Pre-process before proceeding.
2164  convertAffineApply(rewriter, linalgOp);
2165 
2166  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2167  // to 'OpBuilder' when it is passed over to some methods like
2168  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2169  // erase an op within these methods, the actual rewriter won't be
2170  // notified and we will end up with read-after-free issues!
2171  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2172  })
2173  .Case<tensor::PadOp>([&](auto padOp) {
2174  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2175  results);
2176  })
2177  .Case<tensor::PackOp>([&](auto packOp) {
2178  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2179  results);
2180  })
2181  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2182  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2183  inputVectorSizes, results);
2184  })
2185  .Default([](auto) { return failure(); });
2186 
2187  if (failed(vectorizeResult)) {
2188  LDBG("Vectorization failed\n");
2189  return failure();
2190  }
2191 
2192  if (!results.empty())
2193  rewriter.replaceOp(op, results);
2194  else
2195  rewriter.eraseOp(op);
2196 
2197  return success();
2198 }
2199 
2201  memref::CopyOp copyOp) {
2202  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2203  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2204  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2205  return failure();
2206 
2207  auto srcElementType = getElementTypeOrSelf(srcType);
2208  auto dstElementType = getElementTypeOrSelf(dstType);
2209  if (!VectorType::isValidElementType(srcElementType) ||
2210  !VectorType::isValidElementType(dstElementType))
2211  return failure();
2212 
2213  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2214  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2215 
2216  Location loc = copyOp->getLoc();
2217  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2218  SmallVector<Value> indices(srcType.getRank(), zero);
2219 
2220  Value readValue = rewriter.create<vector::TransferReadOp>(
2221  loc, readType, copyOp.getSource(), indices,
2222  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2223  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2224  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2225  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2226  }
2227  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2228  loc, readValue, copyOp.getTarget(), indices,
2229  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2230  rewriter.replaceOp(copyOp, writeValue->getResults());
2231  return success();
2232 }
2233 
2234 //----------------------------------------------------------------------------//
2235 // Misc. vectorization patterns.
2236 //----------------------------------------------------------------------------//
2237 
2238 /// Helper function that retrieves the value of an IntegerAttr.
2239 static int64_t getIntFromAttr(Attribute attr) {
2240  return cast<IntegerAttr>(attr).getInt();
2241 }
2242 
2243 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2244 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2245 /// not supported.
2247  ArrayRef<OpFoldResult> ofrs) {
2248  SmallVector<Value> result;
2249  for (auto o : ofrs) {
2250  if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2251  result.push_back(val);
2252  } else {
2253  result.push_back(rewriter.create<arith::ConstantIndexOp>(
2254  loc, getIntFromAttr(o.template get<Attribute>())));
2255  }
2256  }
2257  return result;
2258 }
2259 
2260 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2261 /// InsertSliceOp. For now, only constant padding values are supported.
2262 /// If there is enough static type information, TransferReadOps and
2263 /// TransferWriteOps may be generated instead of InsertSliceOps.
2266  PatternBenefit benefit = 1)
2267  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2268  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2269  /// each dimension size is statically know in the source type or the result
2270  /// type (or both).
2271  static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2272  tensor::PadOp padOp, Value dest) {
2273  auto sourceType = padOp.getSourceType();
2274  auto resultType = padOp.getResultType();
2275  if (!VectorType::isValidElementType(sourceType.getElementType()))
2276  return failure();
2277 
2278  // Copy cannot be vectorized if pad value is non-constant and source shape
2279  // is dynamic. In case of a dynamic source shape, padding must be appended
2280  // by TransferReadOp, but TransferReadOp supports only constant padding.
2281  auto padValue = padOp.getConstantPaddingValue();
2282  if (!padValue) {
2283  if (!sourceType.hasStaticShape())
2284  return failure();
2285  // Create dummy padding value.
2286  auto elemType = sourceType.getElementType();
2287  padValue = rewriter.create<arith::ConstantOp>(
2288  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2289  }
2290 
2291  SmallVector<int64_t> vecShape;
2292  SmallVector<bool> readInBounds;
2293  SmallVector<bool> writeInBounds;
2294  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2295  if (!sourceType.isDynamicDim(i)) {
2296  vecShape.push_back(sourceType.getDimSize(i));
2297  // Source shape is statically known: Neither read nor write are
2298  // out-of- bounds.
2299  readInBounds.push_back(true);
2300  writeInBounds.push_back(true);
2301  } else if (!resultType.isDynamicDim(i)) {
2302  // Source shape is not statically known, but result shape is.
2303  // Vectorize with size of result shape. This may be larger than the
2304  // source size.
2305  vecShape.push_back(resultType.getDimSize(i));
2306  // Read may be out-of-bounds because the result size could be larger
2307  // than the source size.
2308  readInBounds.push_back(false);
2309  // Write is out-of-bounds if low padding > 0.
2310  writeInBounds.push_back(
2311  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2312  static_cast<int64_t>(0));
2313  } else {
2314  // Neither source nor result dim of padOp is static. Cannot vectorize
2315  // the copy.
2316  return failure();
2317  }
2318  }
2319  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2320 
2321  // Generate TransferReadOp.
2322  SmallVector<Value> readIndices(
2323  vecType.getRank(),
2324  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2325  auto read = rewriter.create<vector::TransferReadOp>(
2326  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2327  ArrayRef<bool>{readInBounds});
2328 
2329  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2330  // entire tensor, write directly to the FillOp's operand.
2331  if (llvm::equal(vecShape, resultType.getShape()) &&
2332  llvm::all_of(writeInBounds, [](bool b) { return b; }))
2333  if (auto fill = dest.getDefiningOp<FillOp>())
2334  dest = fill.output();
2335 
2336  // Generate TransferWriteOp.
2337  auto writeIndices =
2338  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2339  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2340  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2341 
2342  return success();
2343  }
2344 };
2345 
2346 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2347 /// given operation type OpTy.
2348 template <typename OpTy>
2349 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2351 
2352  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2353  PatternRewriter &rewriter) const final {
2354  bool changed = false;
2355  // Insert users in vector, because some users may be replaced/removed.
2356  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2357  if (auto op = dyn_cast<OpTy>(user))
2358  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2359  return success(changed);
2360  }
2361 
2362 protected:
2363  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2364  tensor::PadOp padOp, OpTy op) const = 0;
2365 };
2366 
2367 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2368 /// ```
2369 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2370 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2371 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2372 /// ```
2373 /// is rewritten to:
2374 /// ```
2375 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2376 /// {in_bounds = [true, true]}
2377 /// : tensor<?x?xf32>, vector<17x5xf32>
2378 /// ```
2379 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2380 /// sure that the original padding value %cst was never used.
2381 ///
2382 /// This rewrite is possible if:
2383 /// - `xferOp` has no out-of-bounds dims or mask.
2384 /// - Low padding is static 0.
2385 /// - Single, scalar padding value.
2387  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2389  vector::TransferReadOp>::VectorizePadOpUserPattern;
2390 
2391  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2392  vector::TransferReadOp xferOp) const override {
2393  // Low padding must be static 0.
2394  if (!padOp.hasZeroLowPad())
2395  return failure();
2396  // Pad value must be a constant.
2397  auto padValue = padOp.getConstantPaddingValue();
2398  if (!padValue)
2399  return failure();
2400  // Padding value of existing `xferOp` is unused.
2401  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2402  return failure();
2403 
2404  rewriter.modifyOpInPlace(xferOp, [&]() {
2405  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2406  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2407  rewriter.getBoolArrayAttr(inBounds));
2408  xferOp.getSourceMutable().assign(padOp.getSource());
2409  xferOp.getPaddingMutable().assign(padValue);
2410  });
2411 
2412  return success();
2413  }
2414 };
2415 
2416 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2417 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2418 /// value, where the same amount of padding is immediately removed again after
2419 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2420 /// tensor value and apply out-of-bounds masking. E.g.:
2421 /// ```
2422 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2423 /// : tensor<...> to tensor<?x?xf32>
2424 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2425 /// %2 = vector.transfer_write %vec, %1[...]
2426 /// : vector<17x5xf32>, tensor<17x5xf32>
2427 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2428 /// : tensor<17x5xf32> to tensor<?x?xf32>
2429 /// ```
2430 /// is rewritten to:
2431 /// ```
2432 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2433 /// : tensor<...> to tensor<?x?xf32>
2434 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2435 /// tensor<?x?xf32>
2436 /// ```
2437 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2438 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2439 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2440 /// from %r's old dimensions.
2441 ///
2442 /// This rewrite is possible if:
2443 /// - Low padding is static 0.
2444 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2445 /// ExtractSliceOp trims the same amount of padding that was added
2446 /// beforehand.
2447 /// - Single, scalar padding value.
2449  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2451  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2452 
2453  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2454  vector::TransferWriteOp xferOp) const override {
2455  // TODO: support 0-d corner case.
2456  if (xferOp.getTransferRank() == 0)
2457  return failure();
2458 
2459  // Low padding must be static 0.
2460  if (!padOp.hasZeroLowPad())
2461  return failure();
2462  // Pad value must be a constant.
2463  auto padValue = padOp.getConstantPaddingValue();
2464  if (!padValue)
2465  return failure();
2466  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2467  if (!xferOp->hasOneUse())
2468  return failure();
2469  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2470  if (!trimPadding)
2471  return failure();
2472  // Only static zero offsets supported when trimming padding.
2473  if (!trimPadding.hasZeroOffset())
2474  return failure();
2475  // trimPadding must remove the amount of padding that was added earlier.
2476  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2477  return failure();
2478 
2479  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2480  rewriter.setInsertionPoint(xferOp);
2481 
2482  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2483  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2484  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2485  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2486  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2487  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2488 
2489  return success();
2490  }
2491 
2492  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2493  /// i.e., same dimensions.
2494  ///
2495  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2496  /// dimensions, this function tries to infer the (static) tensor size by
2497  /// looking at the defining op and utilizing op-specific knowledge.
2498  ///
2499  /// This is a conservative analysis. In case equal tensor sizes cannot be
2500  /// proven statically, this analysis returns `false` even though the tensor
2501  /// sizes may turn out to be equal at runtime.
2502  bool hasSameTensorSize(Value beforePadding,
2503  tensor::ExtractSliceOp afterTrimming) const {
2504  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2505  // result and CastOp operand.
2506  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2507  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2508  return true;
2509 
2510  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2511  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2512  // Only RankedTensorType supported.
2513  if (!t1 || !t2)
2514  return false;
2515  // Rank of both values must be the same.
2516  if (t1.getRank() != t2.getRank())
2517  return false;
2518 
2519  // All static dimensions must be the same. Mixed cases (e.g., dimension
2520  // static in `t1` but dynamic in `t2`) are not supported.
2521  for (unsigned i = 0; i < t1.getRank(); ++i) {
2522  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2523  return false;
2524  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2525  return false;
2526  }
2527 
2528  // Nothing more to check if all dimensions are static.
2529  if (t1.getNumDynamicDims() == 0)
2530  return true;
2531 
2532  // All dynamic sizes must be the same. The only supported case at the
2533  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2534  // thereof).
2535 
2536  // Apart from CastOp, only ExtractSliceOp is supported.
2537  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2538  if (!beforeSlice)
2539  return false;
2540 
2541  assert(static_cast<size_t>(t1.getRank()) ==
2542  beforeSlice.getMixedSizes().size());
2543  assert(static_cast<size_t>(t2.getRank()) ==
2544  afterTrimming.getMixedSizes().size());
2545 
2546  for (unsigned i = 0; i < t1.getRank(); ++i) {
2547  // Skip static dimensions.
2548  if (!t1.isDynamicDim(i))
2549  continue;
2550  auto size1 = beforeSlice.getMixedSizes()[i];
2551  auto size2 = afterTrimming.getMixedSizes()[i];
2552 
2553  // Case 1: Same value or same constant int.
2554  if (isEqualConstantIntOrValue(size1, size2))
2555  continue;
2556 
2557  // Other cases: Take a deeper look at defining ops of values.
2558  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2559  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2560  if (!v1 || !v2)
2561  return false;
2562 
2563  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2564  // CSE is run.)
2565  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2566  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2567  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2568  minOp1.getOperands() == minOp2.getOperands())
2569  continue;
2570 
2571  // Add additional cases as needed.
2572  }
2573 
2574  // All tests passed.
2575  return true;
2576  }
2577 };
2578 
2579 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2580 /// ```
2581 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2582 /// %r = tensor.insert_slice %0
2583 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2584 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2585 /// ```
2586 /// is rewritten to:
2587 /// ```
2588 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2589 /// : tensor<?x?xf32>, vector<17x5xf32>
2590 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2591 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2592 /// ```
2593 ///
2594 /// This rewrite is possible if:
2595 /// - Low padding is static 0.
2596 /// - `padOp` result shape is static.
2597 /// - The entire padded tensor is inserted.
2598 /// (Implies that sizes of `insertOp` are all static.)
2599 /// - Only unit strides in `insertOp`.
2600 /// - Single, scalar padding value.
2601 /// - `padOp` result not used as destination.
2603  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2605  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2606 
2607  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2608  tensor::InsertSliceOp insertOp) const override {
2609  // Low padding must be static 0.
2610  if (!padOp.hasZeroLowPad())
2611  return failure();
2612  // Only unit stride supported.
2613  if (!insertOp.hasUnitStride())
2614  return failure();
2615  // Pad value must be a constant.
2616  auto padValue = padOp.getConstantPaddingValue();
2617  if (!padValue)
2618  return failure();
2619  // Dynamic shapes not supported.
2620  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2621  return failure();
2622  // Pad result not used as destination.
2623  if (insertOp.getDest() == padOp.getResult())
2624  return failure();
2625 
2626  auto vecType = VectorType::get(padOp.getType().getShape(),
2627  padOp.getType().getElementType());
2628  unsigned vecRank = vecType.getRank();
2629  unsigned tensorRank = insertOp.getType().getRank();
2630 
2631  // Check if sizes match: Insert the entire tensor into most minor dims.
2632  // (No permutations allowed.)
2633  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2634  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2635  if (!llvm::all_of(
2636  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2637  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2638  }))
2639  return failure();
2640 
2641  // Insert the TransferReadOp and TransferWriteOp at the position of the
2642  // InsertSliceOp.
2643  rewriter.setInsertionPoint(insertOp);
2644 
2645  // Generate TransferReadOp: Read entire source tensor and add high
2646  // padding.
2647  SmallVector<Value> readIndices(
2648  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2649  auto read = rewriter.create<vector::TransferReadOp>(
2650  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2651 
2652  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2653  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2654  // source must fit into the destination at the specified offsets.
2655  auto writeIndices =
2656  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2657  SmallVector<bool> inBounds(vecRank, true);
2658  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2659  insertOp, read, insertOp.getDest(), writeIndices,
2660  ArrayRef<bool>{inBounds});
2661 
2662  return success();
2663  }
2664 };
2665 
2667  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2668  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2669  baseBenefit);
2670  // Try these specialized patterns first before resorting to the generic one.
2674  patterns.getContext(), baseBenefit.getBenefit() + 1);
2675 }
2676 
2677 //----------------------------------------------------------------------------//
2678 // Forwarding patterns
2679 //----------------------------------------------------------------------------//
2680 
2681 /// Check whether there is any interleaved use of any `values` between
2682 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2683 /// is in a different block.
2684 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2685  ValueRange values) {
2686  if (firstOp->getBlock() != secondOp->getBlock() ||
2687  !firstOp->isBeforeInBlock(secondOp)) {
2688  LDBG("interleavedUses precondition failed, firstOp: "
2689  << *firstOp << ", second op: " << *secondOp << "\n");
2690  return true;
2691  }
2692  for (auto v : values) {
2693  for (auto &u : v.getUses()) {
2694  Operation *owner = u.getOwner();
2695  if (owner == firstOp || owner == secondOp)
2696  continue;
2697  // TODO: this is too conservative, use dominance info in the future.
2698  if (owner->getBlock() == firstOp->getBlock() &&
2699  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2700  continue;
2701  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2702  << ", second op: " << *secondOp << "\n");
2703  return true;
2704  }
2705  }
2706  return false;
2707 }
2708 
2709 /// Return the unique subview use of `v` if it is indeed unique, null
2710 /// otherwise.
2711 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2712  memref::SubViewOp subViewOp;
2713  for (auto &u : v.getUses()) {
2714  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2715  if (subViewOp)
2716  return memref::SubViewOp();
2717  subViewOp = newSubViewOp;
2718  }
2719  }
2720  return subViewOp;
2721 }
2722 
2723 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2724 /// when available.
2726  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2727 
2728  // TODO: support mask.
2729  if (xferOp.getMask())
2730  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2731 
2732  // Transfer into `view`.
2733  Value viewOrAlloc = xferOp.getSource();
2734  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2735  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2736  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2737 
2738  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2739  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2740  if (!subViewOp)
2741  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2742  Value subView = subViewOp.getResult();
2743 
2744  // Find the copy into `subView` without interleaved uses.
2745  memref::CopyOp copyOp;
2746  for (auto &u : subView.getUses()) {
2747  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2748  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2749  if (newCopyOp.getTarget() != subView)
2750  continue;
2751  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2752  continue;
2753  copyOp = newCopyOp;
2754  break;
2755  }
2756  }
2757  if (!copyOp)
2758  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2759 
2760  // Find the fill into `viewOrAlloc` without interleaved uses before the
2761  // copy.
2762  FillOp maybeFillOp;
2763  for (auto &u : viewOrAlloc.getUses()) {
2764  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2765  assert(isa<MemRefType>(newFillOp.output().getType()));
2766  if (newFillOp.output() != viewOrAlloc)
2767  continue;
2768  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2769  continue;
2770  maybeFillOp = newFillOp;
2771  break;
2772  }
2773  }
2774  // Ensure padding matches.
2775  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2776  return rewriter.notifyMatchFailure(xferOp,
2777  "padding value does not match fill");
2778 
2779  // `in` is the subview that memref.copy reads. Replace it.
2780  Value in = copyOp.getSource();
2781 
2782  // memref.copy + linalg.fill can be used to create a padded local buffer.
2783  // The `masked` attribute is only valid on this padded buffer.
2784  // When forwarding to vector.transfer_read, the attribute must be reset
2785  // conservatively.
2786  auto vectorType = xferOp.getVectorType();
2787  Value res = rewriter.create<vector::TransferReadOp>(
2788  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2789  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2790  rewriter.getBoolArrayAttr(
2791  SmallVector<bool>(vectorType.getRank(), false)));
2792 
2793  if (maybeFillOp)
2794  rewriter.eraseOp(maybeFillOp);
2795  rewriter.eraseOp(copyOp);
2796  rewriter.replaceOp(xferOp, res);
2797 
2798  return success();
2799 }
2800 
2801 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2802 /// when available.
2804  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2805  // TODO: support mask.
2806  if (xferOp.getMask())
2807  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2808 
2809  // Transfer into `viewOrAlloc`.
2810  Value viewOrAlloc = xferOp.getSource();
2811  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2812  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2813  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2814 
2815  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2816  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2817  if (!subViewOp)
2818  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2819  Value subView = subViewOp.getResult();
2820 
2821  // Find the copy from `subView` without interleaved uses.
2822  memref::CopyOp copyOp;
2823  for (auto &u : subViewOp.getResult().getUses()) {
2824  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2825  if (newCopyOp.getSource() != subView)
2826  continue;
2827  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2828  continue;
2829  copyOp = newCopyOp;
2830  break;
2831  }
2832  }
2833  if (!copyOp)
2834  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2835 
2836  // `out` is the subview copied into that we replace.
2837  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2838  Value out = copyOp.getTarget();
2839 
2840  // Forward vector.transfer into copy.
2841  // memref.copy + linalg.fill can be used to create a padded local buffer.
2842  // The `masked` attribute is only valid on this padded buffer.
2843  // When forwarding to vector.transfer_write, the attribute must be reset
2844  // conservatively.
2845  auto vector = xferOp.getVector();
2846  rewriter.create<vector::TransferWriteOp>(
2847  xferOp.getLoc(), vector, out, xferOp.getIndices(),
2848  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2849  rewriter.getBoolArrayAttr(
2850  SmallVector<bool>(vector.getType().getRank(), false)));
2851 
2852  rewriter.eraseOp(copyOp);
2853  rewriter.eraseOp(xferOp);
2854 
2855  return success();
2856 }
2857 
2858 //===----------------------------------------------------------------------===//
2859 // Convolution vectorization patterns
2860 //===----------------------------------------------------------------------===//
2861 
2862 template <int N>
2863 static void bindShapeDims(ShapedType shapedType) {}
2864 
2865 template <int N, typename IntTy, typename... IntTy2>
2866 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2867  val = shapedType.getShape()[N];
2868  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2869 }
2870 
2871 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2872 template <typename... IntTy>
2873 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2874  bindShapeDims<0>(shapedType, vals...);
2875 }
2876 
2877 namespace {
2878 bool isCastOfBlockArgument(Operation *op) {
2879  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2880  isa<BlockArgument>(op->getOperand(0));
2881 }
2882 
2883 bool isSupportedPoolKind(vector::CombiningKind kind) {
2884  switch (kind) {
2885  case vector::CombiningKind::ADD:
2886  case vector::CombiningKind::MAXNUMF:
2887  case vector::CombiningKind::MAXIMUMF:
2888  case vector::CombiningKind::MAXSI:
2889  case vector::CombiningKind::MAXUI:
2890  case vector::CombiningKind::MINNUMF:
2891  case vector::CombiningKind::MINIMUMF:
2892  case vector::CombiningKind::MINSI:
2894  return true;
2895  default:
2896  return false;
2897  }
2898 }
2899 
2900 /// Generate a vector implementation for either:
2901 /// ```
2902 /// Op def: ( w, kw )
2903 /// Iters: ({Par(), Red()})
2904 /// Layout: {{w + kw}, {kw}, {w}}
2905 /// ```
2906 /// kw is unrolled.
2907 ///
2908 /// or
2909 ///
2910 /// ```
2911 /// Op def: ( n, w, c, kw, f )
2912 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2913 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2914 /// ```
2915 /// kw is unrolled, w is unrolled iff dilationW > 1.
2916 ///
2917 /// or
2918 ///
2919 /// ```
2920 /// Op def: ( n, c, w, f, kw )
2921 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2922 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2923 /// ```
2924 /// kw is unrolled, w is unrolled iff dilationW > 1.
2925 ///
2926 /// or
2927 ///
2928 /// ```
2929 /// Op def: ( n, w, c, kw )
2930 /// Iters: ({Par(), Par(), Par(), Red()})
2931 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2932 /// ```
2933 /// kw is unrolled, w is unrolled iff dilationW > 1.
2934 struct Conv1DGenerator
2935  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2936  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2937  int dilationW)
2938  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2939  strideW(strideW), dilationW(dilationW) {
2940  // Determine whether `linalgOp` can be generated with this generator
2941  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2942  return;
2943  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2944  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2945  resShaped = linalgOp.getDpsInitOperand(0)->get();
2946  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2947  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2948  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2949  if (!lhsShapedType || !rhsShapedType || !resShapedType)
2950  return;
2951  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2952  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2953  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2954  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2955  return;
2956 
2957  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2958  if (!reduceOp)
2959  return;
2960  redOp = reduceOp->getName().getIdentifier();
2961 
2962  if (!setOperKind(reduceOp))
2963  return;
2964  auto maybeKind = getCombinerOpKind(reduceOp);
2965  if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2966  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2967  return;
2968  }
2969 
2970  auto rhsRank = rhsShapedType.getRank();
2971  switch (oper) {
2972  case Conv:
2973  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2974  return;
2975  break;
2976  case Pool:
2977  if (rhsRank != 1)
2978  return;
2979  break;
2980  }
2981  // The op is now known to be valid.
2982  valid = true;
2983  }
2984 
2985  /// Generate a vector implementation for:
2986  /// ```
2987  /// Op def: ( w, kw )
2988  /// Iters: ({Par(), Red()})
2989  /// Layout: {{w + kw}, {kw}, {w}}
2990  /// ```
2991  /// kw is always unrolled.
2992  ///
2993  /// or
2994  ///
2995  /// ```
2996  /// Op def: ( n, w, c, kw, f )
2997  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2998  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2999  /// ```
3000  /// kw is always unrolled.
3001  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3002  /// > 1.
3003  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3004  if (!valid)
3005  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3006 
3007  int64_t nSize, wSize, cSize, kwSize, fSize;
3008  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3009  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3010  switch (conv1DOpOrder) {
3011  case Conv1DOpOrder::W:
3012  // Initialize unused dimensions
3013  nSize = fSize = cSize = 0;
3014  // out{W}
3015  bindShapeDims(resShapedType, wSize);
3016  // kernel{kw}
3017  bindShapeDims(rhsShapedType, kwSize);
3018  lhsShape = {// iw = ow + kw - 1
3019  // (i.e. 16 convolved with 3 -> 14)
3020  (wSize + kwSize - 1)};
3021  rhsShape = {kwSize};
3022  resShape = {wSize};
3023  break;
3024  case Conv1DOpOrder::Nwc:
3025  // out{n, w, f}
3026  bindShapeDims(resShapedType, nSize, wSize, fSize);
3027  switch (oper) {
3028  case Conv:
3029  // kernel{kw, c, f}
3030  bindShapeDims(rhsShapedType, kwSize, cSize);
3031  break;
3032  case Pool:
3033  // kernel{kw}
3034  bindShapeDims(rhsShapedType, kwSize);
3035  cSize = fSize;
3036  break;
3037  }
3038  lhsShape = {nSize,
3039  // iw = ow * sw + kw * dw - 1
3040  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3041  // Perform the proper inclusive -> exclusive -> inclusive.
3042  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3043  1,
3044  cSize};
3045  switch (oper) {
3046  case Conv:
3047  rhsShape = {kwSize, cSize, fSize};
3048  break;
3049  case Pool:
3050  rhsShape = {kwSize};
3051  break;
3052  }
3053  resShape = {nSize, wSize, fSize};
3054  break;
3055  case Conv1DOpOrder::Ncw:
3056  // out{n, f, w}
3057  bindShapeDims(resShapedType, nSize, fSize, wSize);
3058  switch (oper) {
3059  case Conv:
3060  // kernel{f, c, kw}
3061  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3062  break;
3063  case Pool:
3064  // kernel{kw}
3065  bindShapeDims(rhsShapedType, kwSize);
3066  cSize = fSize;
3067  break;
3068  }
3069  lhsShape = {nSize, cSize,
3070  // iw = ow * sw + kw * dw - 1
3071  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3072  // Perform the proper inclusive -> exclusive -> inclusive.
3073  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3074  1};
3075  switch (oper) {
3076  case Conv:
3077  rhsShape = {fSize, cSize, kwSize};
3078  break;
3079  case Pool:
3080  rhsShape = {kwSize};
3081  break;
3082  }
3083  resShape = {nSize, fSize, wSize};
3084  break;
3085  }
3086 
3087  vector::TransferWriteOp write;
3088  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3089 
3090  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3091  // When strideW == 1, we can batch the contiguous loads and avoid
3092  // unrolling
3093  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3094 
3095  Type lhsEltType = lhsShapedType.getElementType();
3096  Type rhsEltType = rhsShapedType.getElementType();
3097  Type resEltType = resShapedType.getElementType();
3098  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3099  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3100  auto resType = VectorType::get(resShape, resEltType);
3101  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3102  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3103  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3104  SmallVector<Value> resPadding(resShape.size(), zero);
3105 
3106  // Read the whole lhs, rhs and res in one shot (with zero padding).
3107  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3108  lhsPadding);
3109  // This is needed only for Conv.
3110  Value rhs = nullptr;
3111  if (oper == Conv)
3112  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3113  rhsPadding);
3114  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3115  resPadding);
3116 
3117  // The base vectorization case for channeled convolution is input:
3118  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3119  // vectorization case, we do pre transpose on input, weight, and output.
3120  switch (conv1DOpOrder) {
3121  case Conv1DOpOrder::W:
3122  case Conv1DOpOrder::Nwc:
3123  // Base case, so no transposes necessary.
3124  break;
3125  case Conv1DOpOrder::Ncw: {
3126  // To match base vectorization case, we pre-transpose current case.
3127  // ncw -> nwc
3128  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3129  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3130  // fcw -> wcf
3131  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3132 
3133  // This is needed only for Conv.
3134  if (oper == Conv)
3135  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3136  // nfw -> nwf
3137  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3138  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3139  break;
3140  }
3141  }
3142 
3143  //===------------------------------------------------------------------===//
3144  // Begin vector-only rewrite part
3145  //===------------------------------------------------------------------===//
3146  // Unroll along kw and read slices of lhs and rhs.
3147  SmallVector<Value> lhsVals, rhsVals, resVals;
3148  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3149  kwSize, strideW, dilationW, wSizeStep,
3150  isSingleChanneled);
3151  // Do not do for pooling.
3152  if (oper == Conv)
3153  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3154  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3155  wSizeStep, isSingleChanneled);
3156 
3157  auto linearIndex = [&](int64_t kw, int64_t w) {
3158  return kw * (wSize / wSizeStep) + w;
3159  };
3160 
3161  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3162  // or perform outerproduct for non-channeled convolution or perform simple
3163  // arith operation for pooling
3164  for (int64_t kw = 0; kw < kwSize; ++kw) {
3165  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3166  switch (oper) {
3167  case Conv:
3168  if (isSingleChanneled) {
3169  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3170  lhsVals[linearIndex(kw, w)],
3171  rhsVals[kw], resVals[w]);
3172  } else {
3173  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3174  lhsVals[linearIndex(kw, w)],
3175  rhsVals[kw], resVals[w]);
3176  }
3177  break;
3178  case Pool:
3179  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3180  resVals[w]);
3181  break;
3182  }
3183  }
3184  }
3185 
3186  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3187  isSingleChanneled);
3188  //===------------------------------------------------------------------===//
3189  // End vector-only rewrite part
3190  //===------------------------------------------------------------------===//
3191 
3192  // The base vectorization case for channeled convolution is output:
3193  // {n,w,f} To reuse the result from base pattern vectorization case, we
3194  // post transpose the base case result.
3195  switch (conv1DOpOrder) {
3196  case Conv1DOpOrder::W:
3197  case Conv1DOpOrder::Nwc:
3198  // Base case, so no transposes necessary.
3199  break;
3200  case Conv1DOpOrder::Ncw: {
3201  // nwf -> nfw
3202  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3203  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3204  break;
3205  }
3206  }
3207 
3208  return rewriter
3209  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3210  .getOperation();
3211  }
3212 
3213  // Take a value and widen to have the same element type as `ty`.
3214  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3215  const Type srcElementType = getElementTypeOrSelf(val.getType());
3216  const Type dstElementType = getElementTypeOrSelf(ty);
3217  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3218  if (srcElementType == dstElementType)
3219  return val;
3220 
3221  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3222  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3223  const Type dstType =
3224  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3225 
3226  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3227  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3228  }
3229 
3230  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3231  srcWidth < dstWidth)
3232  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3233 
3234  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3235  srcWidth < dstWidth)
3236  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3237 
3238  assert(false && "unhandled promotion case");
3239  return nullptr;
3240  }
3241 
3242  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3243  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3244  Value lhs, Value rhs, Value res) {
3245  vector::IteratorType par = vector::IteratorType::parallel;
3246  vector::IteratorType red = vector::IteratorType::reduction;
3247  AffineExpr n, w, f, c;
3248  bindDims(ctx, n, w, f, c);
3249  lhs = promote(rewriter, loc, lhs, res.getType());
3250  rhs = promote(rewriter, loc, rhs, res.getType());
3251  return rewriter.create<vector::ContractionOp>(
3252  loc, lhs, rhs, res,
3253  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3254  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3255  }
3256 
3257  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3258  // convolution.
3259  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3260  Value lhs, Value rhs, Value res) {
3261  return rewriter.create<vector::OuterProductOp>(
3262  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3263  }
3264 
3265  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3266  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3267  Value res) {
3268  if (isPoolExt)
3269  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3270  return rewriter
3271  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3272  ->getResult(0);
3273  }
3274 
3275  /// Generate a vector implementation for:
3276  /// ```
3277  /// Op def: ( n, w, c, kw)
3278  /// Iters: ({Par(), Par(), Par(), Red()})
3279  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3280  /// ```
3281  /// kw is always unrolled.
3282  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3283  /// > 1.
3284  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3285  bool channelDimScalableFlag,
3286  bool flatten) {
3287  if (!valid)
3288  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3289 
3290  bool scalableChDim = false;
3291  bool useMasking = false;
3292  int64_t nSize, wSize, cSize, kwSize;
3293  // kernel{kw, c}
3294  bindShapeDims(rhsShapedType, kwSize, cSize);
3295  if (ShapedType::isDynamic(cSize)) {
3296  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3297  cSize = channelDimVecSize;
3298  // Scalable vectors are only used when both conditions are met:
3299  // 1. channel dim is dynamic
3300  // 2. channelDimScalableFlag is set
3301  scalableChDim = channelDimScalableFlag;
3302  useMasking = true;
3303  }
3304 
3305  assert(!(useMasking && flatten) &&
3306  "Unsupported flattened conv with dynamic shapes");
3307 
3308  // out{n, w, c}
3309  bindShapeDims(resShapedType, nSize, wSize);
3310 
3311  vector::TransferWriteOp write;
3312  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3313 
3314  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3315  // When strideW == 1, we can batch the contiguous loads and avoid
3316  // unrolling
3317  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3318 
3319  Type lhsEltType = lhsShapedType.getElementType();
3320  Type rhsEltType = rhsShapedType.getElementType();
3321  Type resEltType = resShapedType.getElementType();
3322  VectorType lhsType = VectorType::get(
3323  {nSize,
3324  // iw = ow * sw + kw * dw - 1
3325  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3326  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3327  cSize},
3328  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3329  VectorType rhsType =
3330  VectorType::get({kwSize, cSize}, rhsEltType,
3331  /*scalableDims=*/{false, scalableChDim});
3332  VectorType resType =
3333  VectorType::get({nSize, wSize, cSize}, resEltType,
3334  /*scalableDims=*/{false, false, scalableChDim});
3335 
3336  // Masks the input xfer Op along the channel dim, iff the corresponding
3337  // scalable flag is set.
3338  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3339  ArrayRef<bool> scalableDims,
3340  Operation *opToMask) {
3341  if (!useMasking)
3342  return opToMask;
3343  auto maskType =
3344  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3345 
3346  SmallVector<bool> inBounds(maskShape.size(), true);
3347  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3348  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3349  rewriter.getBoolArrayAttr(inBounds));
3350 
3352  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3353 
3354  Value maskOp =
3355  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3356 
3357  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3358  };
3359 
3360  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3361  // 0].
3362  Value lhs = rewriter.create<vector::TransferReadOp>(
3363  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3364  auto maybeMaskedLhs = maybeMaskXferOp(
3365  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3366 
3367  // Read rhs slice of size {kw, c} @ [0, 0].
3368  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3369  ValueRange{zero, zero});
3370  auto maybeMaskedRhs = maybeMaskXferOp(
3371  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3372 
3373  // Read res slice of size {n, w, c} @ [0, 0, 0].
3374  Value res = rewriter.create<vector::TransferReadOp>(
3375  loc, resType, resShaped, ValueRange{zero, zero, zero});
3376  auto maybeMaskedRes = maybeMaskXferOp(
3377  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3378 
3379  //===------------------------------------------------------------------===//
3380  // Begin vector-only rewrite part
3381  //===------------------------------------------------------------------===//
3382  // Unroll along kw and read slices of lhs and rhs.
3383  SmallVector<Value> lhsVals, rhsVals, resVals;
3384  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3385  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3386 
3387  // Extract lhs slice of size {n, wSizeStep, c}
3388  // @ [0, sw * w + dw * kw, 0].
3389  for (int64_t kw = 0; kw < kwSize; ++kw) {
3390  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3391  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3392  loc, maybeMaskedLhs->getResult(0),
3393  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3394  inOutSliceSizes, inOutStrides));
3395  }
3396  }
3397  // Extract rhs slice of size {c} @ [kw].
3398  for (int64_t kw = 0; kw < kwSize; ++kw) {
3399  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3400  loc, maybeMaskedRhs->getResult(0),
3401  /*offsets=*/ArrayRef<int64_t>{kw}));
3402  }
3403  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3404  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3405  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3406  loc, maybeMaskedRes->getResult(0),
3407  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3408  inOutStrides));
3409  }
3410 
3411  auto linearIndex = [&](int64_t kw, int64_t w) {
3412  return kw * (wSize / wSizeStep) + w;
3413  };
3414 
3415  // Note - the scalable flags are ignored as flattening combined with
3416  // scalable vectorization is not supported.
3417  auto inOutFlattenSliceSizes =
3418  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3419  auto lhsTypeAfterFlattening =
3420  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3421  auto resTypeAfterFlattening =
3422  VectorType::get(inOutFlattenSliceSizes, resEltType);
3423 
3424  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3425  for (int64_t kw = 0; kw < kwSize; ++kw) {
3426  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3427  Value lhsVal = lhsVals[linearIndex(kw, w)];
3428  Value resVal = resVals[w];
3429  if (flatten) {
3430  // Flatten the input and output vectors (collapse the channel
3431  // dimension)
3432  lhsVal = rewriter.create<vector::ShapeCastOp>(
3433  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3434  resVal = rewriter.create<vector::ShapeCastOp>(
3435  loc, resTypeAfterFlattening, resVals[w]);
3436  }
3437  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3438  rhsVals[kw], resVal, flatten);
3439  if (flatten) {
3440  // Un-flatten the output vector (restore the channel dimension)
3441  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3442  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3443  }
3444  }
3445  }
3446 
3447  // Its possible we failed to create the Fma.
3448  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3449  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3450  for (auto &collection :
3451  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3452  for (Value v : collection)
3453  rewriter.eraseOp(v.getDefiningOp());
3454  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3455  }
3456 
3457  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3458  // This does not depend on kw.
3459  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3460  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3461  loc, resVals[w], maybeMaskedRes->getResult(0),
3462  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3463  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3464  }
3465  //===------------------------------------------------------------------===//
3466  // End vector-only rewrite part
3467  //===------------------------------------------------------------------===//
3468 
3469  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3470  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3471  loc, maybeMaskedRes->getResult(0), resShaped,
3472  ValueRange{zero, zero, zero});
3473  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3474  resOut);
3475  }
3476 
3477  /// Lower:
3478  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3479  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3480  /// to MulAcc.
3481  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3482  Value lhs, Value rhs, Value res,
3483  bool flatten) {
3484  auto rhsTy = cast<ShapedType>(rhs.getType());
3485  auto resTy = cast<ShapedType>(res.getType());
3486 
3487  // TODO(suderman): Change this to use a vector.ima intrinsic.
3488  lhs = promote(rewriter, loc, lhs, resTy);
3489 
3490  if (flatten) {
3491  // NOTE: This following logic won't work for scalable vectors. For this
3492  // reason, "flattening" is not supported when shapes are dynamic (this
3493  // should be captured by one of the pre-conditions).
3494 
3495  // There are two options for handling the filter:
3496  // * shape_cast(broadcast(filter))
3497  // * broadcast(shuffle(filter))
3498  // Opt for the option without shape_cast to simplify the codegen.
3499  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3500  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3501 
3502  SmallVector<int64_t, 16> indices;
3503  for (int i = 0; i < resSize / rhsSize; ++i) {
3504  for (int j = 0; j < rhsSize; ++j)
3505  indices.push_back(j);
3506  }
3507 
3508  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3509  }
3510  // Broadcast the filter to match the output vector
3511  rhs = rewriter.create<vector::BroadcastOp>(
3512  loc, resTy.clone(rhsTy.getElementType()), rhs);
3513 
3514  rhs = promote(rewriter, loc, rhs, resTy);
3515 
3516  if (!lhs || !rhs)
3517  return nullptr;
3518 
3519  if (isa<FloatType>(resTy.getElementType()))
3520  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3521 
3522  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3523  return rewriter.create<arith::AddIOp>(loc, mul, res);
3524  }
3525 
3526  /// Entry point for non-channeled convolution:
3527  /// {{w + kw}, {kw}, {w}}
3528  FailureOr<Operation *> generateNonChanneledConv() {
3529  AffineExpr w, kw;
3530  bindDims(ctx, w, kw);
3531  if (!iters({Par(), Red()}))
3532  return rewriter.notifyMatchFailure(op,
3533  "failed to match conv::W 1-par 1-red");
3534 
3535  // No transposition needed.
3536  if (layout({/*lhsIndex*/ {w + kw},
3537  /*rhsIndex*/ {kw},
3538  /*resIndex*/ {w}}))
3539  return conv(Conv1DOpOrder::W);
3540 
3541  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3542  }
3543 
3544  /// Entry point that transposes into the common form:
3545  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3546  FailureOr<Operation *> generateNwcConv() {
3547  AffineExpr n, w, f, kw, c;
3548  bindDims(ctx, n, w, f, kw, c);
3549  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3550  return rewriter.notifyMatchFailure(
3551  op, "failed to match conv::Nwc 3-par 2-red");
3552 
3553  // No transposition needed.
3554  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3555  /*rhsIndex*/ {kw, c, f},
3556  /*resIndex*/ {n, w, f}}))
3557  return conv(Conv1DOpOrder::Nwc);
3558 
3559  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3560  }
3561 
3562  /// Entry point that transposes into the common form:
3563  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3564  FailureOr<Operation *> generateNcwConv() {
3565  AffineExpr n, w, f, kw, c;
3566  bindDims(ctx, n, f, w, c, kw);
3567  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3568  return rewriter.notifyMatchFailure(
3569  op, "failed to match conv::Ncw 3-par 2-red");
3570 
3571  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3572  /*rhsIndex*/ {f, c, kw},
3573  /*resIndex*/ {n, f, w}}))
3574  return conv(Conv1DOpOrder::Ncw);
3575 
3576  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3577  }
3578 
3579  /// Entry point that transposes into the common form:
3580  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3581  FailureOr<Operation *> generateNwcPooling() {
3582  AffineExpr n, w, c, kw;
3583  bindDims(ctx, n, w, c, kw);
3584  if (!iters({Par(), Par(), Par(), Red()}))
3585  return rewriter.notifyMatchFailure(op,
3586  "failed to match pooling 3-par 1-red");
3587 
3588  // No transposition needed.
3589  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3590  /*rhsIndex*/ {kw},
3591  /*resIndex*/ {n, w, c}}))
3592  return conv(Conv1DOpOrder::Nwc);
3593 
3594  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3595  }
3596 
3597  /// Entry point that transposes into the common form:
3598  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3599  FailureOr<Operation *> generateNcwPooling() {
3600  AffineExpr n, w, c, kw;
3601  bindDims(ctx, n, c, w, kw);
3602  if (!iters({Par(), Par(), Par(), Red()}))
3603  return rewriter.notifyMatchFailure(op,
3604  "failed to match pooling 3-par 1-red");
3605 
3606  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3607  /*rhsIndex*/ {kw},
3608  /*resIndex*/ {n, c, w}}))
3609  return conv(Conv1DOpOrder::Ncw);
3610 
3611  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3612  }
3613 
3614  /// Entry point that transposes into the common form:
3615  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3616  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3617  bool vecChDimScalableFlag = false,
3618  bool flatten = false) {
3619  AffineExpr n, w, c, kw;
3620  bindDims(ctx, n, w, c, kw);
3621  if (!iters({Par(), Par(), Par(), Red()}))
3622  return rewriter.notifyMatchFailure(
3623  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3624 
3625  // No transposition needed.
3626  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3627  /*rhsIndex*/ {kw, c},
3628  /*resIndex*/ {n, w, c}}))
3629  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3630 
3631  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3632  }
3633 
3634 private:
3635  enum OperKind { Conv, Pool };
3636  bool valid = false;
3637  OperKind oper = Conv;
3638  StringAttr redOp;
3639  StringAttr poolExtOp;
3640  bool isPoolExt = false;
3641  int strideW, dilationW;
3642  Value lhsShaped, rhsShaped, resShaped;
3643  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3644 
3645  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3646  // Returns true iff it is a valid conv/pooling op.
3647  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3648  // + yield) and rhs is not used) then it is the body of a pooling
3649  // If conv, check for single `mul` predecessor. The `mul` operands must be
3650  // block arguments or extension of block arguments.
3651  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3652  // must be block arguments or extension of block arguments.
3653  bool setOperKind(Operation *reduceOp) {
3654  int numBlockArguments =
3655  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3656  switch (numBlockArguments) {
3657  case 1: {
3658  // Will be convolution if feeder is a MulOp.
3659  // Otherwise, if it can be pooling.
3660  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3661  llvm::IsaPred<BlockArgument>);
3662  Operation *feedOp = (*feedValIt).getDefiningOp();
3663  if (isCastOfBlockArgument(feedOp)) {
3664  oper = Pool;
3665  isPoolExt = true;
3666  poolExtOp = feedOp->getName().getIdentifier();
3667  } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3668  llvm::all_of(feedOp->getOperands(), [](Value v) {
3669  if (isa<BlockArgument>(v))
3670  return true;
3671  if (Operation *op = v.getDefiningOp())
3672  return isCastOfBlockArgument(op);
3673  return false;
3674  }))) {
3675  return false;
3676  }
3677  return true;
3678  }
3679  case 2:
3680  // Must be pooling
3681  oper = Pool;
3682  isPoolExt = false;
3683  return true;
3684  default:
3685  return false;
3686  }
3687  }
3688 };
3689 } // namespace
3690 
3691 /// Helper function to vectorize a LinalgOp with convolution semantics.
3692 // TODO: extend the generic vectorization to support windows and drop this.
3693 static FailureOr<Operation *> vectorizeConvolution(
3694  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3695  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3696  // The ConvolutionOpInterface gives us guarantees of existence for
3697  // strides/dilations. However, we do not need to rely on those, we can
3698  // simply use them if present, otherwise use the default and let the generic
3699  // conv. matcher in the ConvGenerator succeed or fail.
3700  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3701  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3702  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3703  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3704  Conv1DGenerator e(rewriter, op, stride, dilation);
3705  auto res = e.generateNonChanneledConv();
3706  if (succeeded(res))
3707  return res;
3708  res = e.generateNwcConv();
3709  if (succeeded(res))
3710  return res;
3711  res = e.generateNcwConv();
3712  if (succeeded(res))
3713  return res;
3714  res = e.generateNwcPooling();
3715  if (succeeded(res))
3716  return res;
3717  res = e.generateNcwPooling();
3718  if (succeeded(res))
3719  return res;
3720 
3721  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3722  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3723  // masked/scalable) is the channel dim (i.e. the trailing dim).
3724  uint64_t vecChDimSize = ShapedType::kDynamic;
3725  bool vecChDimScalableFlag = false;
3726  if (!inputVecSizes.empty()) {
3727  // Only use the input vector size corresponding to the channel dim. Other
3728  // vector dims will be inferred from the Ops.
3729  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3730  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3731  "Not a 1D depthwise conv!");
3732  size_t chDimIdx =
3734  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3735  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3736 
3737  vecChDimSize = inputVecSizes[chDimIdx];
3738  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3739  }
3740  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3741  flatten1DDepthwiseConv);
3742 }
3743 
3746 
3747  LogicalResult matchAndRewrite(LinalgOp op,
3748  PatternRewriter &rewriter) const override {
3749  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3750  if (failed(resultOrFail))
3751  return failure();
3752  Operation *newOp = *resultOrFail;
3753  if (newOp->getNumResults() == 0) {
3754  rewriter.eraseOp(op.getOperation());
3755  return success();
3756  }
3757  assert(newOp->getNumResults() == 1 && "expected single result");
3758  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3759  return success();
3760  }
3761 };
3762 
3764  RewritePatternSet &patterns, PatternBenefit benefit) {
3765  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3766 }
static VectorShape vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:595
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:142
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:625
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:290
OpListType & getOperations()
Definition: Block.h:135
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:398
IntegerType getI32Type()
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:77
IndexType getIndexType()
Definition: Builders.cpp:75
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:281
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:448
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:215
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2376
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:657
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:792
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:62
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:699
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1453
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.