MLIR  18.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Support/LLVM.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <optional>
36 #include <type_traits>
37 
38 using namespace mlir;
39 using namespace mlir::linalg;
40 
41 #define DEBUG_TYPE "linalg-vectorization"
42 
43 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
44 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
45 
46 /// Try to vectorize `convOp` as a convolution.
48  LinalgOp convOp);
49 
50 /// Return the unique instance of OpType in `block` if it is indeed unique.
51 /// Return null if none or more than 1 instances exist.
52 template <typename OpType>
53 static OpType getSingleOpOfType(Block &block) {
54  OpType res;
55  block.walk([&](OpType op) {
56  if (res) {
57  res = nullptr;
58  return WalkResult::interrupt();
59  }
60  res = op;
61  return WalkResult::advance();
62  });
63  return res;
64 }
65 
66 /// Helper function to extract the input slices after filter is unrolled along
67 /// kw.
68 static SmallVector<Value>
70  int64_t nSize, int64_t wSize, int64_t cSize,
71  int64_t kwSize, int strideW, int dilationW,
72  int64_t wSizeStep, bool isSingleChanneled) {
73  SmallVector<Value> result;
74  if (isSingleChanneled) {
75  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
76  // convolution.
77  SmallVector<int64_t> sizes{wSizeStep};
78  SmallVector<int64_t> strides{1};
79  for (int64_t kw = 0; kw < kwSize; ++kw) {
80  for (int64_t w = 0; w < wSize; w += wSizeStep) {
81  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
82  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
83  }
84  }
85  } else {
86  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
87  // for channeled convolution.
88  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
89  SmallVector<int64_t> strides{1, 1, 1};
90  for (int64_t kw = 0; kw < kwSize; ++kw) {
91  for (int64_t w = 0; w < wSize; w += wSizeStep) {
92  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
93  loc, input,
94  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
95  sizes, strides));
96  }
97  }
98  }
99  return result;
100 }
101 
102 /// Helper function to extract the filter slices after filter is unrolled along
103 /// kw.
105  Location loc, Value filter,
106  int64_t kwSize) {
107  SmallVector<Value> result;
108  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
109  // non-chanelled convolution] @ [kw].
110  for (int64_t kw = 0; kw < kwSize; ++kw) {
111  result.push_back(rewriter.create<vector::ExtractOp>(
112  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
113  }
114  return result;
115 }
116 
117 /// Helper function to extract the result slices after filter is unrolled along
118 /// kw.
119 static SmallVector<Value>
121  int64_t nSize, int64_t wSize, int64_t fSize,
122  int64_t wSizeStep, bool isSingleChanneled) {
123  SmallVector<Value> result;
124  if (isSingleChanneled) {
125  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
126  SmallVector<int64_t> sizes{wSizeStep};
127  SmallVector<int64_t> strides{1};
128  for (int64_t w = 0; w < wSize; w += wSizeStep) {
129  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
130  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
131  }
132  } else {
133  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
134  // convolution.
135  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
136  SmallVector<int64_t> strides{1, 1, 1};
137  for (int64_t w = 0; w < wSize; w += wSizeStep) {
138  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
139  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
140  }
141  }
142  return result;
143 }
144 
145 /// Helper function to insert the computed result slices.
147  Value res, int64_t wSize, int64_t wSizeStep,
148  SmallVectorImpl<Value> &resVals,
149  bool isSingleChanneled) {
150 
151  if (isSingleChanneled) {
152  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
153  // This does not depend on kw.
154  SmallVector<int64_t> strides{1};
155  for (int64_t w = 0; w < wSize; w += wSizeStep) {
156  res = rewriter.create<vector::InsertStridedSliceOp>(
157  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
158  }
159  } else {
160  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
161  // convolution. This does not depend on kw.
162  SmallVector<int64_t> strides{1, 1, 1};
163  for (int64_t w = 0; w < wSize; w += wSizeStep) {
164  res = rewriter.create<vector::InsertStridedSliceOp>(
165  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
166  strides);
167  }
168  }
169  return res;
170 }
171 
172 /// Contains the vectorization state and related methods used across the
173 /// vectorization process of a given operation.
175  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
176 
177  /// Initializes the vectorization state, including the computation of the
178  /// canonical vector shape for vectorization.
179  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
180  ArrayRef<int64_t> inputVectorSizes,
181  ArrayRef<bool> inputScalableVecDims);
182 
183  /// Returns the canonical vector shape used to vectorize the iteration space.
184  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
185 
186  /// Returns a vector type of the provided `elementType` with the canonical
187  /// vector shape and the corresponding fixed/scalable dimensions bit. If
188  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
189  /// accordingly.
191  Type elementType,
192  std::optional<AffineMap> dimPermutation = std::nullopt) const {
194  SmallVector<bool> scalableDims;
195  if (dimPermutation.has_value()) {
196  vectorShape =
197  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
198  scalableDims =
199  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
200  } else {
201  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
202  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
203  }
204 
205  return VectorType::get(vectorShape, elementType, scalableDims);
206  }
207 
208  /// Masks an operation with the canonical vector mask if the operation needs
209  /// masking. Returns the masked operation or the original operation if masking
210  /// is not needed. If provided, the canonical mask for this operation is
211  /// permuted using `maybeMaskingMap`.
212  Operation *
213  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
214  std::optional<AffineMap> maybeMaskingMap = std::nullopt);
215 
216 private:
217  /// Initializes the iteration space static sizes using the Linalg op
218  /// information. This may become more complicated in the future.
219  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
220  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
221  }
222 
223  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
224  /// all the static and dynamic dimensions of the iteration space to be
225  /// vectorized and store them in `iterSpaceValueSizes`.
226  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
227  LinalgOp linalgOp);
228 
229  /// Create or retrieve an existing mask value to mask `opToMask` in the
230  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
231  /// permuted using that permutation map. If a new mask is created, it will be
232  /// cached for future users.
233  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
234  LinalgOp linalgOp,
235  std::optional<AffineMap> maybeMaskingMap);
236 
237  // Holds the compile-time static sizes of the iteration space to vectorize.
238  // Dynamic dimensions are represented using ShapedType::kDynamic.
239  SmallVector<int64_t> iterSpaceStaticSizes;
240 
241  /// Holds the value sizes of the iteration space to vectorize. Static
242  /// dimensions are represented by 'arith.constant' and dynamic
243  /// dimensions by 'tensor/memref.dim'.
244  SmallVector<Value> iterSpaceValueSizes;
245 
246  /// Holds the canonical vector shape used to vectorize the iteration space.
247  SmallVector<int64_t> canonicalVecShape;
248 
249  /// Holds the vector dimensions that are scalable in the canonical vector
250  /// shape.
251  SmallVector<bool> scalableVecDims;
252 
253  /// Holds the active masks for permutations of the canonical vector iteration
254  /// space.
255  DenseMap<AffineMap, Value> activeMaskCache;
256 
257  /// Global vectorization guard for the incoming rewriter. It's initialized
258  /// when the vectorization state is initialized.
260 };
261 
263 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
264  LinalgOp linalgOp) {
265  // TODO: Support 0-d vectors.
266  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
267  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
268  // Create constant index op for static dimensions.
269  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
270  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
271  continue;
272  }
273 
274  // Find an operand defined on this dimension of the iteration space to
275  // extract the runtime dimension size.
276  Value operand;
277  unsigned operandDimPos;
278  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
279  operandDimPos)))
280  return failure();
281 
282  Value dynamicDim = linalgOp.hasTensorSemantics()
283  ? (Value)rewriter.create<tensor::DimOp>(
284  linalgOp.getLoc(), operand, operandDimPos)
285  : (Value)rewriter.create<memref::DimOp>(
286  linalgOp.getLoc(), operand, operandDimPos);
287  iterSpaceValueSizes.push_back(dynamicDim);
288  }
289 
290  return success();
291 }
292 
293 /// Initializes the vectorization state, including the computation of the
294 /// canonical vector shape for vectorization.
295 // TODO: Move this to the constructor when we can remove the failure cases.
297 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
298  ArrayRef<int64_t> inputVectorSizes,
299  ArrayRef<bool> inputScalableVecDims) {
300  // Initialize the insertion point.
301  rewriter.setInsertionPoint(linalgOp);
302 
303  if (!inputVectorSizes.empty()) {
304  // Get the canonical vector shape from the input vector sizes provided. This
305  // path should be taken to vectorize code with dynamic shapes and when using
306  // vector sizes greater than the iteration space sizes.
307  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
308  scalableVecDims.append(inputScalableVecDims.begin(),
309  inputScalableVecDims.end());
310  } else {
311  // Compute the canonical vector shape from the operation shape. If there are
312  // dynamic shapes, the operation won't be vectorized. We assume all the
313  // vector dimensions are fixed.
314  canonicalVecShape = linalgOp.getStaticLoopRanges();
315  scalableVecDims.append(linalgOp.getNumLoops(), false);
316  }
317 
318  LDBG("Canonical vector shape: ");
319  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
320  LLVM_DEBUG(llvm::dbgs() << "\n");
321  LDBG("Scalable vector dims: ");
322  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
323  LLVM_DEBUG(llvm::dbgs() << "\n");
324 
325  if (ShapedType::isDynamicShape(canonicalVecShape))
326  return failure();
327 
328  // Initialize iteration space static sizes.
329  initIterSpaceStaticSizes(linalgOp);
330 
331  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
332  // all the static and dynamic dimensions of the iteration space, needed to
333  // compute a mask during vectorization.
334  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
335  return failure();
336 
337  return success();
338 }
339 
340 /// Create or retrieve an existing mask value to mask `opToMask` in the
341 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
342 /// using that permutation map. If a new mask is created, it will be cached for
343 /// future users.
344 Value VectorizationState::getOrCreateMaskFor(
345  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
346  std::optional<AffineMap> maybeMaskingMap) {
347  // No mask is needed if the operation is not maskable.
348  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
349  if (!maskableOp)
350  return Value();
351 
352  assert(!maskableOp.isMasked() &&
353  "Masking an operation that is already masked");
354 
355  // If no masking map was provided, use an identity map with the loop dims.
356  assert((!maybeMaskingMap || *maybeMaskingMap) &&
357  "Unexpected null mask permutation map");
358  AffineMap maskingMap =
359  maybeMaskingMap ? *maybeMaskingMap
361  linalgOp.getNumLoops(), rewriter.getContext());
362 
363  LDBG("Masking map: " << maskingMap << "\n");
364 
365  // Return the active mask for the masking map of this operation if it was
366  // already created.
367  auto activeMaskIt = activeMaskCache.find(maskingMap);
368  if (activeMaskIt != activeMaskCache.end()) {
369  Value mask = activeMaskIt->second;
370  LDBG("Reusing mask: " << mask << "\n");
371  return mask;
372  }
373 
374  // Compute permuted projection of the iteration space to be masked and the
375  // corresponding mask shape. If the resulting iteration space dimensions are
376  // static and identical to the mask shape, masking is not needed for this
377  // operation.
378  // TODO: Improve this check. Only projected permutation indexing maps are
379  // supported.
380  SmallVector<int64_t> permutedStaticSizes =
381  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
382  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
383  auto maskShape = maskType.getShape();
384 
385  LDBG("Mask shape: ");
386  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
387  LLVM_DEBUG(llvm::dbgs() << "\n");
388 
389  if (permutedStaticSizes == maskShape) {
390  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
391  activeMaskCache[maskingMap] = Value();
392  return Value();
393  }
394 
395  // Permute the iteration space value sizes to compute the mask upper bounds.
396  SmallVector<Value> upperBounds =
397  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
398  assert(!maskShape.empty() && !upperBounds.empty() &&
399  "Masked 0-d vectors are not supported yet");
400 
401  // Create the mask based on the dimension values.
402  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
403  maskType, upperBounds);
404  LDBG("Creating new mask: " << mask << "\n");
405  activeMaskCache[maskingMap] = mask;
406  return mask;
407 }
408 
409 /// Masks an operation with the canonical vector mask if the operation needs
410 /// masking. Returns the masked operation or the original operation if masking
411 /// is not needed. If provided, the canonical mask for this operation is
412 /// permuted using `maybeMaskingMap`.
413 Operation *
415  LinalgOp linalgOp,
416  std::optional<AffineMap> maybeMaskingMap) {
417  LDBG("Trying to mask: " << *opToMask << "\n");
418 
419  // Create or retrieve mask for this operation.
420  Value mask =
421  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
422 
423  if (!mask) {
424  LDBG("No mask required\n");
425  return opToMask;
426  }
427 
428  // Wrap the operation with a new `vector.mask` and update D-U chain.
429  assert(opToMask && "Expected a valid operation to mask");
430  auto maskOp = cast<vector::MaskOp>(
431  mlir::vector::maskOperation(rewriter, opToMask, mask));
432  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
433 
434  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
435  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
436  maskOpTerminator);
437 
438  LDBG("Masked operation: " << *maskOp << "\n");
439  return maskOp;
440 }
441 
442 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
443 /// projectedPermutation, compress the unused dimensions to serve as a
444 /// permutation_map for a vector transfer operation.
445 /// For example, given a linalg op such as:
446 ///
447 /// ```
448 /// %0 = linalg.generic {
449 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
450 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
451 /// }
452 /// ins(%0 : tensor<2x3x4xf32>)
453 /// outs(%1 : tensor<5x6xf32>)
454 /// ```
455 ///
456 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
457 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
458 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
460  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
461  "expected projected permutation");
462  auto res = compressUnusedDims(map);
463  assert(res.getNumDims() == res.getNumResults() &&
464  "expected reindexed map with same number of dims and results");
465  return res;
466 }
467 
468 /// Helper enum to represent conv1d input traversal order.
469 enum class Conv1DOpOrder {
470  W, // Corresponds to non-channeled 1D convolution operation.
471  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
472  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
473 };
474 
475 /// Helper data structure to represent the result of vectorization.
476 /// In certain specific cases, like terminators, we do not want to propagate/
478  /// Op failed to vectorize.
479  Failure = 0,
480  /// Op vectorized and custom function took care of replacement logic
482  /// Op vectorized into a new Op whose results will replace original Op's
483  /// results.
484  NewOp
485  // TODO: support values if Op vectorized to Many-Ops whose results we need to
486  // aggregate for replacement.
487 };
489  /// Return status from vectorizing the current op.
491  /// New vectorized operation to replace the current op.
492  /// Replacement behavior is specified by `status`.
494 };
495 
496 std::optional<vector::CombiningKind>
498  using ::mlir::vector::CombiningKind;
499 
500  if (!combinerOp)
501  return std::nullopt;
503  .Case<arith::AddIOp, arith::AddFOp>(
504  [&](auto op) { return CombiningKind::ADD; })
505  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
506  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
507  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
508  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
509  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
510  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
511  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
512  .Case<arith::MulIOp, arith::MulFOp>(
513  [&](auto op) { return CombiningKind::MUL; })
514  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
515  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
516  .Default([&](auto op) { return std::nullopt; });
517 }
518 
519 /// Check whether `outputOperand` is a reduction with a single combiner
520 /// operation. Return the combiner operation of the reduction. Return
521 /// nullptr otherwise. Multiple reduction operations would impose an
522 /// ordering between reduction dimensions and is currently unsupported in
523 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
524 /// max(min(X))
525 // TODO: use in LinalgOp verification, there is a circular dependency atm.
526 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
527  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
528  unsigned outputPos =
529  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
530  // Only single combiner operations are supported for now.
531  SmallVector<Operation *, 4> combinerOps;
532  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
533  combinerOps.size() != 1)
534  return nullptr;
535 
536  // Return the combiner operation.
537  return combinerOps[0];
538 }
539 
540 /// Broadcast `value` to a vector of `shape` if possible. Return value
541 /// otherwise.
542 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
543  auto dstVecType = dyn_cast<VectorType>(dstType);
544  // If no shape to broadcast to, just return `value`.
545  if (dstVecType.getRank() == 0)
546  return value;
547  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
549  return value;
550  Location loc = b.getInsertionPoint()->getLoc();
551  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
552 }
553 
554 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
555 /// assumes that `reductionOp` has two operands and one of them is the reduction
556 /// initial value.buildMultiDimReduce
557 // Note: this is a true builder that notifies the OpBuilder listener.
558 // TODO: Consider moving as a static helper on the ReduceOp.
560  Value valueToReduce, Value acc,
561  ArrayRef<bool> dimsToMask) {
562  auto maybeKind = getCombinerOpKind(reduceOp);
563  assert(maybeKind && "Failed precondition: could not get reduction kind");
564  return b.create<vector::MultiDimReductionOp>(
565  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
566 }
567 
568 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
569  return llvm::to_vector(
570  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
571 }
572 
573 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
574 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
575 /// currently being vectorized. If `dest` has null rank, build an memref.store.
576 /// Return the produced value or null if no value is produced.
577 // Note: this is a true builder that notifies the OpBuilder listener.
578 // TODO: Consider moving as a static helper on the ReduceOp.
579 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
580  OpOperand *outputOperand,
581  VectorizationState &state) {
582  Location loc = value.getLoc();
583  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
584  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
585 
586  // Compute the vector type of the value to store. This type should be an
587  // identity or projection of the canonical vector type without any permutation
588  // applied, given that any permutation in a transfer write happens as part of
589  // the write itself.
591  opOperandMap.getContext(), opOperandMap.getNumInputs(),
592  [&](AffineDimExpr dimExpr) -> bool {
593  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
594  });
595  auto vectorType = state.getCanonicalVecType(
596  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
597 
598  Operation *write;
599  if (vectorType.getRank() > 0) {
600  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
601  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
602  rewriter.create<arith::ConstantIndexOp>(loc, 0));
603  value = broadcastIfNeeded(rewriter, value, vectorType);
604  assert(value.getType() == vectorType && "Incorrect type");
605  write = rewriter.create<vector::TransferWriteOp>(
606  loc, value, outputOperand->get(), indices, writeMap);
607  } else {
608  // 0-d case is still special: do not invert the reindexing writeMap.
609  if (!isa<VectorType>(value.getType()))
610  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
611  assert(value.getType() == vectorType && "Incorrect type");
612  write = rewriter.create<vector::TransferWriteOp>(
613  loc, value, outputOperand->get(), ValueRange{});
614  }
615 
616  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
617 
618  // If masked, set in-bounds to true. Masking guarantees that the access will
619  // be in-bounds.
620  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
621  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
622  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
623  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
624  }
625 
626  LDBG("vectorized op: " << *write << "\n");
627  if (!write->getResults().empty())
628  return write->getResult(0);
629  return Value();
630 }
631 
632 // Custom vectorization precondition function type. This is intented to be used
633 // with CustomVectorizationHook. Returns success if the corresponding custom
634 // hook can vectorize the op.
636  std::function<LogicalResult(Operation *, bool)>;
637 
638 // Custom vectorization function type. Produce a vector form of Operation*
639 // assuming all its vectorized operands are already in the IRMapping.
640 // Return nullptr if the Operation cannot be vectorized.
642  std::function<VectorizationResult(Operation *, const IRMapping &)>;
643 
644 /// Helper function to vectorize the terminator of a `linalgOp`. New result
645 /// vector values are appended to `newResults`. Return
646 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
647 /// should not try to map produced operations and instead return the results
648 /// using the `newResults` vector making them available to the vectorization
649 /// algorithm for RAUW. This function is meant to be used as a
650 /// CustomVectorizationHook.
651 static VectorizationResult
653  const IRMapping &bvm, VectorizationState &state,
654  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
655  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
656  if (!yieldOp)
658  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
659  // TODO: Scan for an opportunity for reuse.
660  // TODO: use a map.
661  Value vectorValue = bvm.lookup(output.value());
662  Value newResult =
663  buildVectorWrite(rewriter, vectorValue,
664  linalgOp.getDpsInitOperand(output.index()), state);
665  if (newResult)
666  newResults.push_back(newResult);
667  }
668 
670 }
671 
672 /// Helper function to vectorize the index operations of a `linalgOp`. Return
673 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
674 /// should map the produced operations. This function is meant to be used as a
675 /// CustomVectorizationHook.
677  VectorizationState &state,
678  Operation *op,
679  LinalgOp linalgOp) {
680  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
681  if (!indexOp)
683  auto loc = indexOp.getLoc();
684  // Compute the static loop sizes of the index op.
685  auto targetShape = state.getCanonicalVecShape();
686  // Compute a one-dimensional index vector for the index op dimension.
687  auto constantSeq =
688  llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
689  auto indexSteps = rewriter.create<arith::ConstantOp>(
690  loc, rewriter.getIndexVectorAttr(constantSeq));
691  // Return the one-dimensional index vector if it lives in the trailing
692  // dimension of the iteration space since the vectorization algorithm in this
693  // case can handle the broadcast.
694  if (indexOp.getDim() == targetShape.size() - 1)
696  // Otherwise permute the targetShape to move the index dimension last,
697  // broadcast the one-dimensional index vector to the permuted shape, and
698  // finally transpose the broadcasted index vector to undo the permutation.
699  auto permPattern =
700  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
701  std::swap(permPattern[indexOp.getDim()], permPattern.back());
702  auto permMap =
703  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
704 
705  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
706  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
707  indexSteps);
708  SmallVector<int64_t> transposition =
709  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
710  std::swap(transposition.back(), transposition[indexOp.getDim()]);
711  auto transposeOp =
712  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
713  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
714 }
715 
716 /// Helper function to check if the tensor.extract can be vectorized by the
717 /// custom hook vectorizeTensorExtract.
718 static LogicalResult
719 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
720  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
721  if (!extractOp)
722  return failure();
723 
724  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
725  return failure();
726 
727  // Check the index type, but only for non 0-d tensors (for which we do need
728  // access indices).
729  if (not extractOp.getIndices().empty()) {
730  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
731  return failure();
732  }
733 
734  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
735  return !VectorType::isValidElementType(type);
736  })) {
737  return failure();
738  }
739 
740  return success();
741 }
742 
743 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
744 /// generated from `tensor.extract`. The offset is calculated as follows
745 /// (example using scalar values):
746 ///
747 /// offset = extractOp.indices[0]
748 /// for (i = 1; i < numIndices; i++)
749 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
750 ///
751 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
752 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
754  VectorizationState &state,
755  tensor::ExtractOp extractOp,
756  const IRMapping &bvm) {
757  // The vector of indices for GatherOp should be shaped as the output vector.
758  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
759  auto loc = extractOp.getLoc();
760 
761  Value offset = broadcastIfNeeded(
762  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
763 
764  const size_t numIndices = extractOp.getIndices().size();
765  for (size_t i = 1; i < numIndices; i++) {
766  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
767 
768  auto dimSize = broadcastIfNeeded(
769  rewriter,
770  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
771  indexVecType);
772 
773  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
774 
775  auto extractOpIndex = broadcastIfNeeded(
776  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
777 
778  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
779  }
780 
781  return offset;
782 }
783 
785 
786 /// Checks whether /p val can be used for calculating a loop invariant index.
787 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
788 
789  auto targetShape = linalgOp.getStaticLoopRanges();
790  assert(((llvm::count_if(targetShape,
791  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
792  "n-D vectors are not yet supported");
793  assert(targetShape.back() != 1 &&
794  "1-D vectors with the trailing dim eqaual 1 are not yet supported");
795 
796  // Blocks outside _this_ linalg.generic are effectively loop invariant.
797  // However, analysing block arguments for _this_ linalg.generic Op is a bit
798  // tricky. Just bail out in the latter case.
799  // TODO: We could try analysing the corresponding affine map here.
800  auto *block = linalgOp.getBlock();
801  if (isa<BlockArgument>(val))
802  return llvm::all_of(block->getArguments(),
803  [&val](Value v) { return (v != val); });
804 
805  Operation *defOp = val.getDefiningOp();
806  assert(defOp && "This is neither a block argument nor an operation result");
807 
808  // IndexOp is loop invariant as long as its result remains constant across
809  // iterations. Given the assumptions on the loop ranges above, only the
810  // trailing loop dim ever changes.
811  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
812  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
813  return (indexOp.getDim() != trailingLoopDim);
814 
815  auto *ancestor = block->findAncestorOpInBlock(*defOp);
816 
817  // Values define outside `linalgOp` are loop invariant.
818  if (!ancestor)
819  return true;
820 
821  // Values defined inside `linalgOp`, which are constant, are loop invariant.
822  if (isa<arith::ConstantOp>(ancestor))
823  return true;
824 
825  bool result = true;
826  for (auto op : ancestor->getOperands())
827  result &= isLoopInvariantIdx(linalgOp, op);
828 
829  return result;
830 }
831 
832 /// Check whether \p val could be used for calculating the trailing index for a
833 /// contiguous load operation.
834 ///
835 /// There are currently 3 types of values that are allowed here:
836 /// 1. loop-invariant values,
837 /// 2. values that increment by 1 with every loop iteration,
838 /// 3. results of basic arithmetic operations (linear and continuous)
839 /// involving 1., 2. and 3.
840 /// This method returns True if indeed only such values are used in calculating
841 /// \p val.
842 ///
843 /// Additionally, the trailing index for a contiguous load operation should
844 /// increment by 1 with every loop iteration, i.e. be based on:
845 /// * `linalg.index <dim>` ,
846 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
847 /// updated to `true` when such an op is found.
848 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
849  bool &foundIndexOp) {
850 
851  auto targetShape = linalgOp.getStaticLoopRanges();
852  assert(((llvm::count_if(targetShape,
853  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
854  "n-D vectors are not yet supported");
855  assert(targetShape.back() != 1 &&
856  "1-D vectors with the trailing dim 1 are not yet supported");
857 
858  // Blocks outside _this_ linalg.generic are effectively loop invariant.
859  // However, analysing block arguments for _this_ linalg.generic Op is a bit
860  // tricky. Just bail out in the latter case.
861  // TODO: We could try analysing the corresponding affine map here.
862  auto *block = linalgOp.getBlock();
863  if (isa<BlockArgument>(val))
864  return llvm::all_of(block->getArguments(),
865  [&val](Value v) { return (v != val); });
866 
867  Operation *defOp = val.getDefiningOp();
868  assert(defOp && "This is neither a block argument nor an operation result");
869 
870  // Given the assumption on the loop ranges above, only the trailing loop
871  // index is not constant.
872  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
873  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
874  foundIndexOp = (indexOp.getDim() == trailingLoopDim);
875  return true;
876  }
877 
878  auto *ancestor = block->findAncestorOpInBlock(*defOp);
879 
880  if (!ancestor)
881  return false;
882 
883  // Conservatively reject Ops that could lead to indices with stride other
884  // than 1.
885  if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
886  ancestor))
887  return false;
888 
889  bool result = false;
890  for (auto op : ancestor->getOperands())
891  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
892 
893  return result;
894 }
895 
896 /// Check whether \p extractOp would be a gather or a contiguous load Op after
897 /// vectorising \p linalgOp. Note that it is always safe to use gather load
898 /// operations for contiguous loads (albeit slow), but not vice-versa. When in
899 /// doubt, bail out and assume that \p extractOp is a gather load.
901 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
902  LinalgOp &linalgOp) {
903 
904  auto targetShape = linalgOp.getStaticLoopRanges();
905  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
906 
907  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
908  if (inputShape.getShape().empty())
910 
911  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
912  // gather load.
913  // TODO: Relax this condition.
914  if (linalgOp.hasDynamicShape())
916 
917  // 1. Assume that it's a gather load when reading _into_:
918  // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
919  // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
920  // TODO: Relax these conditions.
921  // FIXME: This condition assumes non-dynamic sizes.
922  if ((llvm::count_if(targetShape,
923  [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
924  targetShape.back() == 1)
926 
927  // 2. Assume that it's a gather load when reading _from_ a tensor for which
928  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
929  // TODO: Relax this condition.
930  if (inputShape.getShape().back() == 1)
932 
933  bool leadingIdxsLoopInvariant = true;
934 
935  // 3. Analyze the leading indices of `extractOp`.
936  // Look at the way each index is calculated and decide whether it is suitable
937  // for a contiguous load, i.e. whether it's loop invariant.
938  auto indices = extractOp.getIndices();
939  auto leadIndices = indices.drop_back(1);
940 
941  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
942  if (inputShape.getShape()[i] == 1)
943  continue;
944 
945  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
946  }
947 
948  if (!leadingIdxsLoopInvariant) {
949  LDBG("Found gather load: " << extractOp);
951  }
952 
953  // 4. Analyze the trailing index for `extractOp`.
954  // At this point we know that the leading indices are loop invariant. This
955  // means that is potentially a scalar or a contiguous load. We can decide
956  // based on the trailing idx.
957  auto extractOpTrailingIdx = indices.back();
958 
959  // 4a. Scalar broadcast load
960  // If the trailing index is loop invariant then this is a scalar load.
961  if (leadingIdxsLoopInvariant &&
962  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
963  LDBG("Found scalar broadcast load: " << extractOp);
964 
966  }
967 
968  // 4b. Contiguous loads
969  // The trailing `extractOp` index should increment with every loop iteration.
970  // This effectively means that it must be based on the trailing loop index.
971  // This is what the following bool captures.
972  bool foundIndexOp = false;
973  bool isContiguousLoad =
974  isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
975  isContiguousLoad &= foundIndexOp;
976 
977  if (isContiguousLoad) {
978  LDBG("Found contigous load: " << extractOp);
980  }
981 
982  // 5. Fallback case - gather load.
983  LDBG("Found gather load: " << extractOp);
985 }
986 
987 /// Helper function to vectorize the tensor.extract operations. Returns
988 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
989 /// should map the produced operations. This function is meant to be used as a
990 /// CustomVectorizationHook.
991 static VectorizationResult
993  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
994  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
995  if (!extractOp)
997  auto loc = extractOp.getLoc();
998 
999  // Compute the static loop sizes of the extract op.
1000  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1001  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1002  loc,
1003  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1004  /*value=*/true));
1005  auto passThruConstantOp =
1006  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1007 
1008  // Base indices are currently set to 0. We will need to re-visit if more
1009  // generic scenarios are to be supported.
1010  SmallVector<Value> baseIndices(
1011  extractOp.getIndices().size(),
1012  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1013 
1014  VectorMemoryAccessKind memAccessKind =
1015  getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1016 
1017  // 1. Handle gather access
1018  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1019  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1020 
1021  // Generate the gather load
1022  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1023  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1024  maskConstantOp, passThruConstantOp);
1025  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1026 
1027  LDBG("Vectorised as gather load: " << extractOp << "\n");
1029  }
1030 
1031  // 2. Handle:
1032  // a. scalar loads + broadcast,
1033  // b. contiguous loads.
1034  // Both cases use vector.transfer_read.
1035 
1036  // Collect indices for `vector.transfer_read`. At this point, the indices will
1037  // either be scalars or would have been broadcast to vectors matching the
1038  // result type. For indices that are vectors, there are two options:
1039  // * for non-trailing indices, all elements are identical (contiguous
1040  // loads are identified by looking for non-trailing indices that are
1041  // invariant with respect to the corresponding linalg.generic), or
1042  // * for trailing indices, the index vector will contain values with stride
1043  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1044  // needed.
1045  // This means that
1046  // * for scalar indices - just re-use it,
1047  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1048  // (0th) element and use that.
1049  SmallVector<Value> transferReadIdxs;
1050  auto resTrailingDim = resultType.getShape().back();
1051  auto zero = rewriter.create<arith::ConstantOp>(
1052  loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1053  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1054  auto idx = bvm.lookup(extractOp.getIndices()[i]);
1055  if (idx.getType().isIndex()) {
1056  transferReadIdxs.push_back(idx);
1057  continue;
1058  }
1059 
1060  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1061  loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
1062  bvm.lookup(extractOp.getIndices()[i]));
1063  transferReadIdxs.push_back(
1064  rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1065  }
1066 
1067  // `tensor.extract_element` is always in-bounds, hence the following holds.
1068  auto dstRank = resultType.getRank();
1069  auto srcRank = extractOp.getTensor().getType().getRank();
1070  SmallVector<bool> inBounds(dstRank, true);
1071 
1072  // 2a. Handle scalar broadcast access.
1073  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1074  MLIRContext *ctx = rewriter.getContext();
1075  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1076  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1077 
1078  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1079  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1080  permutationMap, inBounds);
1081 
1082  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1083  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1084  }
1085 
1086  // 2b. Handle contiguous access.
1087  auto permutationMap = AffineMap::getMinorIdentityMap(
1088  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1089 
1090  int32_t rankDiff = dstRank - srcRank;
1091  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1092  // dims so that the ranks match. This is done by extending the map with 0s.
1093  // For example, for dstRank = 3, srcRank = 2, the following map created
1094  // above:
1095  // (d0, d1) --> (d0, d1)
1096  // is extended as:
1097  // (d0, d1) --> (0, d0, d1)
1098  while (rankDiff > 0) {
1099  permutationMap = permutationMap.insertResult(
1100  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1101  rankDiff--;
1102  }
1103 
1104  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1105  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1106  inBounds);
1107 
1108  LDBG("Vectorised as contiguous load: " << extractOp);
1109  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1110 }
1111 
1112 /// Emit reduction operations if the shapes of the value to reduce is different
1113 /// that the result shape.
1114 // Note: this is a true builder that notifies the OpBuilder listener.
1115 // TODO: Consider moving as a static helper on the ReduceOp.
1116 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1117  Value reduceValue, Value initialValue,
1118  const IRMapping &bvm) {
1119  Value reduceVec = bvm.lookup(reduceValue);
1120  Value outputVec = bvm.lookup(initialValue);
1121  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1122  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1123  // Reduce only if needed as the value may already have been reduce for
1124  // contraction vectorization.
1125  if (!reduceType ||
1126  (outputType && reduceType.getShape() == outputType.getShape()))
1127  return nullptr;
1128  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1129  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1130 }
1131 
1132 /// Generic vectorization for a single operation `op`, given already vectorized
1133 /// operands carried by `bvm`. Vectorization occurs as follows:
1134 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1135 /// result on success.
1136 /// 2. Clone any constant in the current scope without vectorization: each
1137 /// consumer of the constant will later determine the shape to which the
1138 /// constant needs to be broadcast to.
1139 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1140 /// of the `customVectorizationHooks` to cover such cases.
1141 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1142 /// operand of maximal rank. Other operands have smaller rank and are
1143 /// broadcast accordingly. It is assumed this broadcast is always legal,
1144 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1145 ///
1146 /// This function assumes all operands of `op` have been vectorized and are in
1147 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1148 /// a topologically-sorted list of ops.
1149 /// This function does not update `bvm` but returns a VectorizationStatus that
1150 /// instructs the caller what `bvm` update needs to occur.
1151 static VectorizationResult
1153  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1154  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1155  LDBG("vectorize op " << *op << "\n");
1156 
1157  // 1. Try to apply any CustomVectorizationHook.
1158  if (!customVectorizationHooks.empty()) {
1159  for (auto &customFunc : customVectorizationHooks) {
1160  VectorizationResult result = customFunc(op, bvm);
1161  if (result.status == VectorizationStatus::Failure)
1162  continue;
1163  return result;
1164  }
1165  }
1166 
1167  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1168  // Clone so that the constant is not confined to the linalgOp block .
1169  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1170  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1171 
1172  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1175 
1176  // 4 . Check if the operation is a reduction.
1177  SmallVector<std::pair<Value, Value>> reductionOperands;
1178  for (Value operand : op->getOperands()) {
1179  auto blockArg = dyn_cast<BlockArgument>(operand);
1180  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1181  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1182  continue;
1183  SmallVector<Operation *> reductionOps;
1184  Value reduceValue = matchReduction(
1185  linalgOp.getRegionOutputArgs(),
1186  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1187  if (!reduceValue)
1188  continue;
1189  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1190  }
1191  if (!reductionOperands.empty()) {
1192  assert(reductionOperands.size() == 1);
1193  Operation *reduceOp =
1194  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1195  reductionOperands[0].second, bvm);
1196  if (reduceOp)
1198  }
1199 
1200  // 5. Generic vectorization path for ElementwiseMappable ops.
1201  // a. Get the first max ranked shape.
1202  VectorType firstMaxRankedType;
1203  for (Value operand : op->getOperands()) {
1204  auto vecOperand = bvm.lookup(operand);
1205  assert(vecOperand && "Vector operand couldn't be found");
1206 
1207  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1208  if (vecType && (!firstMaxRankedType ||
1209  firstMaxRankedType.getRank() < vecType.getRank()))
1210  firstMaxRankedType = vecType;
1211  }
1212  // b. Broadcast each op if needed.
1213  SmallVector<Value> vecOperands;
1214  for (Value scalarOperand : op->getOperands()) {
1215  Value vecOperand = bvm.lookup(scalarOperand);
1216  assert(vecOperand && "Vector operand couldn't be found");
1217 
1218  if (firstMaxRankedType) {
1219  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1220  getElementTypeOrSelf(vecOperand.getType()),
1221  firstMaxRankedType.getScalableDims());
1222  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1223  } else {
1224  vecOperands.push_back(vecOperand);
1225  }
1226  }
1227  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1228  SmallVector<Type> resultTypes;
1229  for (Type resultType : op->getResultTypes()) {
1230  resultTypes.push_back(
1231  firstMaxRankedType
1232  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1233  firstMaxRankedType.getScalableDims())
1234  : resultType);
1235  }
1236  // d. Build and return the new op.
1237  return VectorizationResult{
1239  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1240  resultTypes, op->getAttrs())};
1241 }
1242 
1243 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1244 /// vector form. Generic vectorization proceeds as follows:
1245 /// 1. Verify the `linalgOp` has one non-empty region.
1246 /// 2. Values defined above the region are mapped to themselves and will be
1247 /// broadcasted on a per-need basis by their consumers.
1248 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1249 /// load).
1250 /// TODO: Reuse opportunities for RAR dependencies.
1251 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1252 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1253 /// iteration indices.
1254 /// 5. Iteratively call vectorizeOneOp on the region operations.
1255 ///
1256 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1257 /// performed to the maximal common vector size implied by the `linalgOp`
1258 /// iteration space. This eager broadcasting is introduced in the
1259 /// permutation_map of the vector.transfer_read operations. The eager
1260 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1261 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1262 /// the absence of good canonicalizations, the amount of work increases.
1263 /// This is not deemed a problem as we expect canonicalizations and foldings to
1264 /// aggressively clean up the useless work.
1265 static LogicalResult
1267  LinalgOp linalgOp,
1268  SmallVectorImpl<Value> &newResults) {
1269  LDBG("Vectorizing operation as linalg generic\n");
1270  Block *block = linalgOp.getBlock();
1271 
1272  // 2. Values defined above the region can only be broadcast for now. Make them
1273  // map to themselves.
1274  IRMapping bvm;
1275  SetVector<Value> valuesSet;
1276  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1277  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1278 
1279  if (linalgOp.getNumDpsInits() == 0)
1280  return failure();
1281 
1282  // 3. Turn all BBArgs into vector.transfer_read / load.
1283  Location loc = linalgOp.getLoc();
1284  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1285  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1286  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1287  if (linalgOp.isScalar(opOperand)) {
1288  bvm.map(bbarg, opOperand->get());
1289  continue;
1290  }
1291 
1292  // 3.a. Convert the indexing map for this input/output to a transfer read
1293  // permutation map and masking map.
1294  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1295 
1296  // Remove zeros from indexing map to use it as masking map.
1297  SmallVector<int64_t> zeroPos;
1298  auto results = indexingMap.getResults();
1299  for (const auto &result : llvm::enumerate(results)) {
1300  if (result.value().isa<AffineConstantExpr>()) {
1301  zeroPos.push_back(result.index());
1302  }
1303  }
1304  AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1305 
1306  AffineMap readMap;
1307  VectorType readType;
1308  Type elemType = getElementTypeOrSelf(opOperand->get());
1309  if (linalgOp.isDpsInput(opOperand)) {
1310  // 3.a.i. For input reads we use the canonical vector shape.
1311  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1312  readType = state.getCanonicalVecType(elemType);
1313  } else {
1314  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1315  // reductions), the vector shape is computed by mapping the canonical
1316  // vector shape to the output domain and back to the canonical domain.
1317  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1318  readType =
1319  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1320  }
1321 
1322  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1323 
1324  Operation *read = rewriter.create<vector::TransferReadOp>(
1325  loc, readType, opOperand->get(), indices, readMap);
1326  read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1327  Value readValue = read->getResult(0);
1328 
1329  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1330  // will be in-bounds.
1331  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1332  SmallVector<bool> inBounds(readType.getRank(), true);
1333  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1334  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1335  }
1336 
1337  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1338  // TODO: remove this.
1339  if (readType.getRank() == 0)
1340  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1341 
1342  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1343  << "\n");
1344  bvm.map(bbarg, readValue);
1345  bvm.map(opOperand->get(), readValue);
1346  }
1347 
1349  // 4a. Register CustomVectorizationHook for yieldOp.
1350  CustomVectorizationHook vectorizeYield =
1351  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1352  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1353  };
1354  hooks.push_back(vectorizeYield);
1355 
1356  // 4b. Register CustomVectorizationHook for indexOp.
1357  CustomVectorizationHook vectorizeIndex =
1358  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1359  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1360  };
1361  hooks.push_back(vectorizeIndex);
1362 
1363  // 4c. Register CustomVectorizationHook for extractOp.
1364  CustomVectorizationHook vectorizeExtract =
1365  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1366  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1367  };
1368  hooks.push_back(vectorizeExtract);
1369 
1370  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1371  for (Operation &op : block->getOperations()) {
1372  VectorizationResult result =
1373  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1374  if (result.status == VectorizationStatus::Failure) {
1375  LDBG("failed to vectorize: " << op << "\n");
1376  return failure();
1377  }
1378  if (result.status == VectorizationStatus::NewOp) {
1379  Operation *maybeMaskedOp =
1380  state.maskOperation(rewriter, result.newOp, linalgOp);
1381  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1382  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1383  }
1384  }
1385 
1386  return success();
1387 }
1388 
1389 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1390 /// and (3) all-zero lowPad to
1391 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1392 static LogicalResult
1393 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1394  ArrayRef<int64_t> inputVectorSizes,
1395  SmallVectorImpl<Value> &newResults) {
1396  auto padValue = padOp.getConstantPaddingValue();
1397  Location loc = padOp.getLoc();
1398  int64_t rank = inputVectorSizes.size();
1399  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
1400  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
1401 
1402  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1403  OpBuilder::InsertionGuard g(rewriter);
1404  rewriter.setInsertionPoint(padOp);
1405 
1406  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1407  LogicalResult status =
1408  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1409  .reifyResultShapes(rewriter, reifiedReturnShapes);
1410  (void)status; // prevent unused variable warning on non-assert builds
1411  assert(succeeded(status) && "failed to reify result shapes");
1412  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
1413  padValue.getType());
1414  SmallVector<OpFoldResult> mixedSourceDims =
1415  tensor::getMixedSizes(rewriter, loc, padOp.getSource());
1416  Value mask =
1417  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1418  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1419  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1420  loc,
1421  /*vectorType=*/vectorType,
1422  /*source=*/padOp.getSource(),
1423  /*indices=*/SmallVector<Value>(rank, zero),
1424  /*padding=*/padValue,
1425  /*inBounds=*/SmallVector<bool>(rank, true));
1426  auto maskedOp = cast<vector::MaskOp>(
1427  mlir::vector::maskOperation(rewriter, transferReadOp, mask));
1428  Operation *write = rewriter.create<vector::TransferWriteOp>(
1429  loc,
1430  /*vector=*/maskedOp->getResult(0),
1431  /*source=*/emptyOp,
1432  /*indices=*/SmallVector<Value>(rank, zero),
1433  /*inBounds=*/SmallVector<bool>(rank, true));
1434  bool needMaskForWrite = llvm::any_of(
1435  llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
1436  [](auto it) { return std::get<0>(it) != std::get<1>(it); });
1437  if (needMaskForWrite) {
1438  Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
1439  loc, maskType, reifiedReturnShapes[0]);
1440  write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
1441  }
1442  newResults.push_back(write->getResult(0));
1443  return success();
1444 }
1445 
1446 // TODO: probably need some extra checks for reduction followed by consumer
1447 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1449  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1450  LDBG("reduction precondition failed: no reduction iterator\n");
1451  return failure();
1452  }
1453  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1454  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1455  if (indexingMap.isPermutation())
1456  continue;
1457 
1458  Operation *reduceOp = matchLinalgReduction(&opOperand);
1459  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1460  LDBG("reduction precondition failed: reduction detection failed\n");
1461  return failure();
1462  }
1463  }
1464  return success();
1465 }
1466 
1468  // TODO: Masking only supports dynamic generic ops for now.
1469  if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp,
1470  linalg::ContractionOpInterface>(op.getOperation()))
1471  return failure();
1472 
1473  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1474  return success();
1475 }
1476 
1477 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
1478 /// given `shape`, i.e., it meets:
1479 /// 1. The numbers of elements in both array are equal.
1480 /// 2. `inputVectorSizes` does nos have dynamic dimensions.
1481 /// 3. All the values in `inputVectorSizes` are greater than or equal to
1482 /// static sizes in `shape`.
1483 static LogicalResult
1485  ArrayRef<int64_t> inputVectorSizes) {
1486  LDBG("Iteration space static sizes:");
1487  LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
1488  LLVM_DEBUG(llvm::dbgs() << "\n");
1489 
1490  if (inputVectorSizes.size() != shape.size()) {
1491  LDBG("Input vector sizes don't match the number of loops");
1492  return failure();
1493  }
1494  if (ShapedType::isDynamicShape(inputVectorSizes)) {
1495  LDBG("Input vector sizes can't have dynamic dimensions");
1496  return failure();
1497  }
1498  if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
1499  [](std::tuple<int64_t, int64_t> sizePair) {
1500  int64_t staticSize = std::get<0>(sizePair);
1501  int64_t inputSize = std::get<1>(sizePair);
1502  return ShapedType::isDynamic(staticSize) ||
1503  staticSize <= inputSize;
1504  })) {
1505  LDBG("Input vector sizes must be greater than or equal to iteration space "
1506  "static sizes");
1507  return failure();
1508  }
1509  return success();
1510 }
1511 
1512 static LogicalResult
1514  ArrayRef<int64_t> inputVectorSizes,
1515  bool vectorizeNDExtract) {
1516  // tensor with dimension of 0 cannot be vectorized.
1517  if (llvm::any_of(linalgOp.getStaticShape(),
1518  [](int64_t dim) { return dim == 0; }))
1519  return failure();
1520  // Check API contract for input vector sizes.
1521  if (!inputVectorSizes.empty() &&
1522  failed(isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1523  inputVectorSizes)))
1524  return failure();
1525 
1526  if (linalgOp.hasDynamicShape() &&
1528  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1529  return failure();
1530  }
1531 
1533 
1534  // Register CustomVectorizationPrecondition for extractOp.
1535  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1536 
1537  // All types in the body should be a supported element type for VectorType.
1538  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1539  // Check if any custom hook can vectorize the inner op.
1540  if (llvm::any_of(
1541  customPreconditions,
1542  [&](const CustomVectorizationPrecondition &customPrecondition) {
1543  return succeeded(
1544  customPrecondition(&innerOp, vectorizeNDExtract));
1545  })) {
1546  continue;
1547  }
1548  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1549  return !VectorType::isValidElementType(type);
1550  })) {
1551  return failure();
1552  }
1553  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1554  return !VectorType::isValidElementType(type);
1555  })) {
1556  return failure();
1557  }
1558  }
1559  if (isElementwise(linalgOp))
1560  return success();
1561  // TODO: isaConvolutionOpInterface that can also infer from generic features.
1562  // But we will still need stride/dilation attributes that will be annoying to
1563  // reverse-engineer...
1564  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1565  return success();
1566  // TODO: the common vector shape is equal to the static loop sizes only when
1567  // all indexing maps are projected permutations. For convs and stencils the
1568  // logic will need to evolve.
1569  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1570  LDBG("precondition failed: not projected permutations\n");
1571  return failure();
1572  }
1573  if (failed(reductionPreconditions(linalgOp))) {
1574  LDBG("precondition failed: reduction preconditions\n");
1575  return failure();
1576  }
1577  return success();
1578 }
1579 
1580 static LogicalResult
1581 vectorizePadOpPrecondition(tensor::PadOp padOp,
1582  ArrayRef<int64_t> inputVectorSizes) {
1583  auto padValue = padOp.getConstantPaddingValue();
1584  if (!padValue) {
1585  LDBG("pad value is not constant: " << padOp << "\n");
1586  return failure();
1587  }
1588 
1589  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1590  if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
1591  return failure();
1592 
1593  if (llvm::any_of(padOp.getLow(), [](Value v) {
1594  std::optional<int64_t> res = getConstantIntValue(v);
1595  return !res.has_value() || res.value() != 0;
1596  })) {
1597  LDBG("low pad must all be zero: " << padOp << "\n");
1598  return failure();
1599  }
1600 
1601  return success();
1602 }
1603 
1604 /// Preconditions for scalable vectors.
1605 static LogicalResult
1607  ArrayRef<int64_t> inputVectorSizes,
1608  ArrayRef<bool> inputScalableVecDims) {
1609  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1610  "Number of input vector sizes and scalable dims doesn't match");
1611 
1612  if (inputVectorSizes.empty())
1613  return success();
1614 
1615  bool isScalable = inputScalableVecDims.back();
1616  if (!isScalable)
1617  return success();
1618 
1619  // Only element-wise ops supported in the presence of scalable dims.
1620  auto linalgOp = dyn_cast<LinalgOp>(op);
1621  return success(linalgOp && isElementwise(linalgOp));
1622 }
1623 
1625  Operation *op, ArrayRef<int64_t> inputVectorSizes,
1626  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
1627  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
1628  inputScalableVecDims)))
1629  return failure();
1630 
1632  .Case<linalg::LinalgOp>([&](auto linalgOp) {
1633  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1634  vectorizeNDExtract);
1635  })
1636  .Case<tensor::PadOp>([&](auto padOp) {
1637  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
1638  })
1639  .Default([](auto) { return failure(); });
1640 }
1641 
1642 /// Converts affine.apply Ops to arithmetic operations.
1643 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1644  OpBuilder::InsertionGuard g(rewriter);
1645  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
1646 
1647  for (auto op : make_early_inc_range(toReplace)) {
1648  rewriter.setInsertionPoint(op);
1649  auto expanded = affine::expandAffineExpr(
1650  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
1651  op.getOperands().take_front(op.getAffineMap().getNumDims()),
1652  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
1653  rewriter.replaceOp(op, expanded);
1654  }
1655 }
1656 
1657 /// Emit a suitable vector form for an operation. If provided,
1658 /// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
1659 /// must match the rank of the iteration space of the operation and the input
1660 /// vector sizes must be greater than or equal to their counterpart iteration
1661 /// space sizes, if static. `inputVectorShapes` also allows the vectorization of
1662 /// operations with dynamic shapes.
1664  ArrayRef<int64_t> inputVectorSizes,
1665  ArrayRef<bool> inputScalableVecDims,
1666  bool vectorizeNDExtract) {
1667  LDBG("Attempting to vectorize:\n" << *op << "\n");
1668  LDBG("Input vector sizes: ");
1669  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1670  LLVM_DEBUG(llvm::dbgs() << "\n");
1671  LDBG("Input scalable vector dims: ");
1672  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
1673  LLVM_DEBUG(llvm::dbgs() << "\n");
1674 
1675  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
1676  vectorizeNDExtract))) {
1677  LDBG("Vectorization pre-conditions failed\n");
1678  return failure();
1679  }
1680 
1681  // Initialize vectorization state.
1682  VectorizationState state(rewriter);
1683  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
1684  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
1685  inputScalableVecDims))) {
1686  LDBG("Vectorization state couldn't be initialized\n");
1687  return failure();
1688  }
1689  }
1690 
1691  SmallVector<Value> results;
1692  auto vectorizeResult =
1694  .Case<linalg::LinalgOp>([&](auto linalgOp) {
1695  // TODO: isaConvolutionOpInterface that can also infer from generic
1696  // features. Will require stride/dilation attributes inference.
1697  FailureOr<Operation *> convOr =
1698  vectorizeConvolution(rewriter, linalgOp);
1699  if (succeeded(convOr)) {
1700  llvm::append_range(results, (*convOr)->getResults());
1701  return success();
1702  }
1703 
1704  LDBG("Vectorize generic by broadcasting to the canonical vector "
1705  "shape\n");
1706 
1707  // Pre-process before proceeding.
1708  convertAffineApply(rewriter, linalgOp);
1709 
1710  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
1711  // to 'OpBuilder' when it is passed over to some methods like
1712  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
1713  // erase an op within these methods, the actual rewriter won't be
1714  // notified and we will end up with read-after-free issues!
1715  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
1716  })
1717  .Case<tensor::PadOp>([&](auto padOp) {
1718  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
1719  results);
1720  })
1721  .Default([](auto) { return failure(); });
1722 
1723  if (failed(vectorizeResult)) {
1724  LDBG("Vectorization failed\n");
1725  return failure();
1726  }
1727 
1728  if (!results.empty())
1729  rewriter.replaceOp(op, results);
1730  else
1731  rewriter.eraseOp(op);
1732 
1733  return success();
1734 }
1735 
1737  memref::CopyOp copyOp) {
1738 
1739  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
1740  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
1741  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
1742  return failure();
1743 
1744  auto srcElementType = getElementTypeOrSelf(srcType);
1745  auto dstElementType = getElementTypeOrSelf(dstType);
1746  if (!VectorType::isValidElementType(srcElementType) ||
1747  !VectorType::isValidElementType(dstElementType))
1748  return failure();
1749 
1750  auto readType = VectorType::get(srcType.getShape(), srcElementType);
1751  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
1752 
1753  Location loc = copyOp->getLoc();
1754  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1755  SmallVector<Value> indices(srcType.getRank(), zero);
1756 
1757  Value readValue = rewriter.create<vector::TransferReadOp>(
1758  loc, readType, copyOp.getSource(), indices,
1759  rewriter.getMultiDimIdentityMap(srcType.getRank()));
1760  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
1761  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1762  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
1763  }
1764  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
1765  loc, readValue, copyOp.getTarget(), indices,
1766  rewriter.getMultiDimIdentityMap(srcType.getRank()));
1767  rewriter.replaceOp(copyOp, writeValue->getResults());
1768  return success();
1769 }
1770 
1771 //----------------------------------------------------------------------------//
1772 // Misc. vectorization patterns.
1773 //----------------------------------------------------------------------------//
1774 
1775 /// Helper function that retrieves the value of an IntegerAttr.
1776 static int64_t getIntFromAttr(Attribute attr) {
1777  return cast<IntegerAttr>(attr).getInt();
1778 }
1779 
1780 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
1781 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
1782 /// not supported.
1784  ArrayRef<OpFoldResult> ofrs) {
1785  SmallVector<Value> result;
1786  for (auto o : ofrs) {
1787  if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
1788  result.push_back(val);
1789  } else {
1790  result.push_back(rewriter.create<arith::ConstantIndexOp>(
1791  loc, getIntFromAttr(o.template get<Attribute>())));
1792  }
1793  }
1794  return result;
1795 }
1796 
1797 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
1798 /// InsertSliceOp. For now, only constant padding values are supported.
1799 /// If there is enough static type information, TransferReadOps and
1800 /// TransferWriteOps may be generated instead of InsertSliceOps.
1803  PatternBenefit benefit = 1)
1804  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
1805  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
1806  /// each dimension size is statically know in the source type or the result
1807  /// type (or both).
1809  tensor::PadOp padOp, Value dest) {
1810  auto sourceType = padOp.getSourceType();
1811  auto resultType = padOp.getResultType();
1812  if (!VectorType::isValidElementType(sourceType.getElementType()))
1813  return failure();
1814 
1815  // Copy cannot be vectorized if pad value is non-constant and source shape
1816  // is dynamic. In case of a dynamic source shape, padding must be appended
1817  // by TransferReadOp, but TransferReadOp supports only constant padding.
1818  auto padValue = padOp.getConstantPaddingValue();
1819  if (!padValue) {
1820  if (!sourceType.hasStaticShape())
1821  return failure();
1822  // Create dummy padding value.
1823  auto elemType = sourceType.getElementType();
1824  padValue = rewriter.create<arith::ConstantOp>(
1825  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
1826  }
1827 
1828  SmallVector<int64_t> vecShape;
1829  SmallVector<bool> readInBounds;
1830  SmallVector<bool> writeInBounds;
1831  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
1832  if (!sourceType.isDynamicDim(i)) {
1833  vecShape.push_back(sourceType.getDimSize(i));
1834  // Source shape is statically known: Neither read nor write are
1835  // out-of- bounds.
1836  readInBounds.push_back(true);
1837  writeInBounds.push_back(true);
1838  } else if (!resultType.isDynamicDim(i)) {
1839  // Source shape is not statically known, but result shape is.
1840  // Vectorize with size of result shape. This may be larger than the
1841  // source size.
1842  vecShape.push_back(resultType.getDimSize(i));
1843  // Read may be out-of-bounds because the result size could be larger
1844  // than the source size.
1845  readInBounds.push_back(false);
1846  // Write is out-of-bounds if low padding > 0.
1847  writeInBounds.push_back(
1848  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
1849  static_cast<int64_t>(0));
1850  } else {
1851  // Neither source nor result dim of padOp is static. Cannot vectorize
1852  // the copy.
1853  return failure();
1854  }
1855  }
1856  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
1857 
1858  // Generate TransferReadOp.
1859  SmallVector<Value> readIndices(
1860  vecType.getRank(),
1861  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1862  auto read = rewriter.create<vector::TransferReadOp>(
1863  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
1864  ArrayRef<bool>{readInBounds});
1865 
1866  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
1867  // entire tensor, write directly to the FillOp's operand.
1868  if (llvm::equal(vecShape, resultType.getShape()) &&
1869  llvm::all_of(writeInBounds, [](bool b) { return b; }))
1870  if (auto fill = dest.getDefiningOp<FillOp>())
1871  dest = fill.output();
1872 
1873  // Generate TransferWriteOp.
1874  auto writeIndices =
1875  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
1876  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1877  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
1878 
1879  return success();
1880  }
1881 };
1882 
1883 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
1884 /// given operation type OpTy.
1885 template <typename OpTy>
1886 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
1888 
1889  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1890  PatternRewriter &rewriter) const final {
1891  bool changed = false;
1892  // Insert users in vector, because some users may be replaced/removed.
1893  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
1894  if (auto op = dyn_cast<OpTy>(user))
1895  changed |= rewriteUser(rewriter, padOp, op).succeeded();
1896  return success(changed);
1897  }
1898 
1899 protected:
1901  tensor::PadOp padOp, OpTy op) const = 0;
1902 };
1903 
1904 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
1905 /// ```
1906 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
1907 /// %r = vector.transfer_read %0[%c0, %c0], %cst
1908 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
1909 /// ```
1910 /// is rewritten to:
1911 /// ```
1912 /// %r = vector.transfer_read %src[%c0, %c0], %padding
1913 /// {in_bounds = [true, true]}
1914 /// : tensor<?x?xf32>, vector<17x5xf32>
1915 /// ```
1916 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
1917 /// sure that the original padding value %cst was never used.
1918 ///
1919 /// This rewrite is possible if:
1920 /// - `xferOp` has no out-of-bounds dims or mask.
1921 /// - Low padding is static 0.
1922 /// - Single, scalar padding value.
1924  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
1926  vector::TransferReadOp>::VectorizePadOpUserPattern;
1927 
1928  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
1929  vector::TransferReadOp xferOp) const override {
1930  // Low padding must be static 0.
1931  if (!padOp.hasZeroLowPad())
1932  return failure();
1933  // Pad value must be a constant.
1934  auto padValue = padOp.getConstantPaddingValue();
1935  if (!padValue)
1936  return failure();
1937  // Padding value of existing `xferOp` is unused.
1938  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
1939  return failure();
1940 
1941  rewriter.updateRootInPlace(xferOp, [&]() {
1942  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
1943  xferOp->setAttr(xferOp.getInBoundsAttrName(),
1944  rewriter.getBoolArrayAttr(inBounds));
1945  xferOp.getSourceMutable().assign(padOp.getSource());
1946  xferOp.getPaddingMutable().assign(padValue);
1947  });
1948 
1949  return success();
1950  }
1951 };
1952 
1953 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
1954 /// This pattern rewrites TransferWriteOps that write to a padded tensor
1955 /// value, where the same amount of padding is immediately removed again after
1956 /// the write. In such cases, the TransferWriteOp can write to the non-padded
1957 /// tensor value and apply out-of-bounds masking. E.g.:
1958 /// ```
1959 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
1960 /// : tensor<...> to tensor<?x?xf32>
1961 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
1962 /// %2 = vector.transfer_write %vec, %1[...]
1963 /// : vector<17x5xf32>, tensor<17x5xf32>
1964 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
1965 /// : tensor<17x5xf32> to tensor<?x?xf32>
1966 /// ```
1967 /// is rewritten to:
1968 /// ```
1969 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
1970 /// : tensor<...> to tensor<?x?xf32>
1971 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
1972 /// tensor<?x?xf32>
1973 /// ```
1974 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
1975 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
1976 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
1977 /// from %r's old dimensions.
1978 ///
1979 /// This rewrite is possible if:
1980 /// - Low padding is static 0.
1981 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
1982 /// ExtractSliceOp trims the same amount of padding that was added
1983 /// beforehand.
1984 /// - Single, scalar padding value.
1986  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
1988  vector::TransferWriteOp>::VectorizePadOpUserPattern;
1989 
1990  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
1991  vector::TransferWriteOp xferOp) const override {
1992  // TODO: support 0-d corner case.
1993  if (xferOp.getTransferRank() == 0)
1994  return failure();
1995 
1996  // Low padding must be static 0.
1997  if (!padOp.hasZeroLowPad())
1998  return failure();
1999  // Pad value must be a constant.
2000  auto padValue = padOp.getConstantPaddingValue();
2001  if (!padValue)
2002  return failure();
2003  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2004  if (!xferOp->hasOneUse())
2005  return failure();
2006  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2007  if (!trimPadding)
2008  return failure();
2009  // Only static zero offsets supported when trimming padding.
2010  if (!trimPadding.hasZeroOffset())
2011  return failure();
2012  // trimPadding must remove the amount of padding that was added earlier.
2013  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2014  return failure();
2015 
2016  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2017  rewriter.setInsertionPoint(xferOp);
2018 
2019  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2020  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2021  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2022  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2023  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2024  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2025 
2026  return success();
2027  }
2028 
2029  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2030  /// i.e., same dimensions.
2031  ///
2032  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2033  /// dimensions, this function tries to infer the (static) tensor size by
2034  /// looking at the defining op and utilizing op-specific knowledge.
2035  ///
2036  /// This is a conservative analysis. In case equal tensor sizes cannot be
2037  /// proven statically, this analysis returns `false` even though the tensor
2038  /// sizes may turn out to be equal at runtime.
2039  bool hasSameTensorSize(Value beforePadding,
2040  tensor::ExtractSliceOp afterTrimming) const {
2041  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2042  // result and CastOp operand.
2043  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2044  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2045  return true;
2046 
2047  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2048  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2049  // Only RankedTensorType supported.
2050  if (!t1 || !t2)
2051  return false;
2052  // Rank of both values must be the same.
2053  if (t1.getRank() != t2.getRank())
2054  return false;
2055 
2056  // All static dimensions must be the same. Mixed cases (e.g., dimension
2057  // static in `t1` but dynamic in `t2`) are not supported.
2058  for (unsigned i = 0; i < t1.getRank(); ++i) {
2059  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2060  return false;
2061  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2062  return false;
2063  }
2064 
2065  // Nothing more to check if all dimensions are static.
2066  if (t1.getNumDynamicDims() == 0)
2067  return true;
2068 
2069  // All dynamic sizes must be the same. The only supported case at the
2070  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2071  // thereof).
2072 
2073  // Apart from CastOp, only ExtractSliceOp is supported.
2074  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2075  if (!beforeSlice)
2076  return false;
2077 
2078  assert(static_cast<size_t>(t1.getRank()) ==
2079  beforeSlice.getMixedSizes().size());
2080  assert(static_cast<size_t>(t2.getRank()) ==
2081  afterTrimming.getMixedSizes().size());
2082 
2083  for (unsigned i = 0; i < t1.getRank(); ++i) {
2084  // Skip static dimensions.
2085  if (!t1.isDynamicDim(i))
2086  continue;
2087  auto size1 = beforeSlice.getMixedSizes()[i];
2088  auto size2 = afterTrimming.getMixedSizes()[i];
2089 
2090  // Case 1: Same value or same constant int.
2091  if (isEqualConstantIntOrValue(size1, size2))
2092  continue;
2093 
2094  // Other cases: Take a deeper look at defining ops of values.
2095  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2096  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2097  if (!v1 || !v2)
2098  return false;
2099 
2100  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2101  // CSE is run.)
2102  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2103  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2104  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2105  minOp1.getOperands() == minOp2.getOperands())
2106  continue;
2107 
2108  // Add additional cases as needed.
2109  }
2110 
2111  // All tests passed.
2112  return true;
2113  }
2114 };
2115 
2116 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2117 /// ```
2118 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2119 /// %r = tensor.insert_slice %0
2120 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2121 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2122 /// ```
2123 /// is rewritten to:
2124 /// ```
2125 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2126 /// : tensor<?x?xf32>, vector<17x5xf32>
2127 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2128 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2129 /// ```
2130 ///
2131 /// This rewrite is possible if:
2132 /// - Low padding is static 0.
2133 /// - `padOp` result shape is static.
2134 /// - The entire padded tensor is inserted.
2135 /// (Implies that sizes of `insertOp` are all static.)
2136 /// - Only unit strides in `insertOp`.
2137 /// - Single, scalar padding value.
2138 /// - `padOp` result not used as destination.
2140  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2142  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2143 
2144  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2145  tensor::InsertSliceOp insertOp) const override {
2146  // Low padding must be static 0.
2147  if (!padOp.hasZeroLowPad())
2148  return failure();
2149  // Only unit stride supported.
2150  if (!insertOp.hasUnitStride())
2151  return failure();
2152  // Pad value must be a constant.
2153  auto padValue = padOp.getConstantPaddingValue();
2154  if (!padValue)
2155  return failure();
2156  // Dynamic shapes not supported.
2157  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2158  return failure();
2159  // Pad result not used as destination.
2160  if (insertOp.getDest() == padOp.getResult())
2161  return failure();
2162 
2163  auto vecType = VectorType::get(padOp.getType().getShape(),
2164  padOp.getType().getElementType());
2165  unsigned vecRank = vecType.getRank();
2166  unsigned tensorRank = insertOp.getType().getRank();
2167 
2168  // Check if sizes match: Insert the entire tensor into most minor dims.
2169  // (No permutations allowed.)
2170  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2171  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2172  if (!llvm::all_of(
2173  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2174  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2175  }))
2176  return failure();
2177 
2178  // Insert the TransferReadOp and TransferWriteOp at the position of the
2179  // InsertSliceOp.
2180  rewriter.setInsertionPoint(insertOp);
2181 
2182  // Generate TransferReadOp: Read entire source tensor and add high
2183  // padding.
2184  SmallVector<Value> readIndices(
2185  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2186  auto read = rewriter.create<vector::TransferReadOp>(
2187  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2188 
2189  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2190  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2191  // source must fit into the destination at the specified offsets.
2192  auto writeIndices =
2193  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2194  SmallVector<bool> inBounds(vecRank, true);
2195  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2196  insertOp, read, insertOp.getDest(), writeIndices,
2197  ArrayRef<bool>{inBounds});
2198 
2199  return success();
2200  }
2201 };
2202 
2204  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2205  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2206  baseBenefit);
2207  // Try these specialized patterns first before resorting to the generic one.
2211  patterns.getContext(), baseBenefit.getBenefit() + 1);
2212 }
2213 
2214 //----------------------------------------------------------------------------//
2215 // Forwarding patterns
2216 //----------------------------------------------------------------------------//
2217 
2218 /// Check whether there is any interleaved use of any `values` between
2219 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2220 /// is in a different block.
2221 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2222  ValueRange values) {
2223  if (firstOp->getBlock() != secondOp->getBlock() ||
2224  !firstOp->isBeforeInBlock(secondOp)) {
2225  LDBG("interleavedUses precondition failed, firstOp: "
2226  << *firstOp << ", second op: " << *secondOp << "\n");
2227  return true;
2228  }
2229  for (auto v : values) {
2230  for (auto &u : v.getUses()) {
2231  Operation *owner = u.getOwner();
2232  if (owner == firstOp || owner == secondOp)
2233  continue;
2234  // TODO: this is too conservative, use dominance info in the future.
2235  if (owner->getBlock() == firstOp->getBlock() &&
2236  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2237  continue;
2238  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2239  << ", second op: " << *secondOp << "\n");
2240  return true;
2241  }
2242  }
2243  return false;
2244 }
2245 
2246 /// Return the unique subview use of `v` if it is indeed unique, null
2247 /// otherwise.
2248 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2249  memref::SubViewOp subViewOp;
2250  for (auto &u : v.getUses()) {
2251  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2252  if (subViewOp)
2253  return memref::SubViewOp();
2254  subViewOp = newSubViewOp;
2255  }
2256  }
2257  return subViewOp;
2258 }
2259 
2260 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2261 /// when available.
2263  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2264 
2265  // TODO: support mask.
2266  if (xferOp.getMask())
2267  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2268 
2269  // Transfer into `view`.
2270  Value viewOrAlloc = xferOp.getSource();
2271  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2272  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2273  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2274 
2275  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2276  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2277  if (!subViewOp)
2278  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2279  Value subView = subViewOp.getResult();
2280 
2281  // Find the copy into `subView` without interleaved uses.
2282  memref::CopyOp copyOp;
2283  for (auto &u : subView.getUses()) {
2284  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2285  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2286  if (newCopyOp.getTarget() != subView)
2287  continue;
2288  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2289  continue;
2290  copyOp = newCopyOp;
2291  break;
2292  }
2293  }
2294  if (!copyOp)
2295  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2296 
2297  // Find the fill into `viewOrAlloc` without interleaved uses before the
2298  // copy.
2299  FillOp maybeFillOp;
2300  for (auto &u : viewOrAlloc.getUses()) {
2301  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2302  assert(isa<MemRefType>(newFillOp.output().getType()));
2303  if (newFillOp.output() != viewOrAlloc)
2304  continue;
2305  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2306  continue;
2307  maybeFillOp = newFillOp;
2308  break;
2309  }
2310  }
2311  // Ensure padding matches.
2312  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2313  return rewriter.notifyMatchFailure(xferOp,
2314  "padding value does not match fill");
2315 
2316  // `in` is the subview that memref.copy reads. Replace it.
2317  Value in = copyOp.getSource();
2318 
2319  // memref.copy + linalg.fill can be used to create a padded local buffer.
2320  // The `masked` attribute is only valid on this padded buffer.
2321  // When forwarding to vector.transfer_read, the attribute must be reset
2322  // conservatively.
2323  Value res = rewriter.create<vector::TransferReadOp>(
2324  xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2325  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2326  // in_bounds is explicitly reset
2327  /*inBoundsAttr=*/ArrayAttr());
2328 
2329  if (maybeFillOp)
2330  rewriter.eraseOp(maybeFillOp);
2331  rewriter.eraseOp(copyOp);
2332  rewriter.replaceOp(xferOp, res);
2333 
2334  return success();
2335 }
2336 
2337 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2338 /// when available.
2340  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2341  // TODO: support mask.
2342  if (xferOp.getMask())
2343  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2344 
2345  // Transfer into `viewOrAlloc`.
2346  Value viewOrAlloc = xferOp.getSource();
2347  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2348  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2349  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2350 
2351  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2352  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2353  if (!subViewOp)
2354  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2355  Value subView = subViewOp.getResult();
2356 
2357  // Find the copy from `subView` without interleaved uses.
2358  memref::CopyOp copyOp;
2359  for (auto &u : subViewOp.getResult().getUses()) {
2360  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2361  if (newCopyOp.getSource() != subView)
2362  continue;
2363  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2364  continue;
2365  copyOp = newCopyOp;
2366  break;
2367  }
2368  }
2369  if (!copyOp)
2370  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2371 
2372  // `out` is the subview copied into that we replace.
2373  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2374  Value out = copyOp.getTarget();
2375 
2376  // Forward vector.transfer into copy.
2377  // memref.copy + linalg.fill can be used to create a padded local buffer.
2378  // The `masked` attribute is only valid on this padded buffer.
2379  // When forwarding to vector.transfer_write, the attribute must be reset
2380  // conservatively.
2381  rewriter.create<vector::TransferWriteOp>(
2382  xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2383  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2384  // in_bounds is explicitly reset
2385  /*inBoundsAttr=*/ArrayAttr());
2386 
2387  rewriter.eraseOp(copyOp);
2388  rewriter.eraseOp(xferOp);
2389 
2390  return success();
2391 }
2392 
2393 //===----------------------------------------------------------------------===//
2394 // Convolution vectorization patterns
2395 //===----------------------------------------------------------------------===//
2396 
2397 template <int N>
2398 static void bindShapeDims(ShapedType shapedType) {}
2399 
2400 template <int N, typename IntTy, typename... IntTy2>
2401 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2402  val = shapedType.getShape()[N];
2403  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2404 }
2405 
2406 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2407 template <typename... IntTy>
2408 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2409  bindShapeDims<0>(shapedType, vals...);
2410 }
2411 
2412 namespace {
2413 bool isCastOfBlockArgument(Operation *op) {
2414  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2415  isa<BlockArgument>(op->getOperand(0));
2416 }
2417 
2418 bool isSupportedPoolKind(vector::CombiningKind kind) {
2419  switch (kind) {
2420  case vector::CombiningKind::ADD:
2421  case vector::CombiningKind::MAXF:
2422  case vector::CombiningKind::MAXIMUMF:
2423  case vector::CombiningKind::MAXSI:
2424  case vector::CombiningKind::MAXUI:
2425  case vector::CombiningKind::MINF:
2426  case vector::CombiningKind::MINIMUMF:
2427  case vector::CombiningKind::MINSI:
2428  case vector::CombiningKind::MINUI:
2429  return true;
2430  default:
2431  return false;
2432  }
2433 }
2434 
2435 /// Generate a vector implementation for either:
2436 /// ```
2437 /// Op def: ( w, kw )
2438 /// Iters: ({Par(), Red()})
2439 /// Layout: {{w + kw}, {kw}, {w}}
2440 /// ```
2441 /// kw is unrolled.
2442 ///
2443 /// or
2444 ///
2445 /// ```
2446 /// Op def: ( n, w, c, kw, f )
2447 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2448 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2449 /// ```
2450 /// kw is unrolled, w is unrolled iff dilationW > 1.
2451 ///
2452 /// or
2453 ///
2454 /// ```
2455 /// Op def: ( n, c, w, f, kw )
2456 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2457 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2458 /// ```
2459 /// kw is unrolled, w is unrolled iff dilationW > 1.
2460 ///
2461 /// or
2462 ///
2463 /// ```
2464 /// Op def: ( n, w, c, kw )
2465 /// Iters: ({Par(), Par(), Par(), Red()})
2466 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2467 /// ```
2468 /// kw is unrolled, w is unrolled iff dilationW > 1.
2469 struct Conv1DGenerator
2470  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2471  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2472  int dilationW)
2473  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2474  strideW(strideW), dilationW(dilationW) {
2475  // Determine whether `linalgOp` can be generated with this generator
2476  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2477  return;
2478  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2479  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2480  resShaped = linalgOp.getDpsInitOperand(0)->get();
2481  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2482  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2483  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2484  if (!lhsShapedType || !rhsShapedType || !resShapedType)
2485  return;
2486  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2487  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2488  if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) ||
2489  (lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1)))
2490  return;
2491 
2492  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2493  if (!reduceOp)
2494  return;
2495  redOp = reduceOp->getName().getIdentifier();
2496 
2497  if (!setOperKind(reduceOp))
2498  return;
2499  auto maybeKind = getCombinerOpKind(reduceOp);
2500  if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2501  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2502  return;
2503  }
2504 
2505  auto rhsRank = rhsShapedType.getRank();
2506  switch (oper) {
2507  case Conv:
2508  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2509  return;
2510  break;
2511  case Pool:
2512  if (rhsRank != 1)
2513  return;
2514  break;
2515  }
2516  // The op is now known to be valid.
2517  valid = true;
2518  }
2519 
2520  /// Generate a vector implementation for:
2521  /// ```
2522  /// Op def: ( w, kw )
2523  /// Iters: ({Par(), Red()})
2524  /// Layout: {{w + kw}, {kw}, {w}}
2525  /// ```
2526  /// kw is always unrolled.
2527  ///
2528  /// or
2529  ///
2530  /// ```
2531  /// Op def: ( n, w, c, kw, f )
2532  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2533  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2534  /// ```
2535  /// kw is always unrolled.
2536  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2537  /// > 1.
2538  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
2539  if (!valid)
2540  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
2541 
2542  int64_t nSize, wSize, cSize, kwSize, fSize;
2543  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
2544  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
2545  switch (conv1DOpOrder) {
2546  case Conv1DOpOrder::W:
2547  // Initialize unused dimensions
2548  nSize = fSize = cSize = 0;
2549  // out{W}
2550  bindShapeDims(resShapedType, wSize);
2551  // kernel{kw}
2552  bindShapeDims(rhsShapedType, kwSize);
2553  lhsShape = {// iw = ow + kw - 1
2554  // (i.e. 16 convolved with 3 -> 14)
2555  (wSize + kwSize - 1)};
2556  rhsShape = {kwSize};
2557  resShape = {wSize};
2558  break;
2559  case Conv1DOpOrder::Nwc:
2560  // out{n, w, f}
2561  bindShapeDims(resShapedType, nSize, wSize, fSize);
2562  switch (oper) {
2563  case Conv:
2564  // kernel{kw, c, f}
2565  bindShapeDims(rhsShapedType, kwSize, cSize);
2566  break;
2567  case Pool:
2568  // kernel{kw}
2569  bindShapeDims(rhsShapedType, kwSize);
2570  cSize = fSize;
2571  break;
2572  }
2573  lhsShape = {nSize,
2574  // iw = ow * sw + kw * dw - 1
2575  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2576  // Perform the proper inclusive -> exclusive -> inclusive.
2577  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2578  1,
2579  cSize};
2580  switch (oper) {
2581  case Conv:
2582  rhsShape = {kwSize, cSize, fSize};
2583  break;
2584  case Pool:
2585  rhsShape = {kwSize};
2586  break;
2587  }
2588  resShape = {nSize, wSize, fSize};
2589  break;
2590  case Conv1DOpOrder::Ncw:
2591  // out{n, f, w}
2592  bindShapeDims(resShapedType, nSize, fSize, wSize);
2593  switch (oper) {
2594  case Conv:
2595  // kernel{f, c, kw}
2596  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
2597  break;
2598  case Pool:
2599  // kernel{kw}
2600  bindShapeDims(rhsShapedType, kwSize);
2601  cSize = fSize;
2602  break;
2603  }
2604  lhsShape = {nSize, cSize,
2605  // iw = ow * sw + kw * dw - 1
2606  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2607  // Perform the proper inclusive -> exclusive -> inclusive.
2608  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2609  1};
2610  switch (oper) {
2611  case Conv:
2612  rhsShape = {fSize, cSize, kwSize};
2613  break;
2614  case Pool:
2615  rhsShape = {kwSize};
2616  break;
2617  }
2618  resShape = {nSize, fSize, wSize};
2619  break;
2620  }
2621 
2622  vector::TransferWriteOp write;
2623  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2624 
2625  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2626  // When strideW == 1, we can batch the contiguous loads and avoid
2627  // unrolling
2628  int64_t wSizeStep = strideW == 1 ? wSize : 1;
2629 
2630  Type lhsEltType = lhsShapedType.getElementType();
2631  Type rhsEltType = rhsShapedType.getElementType();
2632  Type resEltType = resShapedType.getElementType();
2633  auto lhsType = VectorType::get(lhsShape, lhsEltType);
2634  auto rhsType = VectorType::get(rhsShape, rhsEltType);
2635  auto resType = VectorType::get(resShape, resEltType);
2636  // Zero padding with the corresponding dimensions for lhs, rhs and res.
2637  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
2638  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
2639  SmallVector<Value> resPadding(resShape.size(), zero);
2640 
2641  // Read the whole lhs, rhs and res in one shot (with zero padding).
2642  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2643  lhsPadding);
2644  // This is needed only for Conv.
2645  Value rhs = nullptr;
2646  if (oper == Conv)
2647  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2648  rhsPadding);
2649  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
2650  resPadding);
2651 
2652  // The base vectorization case for channeled convolution is input: {n,w,c},
2653  // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2654  // vectorization case, we do pre transpose on input, weight, and output.
2655  switch (conv1DOpOrder) {
2656  case Conv1DOpOrder::W:
2657  case Conv1DOpOrder::Nwc:
2658  // Base case, so no transposes necessary.
2659  break;
2660  case Conv1DOpOrder::Ncw: {
2661  // To match base vectorization case, we pre-transpose current case.
2662  // ncw -> nwc
2663  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2664  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
2665  // fcw -> wcf
2666  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2667 
2668  // This is needed only for Conv.
2669  if (oper == Conv)
2670  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
2671  // nfw -> nwf
2672  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2673  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
2674  break;
2675  }
2676  }
2677 
2678  //===------------------------------------------------------------------===//
2679  // Begin vector-only rewrite part
2680  //===------------------------------------------------------------------===//
2681  // Unroll along kw and read slices of lhs and rhs.
2682  SmallVector<Value> lhsVals, rhsVals, resVals;
2683  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
2684  kwSize, strideW, dilationW, wSizeStep,
2685  isSingleChanneled);
2686  // Do not do for pooling.
2687  if (oper == Conv)
2688  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
2689  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
2690  wSizeStep, isSingleChanneled);
2691 
2692  auto linearIndex = [&](int64_t kw, int64_t w) {
2693  return kw * (wSize / wSizeStep) + w;
2694  };
2695 
2696  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
2697  // perform outerproduct for non-channeled convolution or
2698  // perform simple arith operation for pooling
2699  for (int64_t kw = 0; kw < kwSize; ++kw) {
2700  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2701  switch (oper) {
2702  case Conv:
2703  if (isSingleChanneled) {
2704  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
2705  lhsVals[linearIndex(kw, w)],
2706  rhsVals[kw], resVals[w]);
2707  } else {
2708  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
2709  lhsVals[linearIndex(kw, w)],
2710  rhsVals[kw], resVals[w]);
2711  }
2712  break;
2713  case Pool:
2714  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
2715  resVals[w]);
2716  break;
2717  }
2718  }
2719  }
2720 
2721  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
2722  isSingleChanneled);
2723  //===------------------------------------------------------------------===//
2724  // End vector-only rewrite part
2725  //===------------------------------------------------------------------===//
2726 
2727  // The base vectorization case for channeled convolution is output: {n,w,f}
2728  // To reuse the result from base pattern vectorization case, we post
2729  // transpose the base case result.
2730  switch (conv1DOpOrder) {
2731  case Conv1DOpOrder::W:
2732  case Conv1DOpOrder::Nwc:
2733  // Base case, so no transposes necessary.
2734  break;
2735  case Conv1DOpOrder::Ncw: {
2736  // nwf -> nfw
2737  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
2738  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
2739  break;
2740  }
2741  }
2742 
2743  return rewriter
2744  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
2745  .getOperation();
2746  }
2747 
2748  // Take a value and widen to have the same element type as `ty`.
2749  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
2750  const Type srcElementType = getElementTypeOrSelf(val.getType());
2751  const Type dstElementType = getElementTypeOrSelf(ty);
2752  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
2753  if (srcElementType == dstElementType)
2754  return val;
2755 
2756  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
2757  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
2758  const Type dstType =
2759  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
2760 
2761  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
2762  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
2763  }
2764 
2765  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
2766  srcWidth < dstWidth)
2767  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
2768 
2769  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
2770  srcWidth < dstWidth)
2771  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
2772 
2773  assert(false && "unhandled promotion case");
2774  return nullptr;
2775  }
2776 
2777  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
2778  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
2779  Value lhs, Value rhs, Value res) {
2780  vector::IteratorType par = vector::IteratorType::parallel;
2781  vector::IteratorType red = vector::IteratorType::reduction;
2782  AffineExpr n, w, f, c;
2783  bindDims(ctx, n, w, f, c);
2784  lhs = promote(rewriter, loc, lhs, res.getType());
2785  rhs = promote(rewriter, loc, rhs, res.getType());
2786  return rewriter.create<vector::ContractionOp>(
2787  loc, lhs, rhs, res,
2788  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
2789  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
2790  }
2791 
2792  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
2793  // convolution.
2794  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
2795  Value lhs, Value rhs, Value res) {
2796  return rewriter.create<vector::OuterProductOp>(
2797  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
2798  }
2799 
2800  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
2801  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
2802  Value res) {
2803  if (isPoolExt)
2804  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
2805  return rewriter
2806  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
2807  ->getResult(0);
2808  }
2809 
2810  /// Generate a vector implementation for:
2811  /// ```
2812  /// Op def: ( n, w, c, kw)
2813  /// Iters: ({Par(), Par(), Par(), Red()})
2814  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2815  /// ```
2816  /// kw is always unrolled.
2817  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2818  /// > 1.
2819  FailureOr<Operation *> depthwiseConv() {
2820  if (!valid)
2821  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
2822 
2823  int64_t nSize, wSize, cSize, kwSize;
2824  // kernel{kw, c}
2825  bindShapeDims(rhsShapedType, kwSize, cSize);
2826  // out{n, w, c}
2827  bindShapeDims(resShapedType, nSize, wSize);
2828 
2829  vector::TransferWriteOp write;
2830  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2831 
2832  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2833  // When strideW == 1, we can batch the contiguous loads and avoid
2834  // unrolling
2835  int64_t wSizeStep = strideW == 1 ? wSize : 1;
2836 
2837  Type lhsEltType = lhsShapedType.getElementType();
2838  Type rhsEltType = rhsShapedType.getElementType();
2839  Type resEltType = resShapedType.getElementType();
2840  VectorType lhsType = VectorType::get(
2841  {nSize,
2842  // iw = ow * sw + kw * dw - 1
2843  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2844  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
2845  cSize},
2846  lhsEltType);
2847  VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
2848  VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
2849 
2850  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
2851  // 0].
2852  Value lhs = rewriter.create<vector::TransferReadOp>(
2853  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
2854  // Read rhs slice of size {kw, c} @ [0, 0].
2855  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2856  ValueRange{zero, zero});
2857  // Read res slice of size {n, w, c} @ [0, 0, 0].
2858  Value res = rewriter.create<vector::TransferReadOp>(
2859  loc, resType, resShaped, ValueRange{zero, zero, zero});
2860 
2861  //===------------------------------------------------------------------===//
2862  // Begin vector-only rewrite part
2863  //===------------------------------------------------------------------===//
2864  // Unroll along kw and read slices of lhs and rhs.
2865  SmallVector<Value> lhsVals, rhsVals, resVals;
2866  // Extract lhs slice of size {n, wSizeStep, c}
2867  // @ [0, sw * w + dw * kw, 0].
2868  for (int64_t kw = 0; kw < kwSize; ++kw) {
2869  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2870  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
2871  loc, lhs,
2872  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
2873  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
2874  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
2875  }
2876  }
2877  // Extract rhs slice of size {c} @ [kw].
2878  for (int64_t kw = 0; kw < kwSize; ++kw) {
2879  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
2880  loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
2881  }
2882  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
2883  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2884  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
2885  loc, res,
2886  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
2887  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
2888  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
2889  }
2890 
2891  auto linearIndex = [&](int64_t kw, int64_t w) {
2892  return kw * (wSize / wSizeStep) + w;
2893  };
2894 
2895  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
2896  for (int64_t kw = 0; kw < kwSize; ++kw) {
2897  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2898  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
2899  lhsVals[linearIndex(kw, w)],
2900  rhsVals[kw], resVals[w]);
2901  }
2902  }
2903 
2904  // Its possible we failed to create the Fma.
2905  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
2906  // Manually revert (in reverse order) to avoid leaving a bad IR state.
2907  for (auto &collection :
2908  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
2909  for (Value v : collection)
2910  rewriter.eraseOp(v.getDefiningOp());
2911  return rewriter.notifyMatchFailure(op, "failed to create FMA");
2912  }
2913 
2914  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
2915  // This does not depend on kw.
2916  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2917  res = rewriter.create<vector::InsertStridedSliceOp>(
2918  loc, resVals[w], res,
2919  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
2920  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
2921  }
2922  //===------------------------------------------------------------------===//
2923  // End vector-only rewrite part
2924  //===------------------------------------------------------------------===//
2925 
2926  // Write back res slice of size {n, w, c} @ [0, 0, 0].
2927  return rewriter
2928  .create<vector::TransferWriteOp>(loc, res, resShaped,
2929  ValueRange{zero, zero, zero})
2930  .getOperation();
2931  }
2932 
2933  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
2934  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
2935  Value lhs, Value rhs, Value res) {
2936  auto rhsTy = cast<ShapedType>(rhs.getType());
2937  auto resTy = cast<ShapedType>(res.getType());
2938 
2939  // TODO(suderman): Change this to use a vector.ima intrinsic.
2940  lhs = promote(rewriter, loc, lhs, resTy);
2941 
2942  rhs = rewriter.create<vector::BroadcastOp>(
2943  loc, resTy.clone(rhsTy.getElementType()), rhs);
2944  rhs = promote(rewriter, loc, rhs, resTy);
2945 
2946  if (!lhs || !rhs)
2947  return nullptr;
2948 
2949  if (isa<FloatType>(resTy.getElementType()))
2950  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
2951 
2952  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
2953  return rewriter.create<arith::AddIOp>(loc, mul, res);
2954  }
2955 
2956  /// Entry point for non-channeled convolution:
2957  /// {{w + kw}, {kw}, {w}}
2958  FailureOr<Operation *> generateNonChanneledConv() {
2959  AffineExpr w, kw;
2960  bindDims(ctx, w, kw);
2961  if (!iters({Par(), Red()}))
2962  return rewriter.notifyMatchFailure(op,
2963  "failed to match conv::W 1-par 1-red");
2964 
2965  // No transposition needed.
2966  if (layout({/*lhsIndex*/ {w + kw},
2967  /*rhsIndex*/ {kw},
2968  /*resIndex*/ {w}}))
2969  return conv(Conv1DOpOrder::W);
2970 
2971  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
2972  }
2973 
2974  /// Entry point that transposes into the common form:
2975  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2976  FailureOr<Operation *> generateNwcConv() {
2977  AffineExpr n, w, f, kw, c;
2978  bindDims(ctx, n, w, f, kw, c);
2979  if (!iters({Par(), Par(), Par(), Red(), Red()}))
2980  return rewriter.notifyMatchFailure(
2981  op, "failed to match conv::Nwc 3-par 2-red");
2982 
2983  // No transposition needed.
2984  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
2985  /*rhsIndex*/ {kw, c, f},
2986  /*resIndex*/ {n, w, f}}))
2987  return conv(Conv1DOpOrder::Nwc);
2988 
2989  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
2990  }
2991 
2992  /// Entry point that transposes into the common form:
2993  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2994  FailureOr<Operation *> generateNcwConv() {
2995  AffineExpr n, w, f, kw, c;
2996  bindDims(ctx, n, f, w, c, kw);
2997  if (!iters({Par(), Par(), Par(), Red(), Red()}))
2998  return rewriter.notifyMatchFailure(
2999  op, "failed to match conv::Ncw 3-par 2-red");
3000 
3001  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3002  /*rhsIndex*/ {f, c, kw},
3003  /*resIndex*/ {n, f, w}}))
3004  return conv(Conv1DOpOrder::Ncw);
3005 
3006  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3007  }
3008 
3009  /// Entry point that transposes into the common form:
3010  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3011  FailureOr<Operation *> generateNwcPooling() {
3012  AffineExpr n, w, c, kw;
3013  bindDims(ctx, n, w, c, kw);
3014  if (!iters({Par(), Par(), Par(), Red()}))
3015  return rewriter.notifyMatchFailure(op,
3016  "failed to match pooling 3-par 1-red");
3017 
3018  // No transposition needed.
3019  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3020  /*rhsIndex*/ {kw},
3021  /*resIndex*/ {n, w, c}}))
3022  return conv(Conv1DOpOrder::Nwc);
3023 
3024  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3025  }
3026 
3027  /// Entry point that transposes into the common form:
3028  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3029  FailureOr<Operation *> generateNcwPooling() {
3030  AffineExpr n, w, c, kw;
3031  bindDims(ctx, n, c, w, kw);
3032  if (!iters({Par(), Par(), Par(), Red()}))
3033  return rewriter.notifyMatchFailure(op,
3034  "failed to match pooling 3-par 1-red");
3035 
3036  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3037  /*rhsIndex*/ {kw},
3038  /*resIndex*/ {n, c, w}}))
3039  return conv(Conv1DOpOrder::Ncw);
3040 
3041  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3042  }
3043 
3044  /// Entry point that transposes into the common form:
3045  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3046  FailureOr<Operation *> generateDilatedConv() {
3047  AffineExpr n, w, c, kw;
3048  bindDims(ctx, n, w, c, kw);
3049  if (!iters({Par(), Par(), Par(), Red()}))
3050  return rewriter.notifyMatchFailure(
3051  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3052 
3053  // No transposition needed.
3054  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3055  /*rhsIndex*/ {kw, c},
3056  /*resIndex*/ {n, w, c}}))
3057  return depthwiseConv();
3058 
3059  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3060  }
3061 
3062 private:
3063  enum OperKind { Conv, Pool };
3064  bool valid = false;
3065  OperKind oper = Conv;
3066  StringAttr redOp;
3067  StringAttr poolExtOp;
3068  bool isPoolExt = false;
3069  int strideW, dilationW;
3070  Value lhsShaped, rhsShaped, resShaped;
3071  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3072 
3073  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3074  // Returns true iff it is a valid conv/pooling op.
3075  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3076  // + yield) and rhs is not used) then it is the body of a pooling
3077  // If conv, check for single `mul` predecessor. The `mul` operands must be
3078  // block arguments or extension of block arguments.
3079  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3080  // must be block arguments or extension of block arguments.
3081  bool setOperKind(Operation *reduceOp) {
3082  int numBlockArguments = llvm::count_if(
3083  reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
3084  switch (numBlockArguments) {
3085  case 1: {
3086  // Will be convolution if feeder is a MulOp.
3087  // Otherwise, if it can be pooling.
3088  auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
3089  return !isa<BlockArgument>(v);
3090  });
3091  Operation *feedOp = (*feedValIt).getDefiningOp();
3092  if (isCastOfBlockArgument(feedOp)) {
3093  oper = Pool;
3094  isPoolExt = true;
3095  poolExtOp = feedOp->getName().getIdentifier();
3096  } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3097  llvm::all_of(feedOp->getOperands(), [](Value v) {
3098  if (isa<BlockArgument>(v))
3099  return true;
3100  if (Operation *op = v.getDefiningOp())
3101  return isCastOfBlockArgument(op);
3102  return false;
3103  }))) {
3104  return false;
3105  }
3106  return true;
3107  }
3108  case 2:
3109  // Must be pooling
3110  oper = Pool;
3111  isPoolExt = false;
3112  return true;
3113  default:
3114  return false;
3115  }
3116  }
3117 };
3118 } // namespace
3119 
3120 /// Helper function to vectorize a LinalgOp with convolution semantics.
3121 // TODO: extend the generic vectorization to support windows and drop this.
3123  LinalgOp op) {
3124  // The ConvolutionOpInterface gives us guarantees of existence for
3125  // strides/dilations. However, we do not need to rely on those, we can simply
3126  // use them if present, otherwise use the default and let the generic conv.
3127  // matcher in the ConvGenerator succeed or fail.
3128  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3129  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3130  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3131  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3132  Conv1DGenerator e(rewriter, op, stride, dilation);
3133  auto res = e.generateNonChanneledConv();
3134  if (succeeded(res))
3135  return res;
3136  res = e.generateNwcConv();
3137  if (succeeded(res))
3138  return res;
3139  res = e.generateNcwConv();
3140  if (succeeded(res))
3141  return res;
3142  res = e.generateNwcPooling();
3143  if (succeeded(res))
3144  return res;
3145  res = e.generateNcwPooling();
3146  if (succeeded(res))
3147  return res;
3148  return e.generateDilatedConv();
3149 }
3150 
3152  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3153 
3155  PatternRewriter &rewriter) const override {
3156  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3157  if (failed(resultOrFail))
3158  return failure();
3159  Operation *newOp = *resultOrFail;
3160  if (newOp->getNumResults() == 0) {
3161  rewriter.eraseOp(op.getOperation());
3162  return success();
3163  }
3164  assert(newOp->getNumResults() == 1 && "expected single result");
3165  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3166  return success();
3167  }
3168 };
3169 
3171  RewritePatternSet &patterns, PatternBenefit benefit) {
3172  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3173 }
static ArrayRef< int64_t > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract)
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp)
Try to vectorize convOp as a convolution.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Check whether extractOp would be a gather or a contiguous load Op after vectorising linalgOp.
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:109
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:273
MLIRContext * getContext() const
Definition: AffineMap.cpp:284
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:275
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:534
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
unsigned getNumInputs() const
Definition: AffineMap.cpp:346
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map witn numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:116
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:225
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:495
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:564
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:322
Block represents an ordered list of Operations.
Definition: Block.h:30
RetT walk(FnT &&callback)
Walk the operations in this block.
Definition: Block.h:280
OpListType & getOperations()
Definition: Block.h:130
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
IntegerType getI32Type()
Definition: Builders.cpp:83
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:150
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:511
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:469
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:648
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:646
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:206
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1344
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:213
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false)
Return success if the operation can be vectorized.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false)
Emit a suitable vector form for an operation.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:571
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:42
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:59
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2013
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:624
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:704
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:331
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:680
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:63
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:638
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1291
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.