MLIR  19.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66  OpType res;
67  block.walk([&](OpType op) {
68  if (res) {
69  res = nullptr;
70  return WalkResult::interrupt();
71  }
72  res = op;
73  return WalkResult::advance();
74  });
75  return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
82  int64_t nSize, int64_t wSize, int64_t cSize,
83  int64_t kwSize, int strideW, int dilationW,
84  int64_t wSizeStep, bool isSingleChanneled) {
85  SmallVector<Value> result;
86  if (isSingleChanneled) {
87  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88  // convolution.
89  SmallVector<int64_t> sizes{wSizeStep};
90  SmallVector<int64_t> strides{1};
91  for (int64_t kw = 0; kw < kwSize; ++kw) {
92  for (int64_t w = 0; w < wSize; w += wSizeStep) {
93  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95  }
96  }
97  } else {
98  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99  // for channeled convolution.
100  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101  SmallVector<int64_t> strides{1, 1, 1};
102  for (int64_t kw = 0; kw < kwSize; ++kw) {
103  for (int64_t w = 0; w < wSize; w += wSizeStep) {
104  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105  loc, input,
106  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107  sizes, strides));
108  }
109  }
110  }
111  return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
117  Location loc, Value filter,
118  int64_t kwSize) {
119  SmallVector<Value> result;
120  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121  // non-chanelled convolution] @ [kw].
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  result.push_back(rewriter.create<vector::ExtractOp>(
124  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125  }
126  return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
133  int64_t nSize, int64_t wSize, int64_t fSize,
134  int64_t wSizeStep, bool isSingleChanneled) {
135  SmallVector<Value> result;
136  if (isSingleChanneled) {
137  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138  SmallVector<int64_t> sizes{wSizeStep};
139  SmallVector<int64_t> strides{1};
140  for (int64_t w = 0; w < wSize; w += wSizeStep) {
141  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143  }
144  } else {
145  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146  // convolution.
147  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148  SmallVector<int64_t> strides{1, 1, 1};
149  for (int64_t w = 0; w < wSize; w += wSizeStep) {
150  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152  }
153  }
154  return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
159  Value res, int64_t wSize, int64_t wSizeStep,
160  SmallVectorImpl<Value> &resVals,
161  bool isSingleChanneled) {
162 
163  if (isSingleChanneled) {
164  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165  // This does not depend on kw.
166  SmallVector<int64_t> strides{1};
167  for (int64_t w = 0; w < wSize; w += wSizeStep) {
168  res = rewriter.create<vector::InsertStridedSliceOp>(
169  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170  }
171  } else {
172  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173  // convolution. This does not depend on kw.
174  SmallVector<int64_t> strides{1, 1, 1};
175  for (int64_t w = 0; w < wSize; w += wSizeStep) {
176  res = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178  strides);
179  }
180  }
181  return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
187  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189  /// Initializes the vectorization state, including the computation of the
190  /// canonical vector shape for vectorization.
191  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192  ArrayRef<int64_t> inputVectorSizes,
193  ArrayRef<bool> inputScalableVecDims);
194 
195  /// Returns the canonical vector shape used to vectorize the iteration space.
196  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198  /// Returns a vector type of the provided `elementType` with the canonical
199  /// vector shape and the corresponding fixed/scalable dimensions bit. If
200  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
201  /// accordingly.
203  Type elementType,
204  std::optional<AffineMap> dimPermutation = std::nullopt) const {
206  SmallVector<bool> scalableDims;
207  if (dimPermutation.has_value()) {
208  vectorShape =
209  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
210  scalableDims =
211  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
212  } else {
213  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
214  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
215  }
216 
217  return VectorType::get(vectorShape, elementType, scalableDims);
218  }
219 
220  /// Masks an operation with the canonical vector mask if the operation needs
221  /// masking. Returns the masked operation or the original operation if masking
222  /// is not needed. If provided, the canonical mask for this operation is
223  /// permuted using `maybeMaskingMap`.
224  Operation *
225  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
226  std::optional<AffineMap> maybeMaskingMap = std::nullopt);
227 
228 private:
229  /// Initializes the iteration space static sizes using the Linalg op
230  /// information. This may become more complicated in the future.
231  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
232  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
233  }
234 
235  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
236  /// all the static and dynamic dimensions of the iteration space to be
237  /// vectorized and store them in `iterSpaceValueSizes`.
238  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
239  LinalgOp linalgOp);
240 
241  /// Create or retrieve an existing mask value to mask `opToMask` in the
242  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
243  /// permuted using that permutation map. If a new mask is created, it will be
244  /// cached for future users.
245  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
246  LinalgOp linalgOp,
247  std::optional<AffineMap> maybeMaskingMap);
248 
249  // Holds the compile-time static sizes of the iteration space to vectorize.
250  // Dynamic dimensions are represented using ShapedType::kDynamic.
251  SmallVector<int64_t> iterSpaceStaticSizes;
252 
253  /// Holds the value sizes of the iteration space to vectorize. Static
254  /// dimensions are represented by 'arith.constant' and dynamic
255  /// dimensions by 'tensor/memref.dim'.
256  SmallVector<Value> iterSpaceValueSizes;
257 
258  /// Holds the canonical vector shape used to vectorize the iteration space.
259  SmallVector<int64_t> canonicalVecShape;
260 
261  /// Holds the vector dimensions that are scalable in the canonical vector
262  /// shape.
263  SmallVector<bool> scalableVecDims;
264 
265  /// Holds the active masks for permutations of the canonical vector iteration
266  /// space.
267  DenseMap<AffineMap, Value> activeMaskCache;
268 
269  /// Global vectorization guard for the incoming rewriter. It's initialized
270  /// when the vectorization state is initialized.
272 };
273 
275 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276  LinalgOp linalgOp) {
277  // TODO: Support 0-d vectors.
278  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
279  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
280  // Create constant index op for static dimensions.
281  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
282  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
283  continue;
284  }
285 
286  // Find an operand defined on this dimension of the iteration space to
287  // extract the runtime dimension size.
288  Value operand;
289  unsigned operandDimPos;
290  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
291  operandDimPos)))
292  return failure();
293 
294  Value dynamicDim = linalgOp.hasPureTensorSemantics()
295  ? (Value)rewriter.create<tensor::DimOp>(
296  linalgOp.getLoc(), operand, operandDimPos)
297  : (Value)rewriter.create<memref::DimOp>(
298  linalgOp.getLoc(), operand, operandDimPos);
299  iterSpaceValueSizes.push_back(dynamicDim);
300  }
301 
302  return success();
303 }
304 
305 /// Initializes the vectorization state, including the computation of the
306 /// canonical vector shape for vectorization.
307 // TODO: Move this to the constructor when we can remove the failure cases.
309 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
310  ArrayRef<int64_t> inputVectorSizes,
311  ArrayRef<bool> inputScalableVecDims) {
312  // Initialize the insertion point.
313  rewriter.setInsertionPoint(linalgOp);
314 
315  if (!inputVectorSizes.empty()) {
316  // Get the canonical vector shape from the input vector sizes provided. This
317  // path should be taken to vectorize code with dynamic shapes and when using
318  // vector sizes greater than the iteration space sizes.
319  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
320  scalableVecDims.append(inputScalableVecDims.begin(),
321  inputScalableVecDims.end());
322  } else {
323  // Compute the canonical vector shape from the operation shape. If there are
324  // dynamic shapes, the operation won't be vectorized. We assume all the
325  // vector dimensions are fixed.
326  canonicalVecShape = linalgOp.getStaticLoopRanges();
327  scalableVecDims.append(linalgOp.getNumLoops(), false);
328  }
329 
330  LDBG("Canonical vector shape: ");
331  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
332  LLVM_DEBUG(llvm::dbgs() << "\n");
333  LDBG("Scalable vector dims: ");
334  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
335  LLVM_DEBUG(llvm::dbgs() << "\n");
336 
337  if (ShapedType::isDynamicShape(canonicalVecShape))
338  return failure();
339 
340  // Initialize iteration space static sizes.
341  initIterSpaceStaticSizes(linalgOp);
342 
343  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
344  // all the static and dynamic dimensions of the iteration space, needed to
345  // compute a mask during vectorization.
346  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
347  return failure();
348 
349  return success();
350 }
351 
352 /// Create or retrieve an existing mask value to mask `opToMask` in the
353 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
354 /// using that permutation map. If a new mask is created, it will be cached for
355 /// future users.
356 Value VectorizationState::getOrCreateMaskFor(
357  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
358  std::optional<AffineMap> maybeMaskingMap) {
359  // No mask is needed if the operation is not maskable.
360  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
361  if (!maskableOp)
362  return Value();
363 
364  assert(!maskableOp.isMasked() &&
365  "Masking an operation that is already masked");
366 
367  // If no masking map was provided, use an identity map with the loop dims.
368  assert((!maybeMaskingMap || *maybeMaskingMap) &&
369  "Unexpected null mask permutation map");
370  AffineMap maskingMap =
371  maybeMaskingMap ? *maybeMaskingMap
373  linalgOp.getNumLoops(), rewriter.getContext());
374 
375  LDBG("Masking map: " << maskingMap << "\n");
376 
377  // Return the active mask for the masking map of this operation if it was
378  // already created.
379  auto activeMaskIt = activeMaskCache.find(maskingMap);
380  if (activeMaskIt != activeMaskCache.end()) {
381  Value mask = activeMaskIt->second;
382  LDBG("Reusing mask: " << mask << "\n");
383  return mask;
384  }
385 
386  // Compute permuted projection of the iteration space to be masked and the
387  // corresponding mask shape. If the resulting iteration space dimensions are
388  // static and identical to the mask shape, masking is not needed for this
389  // operation.
390  // TODO: Improve this check. Only projected permutation indexing maps are
391  // supported.
392  SmallVector<int64_t> permutedStaticSizes =
393  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
394  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
395  auto maskShape = maskType.getShape();
396 
397  LDBG("Mask shape: ");
398  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
399  LLVM_DEBUG(llvm::dbgs() << "\n");
400 
401  if (permutedStaticSizes == maskShape) {
402  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
403  activeMaskCache[maskingMap] = Value();
404  return Value();
405  }
406 
407  // Permute the iteration space value sizes to compute the mask upper bounds.
408  SmallVector<Value> upperBounds =
409  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
410  assert(!maskShape.empty() && !upperBounds.empty() &&
411  "Masked 0-d vectors are not supported yet");
412 
413  // Create the mask based on the dimension values.
414  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
415  maskType, upperBounds);
416  LDBG("Creating new mask: " << mask << "\n");
417  activeMaskCache[maskingMap] = mask;
418  return mask;
419 }
420 
421 /// Masks an operation with the canonical vector mask if the operation needs
422 /// masking. Returns the masked operation or the original operation if masking
423 /// is not needed. If provided, the canonical mask for this operation is
424 /// permuted using `maybeMaskingMap`.
425 Operation *
427  LinalgOp linalgOp,
428  std::optional<AffineMap> maybeMaskingMap) {
429  LDBG("Trying to mask: " << *opToMask << "\n");
430 
431  // Create or retrieve mask for this operation.
432  Value mask =
433  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
434 
435  if (!mask) {
436  LDBG("No mask required\n");
437  return opToMask;
438  }
439 
440  // Wrap the operation with a new `vector.mask` and update D-U chain.
441  assert(opToMask && "Expected a valid operation to mask");
442  auto maskOp = cast<vector::MaskOp>(
443  mlir::vector::maskOperation(rewriter, opToMask, mask));
444  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
445 
446  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
447  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
448  maskOpTerminator);
449 
450  LDBG("Masked operation: " << *maskOp << "\n");
451  return maskOp;
452 }
453 
454 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
455 /// projectedPermutation, compress the unused dimensions to serve as a
456 /// permutation_map for a vector transfer operation.
457 /// For example, given a linalg op such as:
458 ///
459 /// ```
460 /// %0 = linalg.generic {
461 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
462 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
463 /// }
464 /// ins(%0 : tensor<2x3x4xf32>)
465 /// outs(%1 : tensor<5x6xf32>)
466 /// ```
467 ///
468 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
469 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
470 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
472  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
473  "expected projected permutation");
474  auto res = compressUnusedDims(map);
475  assert(res.getNumDims() == res.getNumResults() &&
476  "expected reindexed map with same number of dims and results");
477  return res;
478 }
479 
480 /// Helper enum to represent conv1d input traversal order.
481 enum class Conv1DOpOrder {
482  W, // Corresponds to non-channeled 1D convolution operation.
483  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
484  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
485 };
486 
487 /// Helper data structure to represent the result of vectorization.
488 /// In certain specific cases, like terminators, we do not want to propagate/
490  /// Op failed to vectorize.
491  Failure = 0,
492  /// Op vectorized and custom function took care of replacement logic
494  /// Op vectorized into a new Op whose results will replace original Op's
495  /// results.
496  NewOp
497  // TODO: support values if Op vectorized to Many-Ops whose results we need to
498  // aggregate for replacement.
499 };
501  /// Return status from vectorizing the current op.
503  /// New vectorized operation to replace the current op.
504  /// Replacement behavior is specified by `status`.
506 };
507 
508 std::optional<vector::CombiningKind>
510  using ::mlir::vector::CombiningKind;
511 
512  if (!combinerOp)
513  return std::nullopt;
515  .Case<arith::AddIOp, arith::AddFOp>(
516  [&](auto op) { return CombiningKind::ADD; })
517  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
518  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
519  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
520  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
521  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
522  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
523  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
524  .Case<arith::MulIOp, arith::MulFOp>(
525  [&](auto op) { return CombiningKind::MUL; })
526  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
527  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
528  .Default([&](auto op) { return std::nullopt; });
529 }
530 
531 /// Check whether `outputOperand` is a reduction with a single combiner
532 /// operation. Return the combiner operation of the reduction. Return
533 /// nullptr otherwise. Multiple reduction operations would impose an
534 /// ordering between reduction dimensions and is currently unsupported in
535 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
536 /// max(min(X))
537 // TODO: use in LinalgOp verification, there is a circular dependency atm.
538 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
539  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
540  unsigned outputPos =
541  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
542  // Only single combiner operations are supported for now.
543  SmallVector<Operation *, 4> combinerOps;
544  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
545  combinerOps.size() != 1)
546  return nullptr;
547 
548  // Return the combiner operation.
549  return combinerOps[0];
550 }
551 
552 /// Broadcast `value` to a vector of `shape` if possible. Return value
553 /// otherwise.
554 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
555  auto dstVecType = dyn_cast<VectorType>(dstType);
556  // If no shape to broadcast to, just return `value`.
557  if (dstVecType.getRank() == 0)
558  return value;
559  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
561  return value;
562  Location loc = b.getInsertionPoint()->getLoc();
563  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
564 }
565 
566 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
567 /// assumes that `reductionOp` has two operands and one of them is the reduction
568 /// initial value.buildMultiDimReduce
569 // Note: this is a true builder that notifies the OpBuilder listener.
570 // TODO: Consider moving as a static helper on the ReduceOp.
572  Value valueToReduce, Value acc,
573  ArrayRef<bool> dimsToMask) {
574  auto maybeKind = getCombinerOpKind(reduceOp);
575  assert(maybeKind && "Failed precondition: could not get reduction kind");
576  return b.create<vector::MultiDimReductionOp>(
577  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
578 }
579 
580 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
581  return llvm::to_vector(
582  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
583 }
584 
585 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
586 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
587 /// currently being vectorized. If `dest` has null rank, build an memref.store.
588 /// Return the produced value or null if no value is produced.
589 // Note: this is a true builder that notifies the OpBuilder listener.
590 // TODO: Consider moving as a static helper on the ReduceOp.
591 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
592  OpOperand *outputOperand,
593  VectorizationState &state) {
594  Location loc = value.getLoc();
595  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
596  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
597 
598  // Compute the vector type of the value to store. This type should be an
599  // identity or projection of the canonical vector type without any permutation
600  // applied, given that any permutation in a transfer write happens as part of
601  // the write itself.
603  opOperandMap.getContext(), opOperandMap.getNumInputs(),
604  [&](AffineDimExpr dimExpr) -> bool {
605  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
606  });
607  auto vectorType = state.getCanonicalVecType(
608  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
609 
610  Operation *write;
611  if (vectorType.getRank() > 0) {
612  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
613  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
614  rewriter.create<arith::ConstantIndexOp>(loc, 0));
615  value = broadcastIfNeeded(rewriter, value, vectorType);
616  assert(value.getType() == vectorType && "Incorrect type");
617  write = rewriter.create<vector::TransferWriteOp>(
618  loc, value, outputOperand->get(), indices, writeMap);
619  } else {
620  // 0-d case is still special: do not invert the reindexing writeMap.
621  if (!isa<VectorType>(value.getType()))
622  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
623  assert(value.getType() == vectorType && "Incorrect type");
624  write = rewriter.create<vector::TransferWriteOp>(
625  loc, value, outputOperand->get(), ValueRange{});
626  }
627 
628  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
629 
630  // If masked, set in-bounds to true. Masking guarantees that the access will
631  // be in-bounds.
632  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
633  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
634  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
635  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
636  }
637 
638  LDBG("vectorized op: " << *write << "\n");
639  if (!write->getResults().empty())
640  return write->getResult(0);
641  return Value();
642 }
643 
644 // Custom vectorization precondition function type. This is intented to be used
645 // with CustomVectorizationHook. Returns success if the corresponding custom
646 // hook can vectorize the op.
648  std::function<LogicalResult(Operation *, bool)>;
649 
650 // Custom vectorization function type. Produce a vector form of Operation*
651 // assuming all its vectorized operands are already in the IRMapping.
652 // Return nullptr if the Operation cannot be vectorized.
654  std::function<VectorizationResult(Operation *, const IRMapping &)>;
655 
656 /// Helper function to vectorize the terminator of a `linalgOp`. New result
657 /// vector values are appended to `newResults`. Return
658 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
659 /// should not try to map produced operations and instead return the results
660 /// using the `newResults` vector making them available to the vectorization
661 /// algorithm for RAUW. This function is meant to be used as a
662 /// CustomVectorizationHook.
663 static VectorizationResult
665  const IRMapping &bvm, VectorizationState &state,
666  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
667  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
668  if (!yieldOp)
670  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
671  // TODO: Scan for an opportunity for reuse.
672  // TODO: use a map.
673  Value vectorValue = bvm.lookup(output.value());
674  Value newResult =
675  buildVectorWrite(rewriter, vectorValue,
676  linalgOp.getDpsInitOperand(output.index()), state);
677  if (newResult)
678  newResults.push_back(newResult);
679  }
680 
682 }
683 
684 /// Helper function to vectorize the index operations of a `linalgOp`. Return
685 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
686 /// should map the produced operations. This function is meant to be used as a
687 /// CustomVectorizationHook.
689  VectorizationState &state,
690  Operation *op,
691  LinalgOp linalgOp) {
692  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
693  if (!indexOp)
695  auto loc = indexOp.getLoc();
696  // Compute the static loop sizes of the index op.
697  auto targetShape = state.getCanonicalVecShape();
698  // Compute a one-dimensional index vector for the index op dimension.
699  auto constantSeq =
700  llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
701  auto indexSteps = rewriter.create<arith::ConstantOp>(
702  loc, rewriter.getIndexVectorAttr(constantSeq));
703  // Return the one-dimensional index vector if it lives in the trailing
704  // dimension of the iteration space since the vectorization algorithm in this
705  // case can handle the broadcast.
706  if (indexOp.getDim() == targetShape.size() - 1)
708  // Otherwise permute the targetShape to move the index dimension last,
709  // broadcast the one-dimensional index vector to the permuted shape, and
710  // finally transpose the broadcasted index vector to undo the permutation.
711  auto permPattern =
712  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
713  std::swap(permPattern[indexOp.getDim()], permPattern.back());
714  auto permMap =
715  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
716 
717  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
718  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
719  indexSteps);
720  SmallVector<int64_t> transposition =
721  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
722  std::swap(transposition.back(), transposition[indexOp.getDim()]);
723  auto transposeOp =
724  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
725  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
726 }
727 
728 /// Helper function to check if the tensor.extract can be vectorized by the
729 /// custom hook vectorizeTensorExtract.
730 static LogicalResult
731 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
732  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
733  if (!extractOp)
734  return failure();
735 
736  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
737  return failure();
738 
739  // Check the index type, but only for non 0-d tensors (for which we do need
740  // access indices).
741  if (not extractOp.getIndices().empty()) {
742  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
743  return failure();
744  }
745 
746  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
747  return !VectorType::isValidElementType(type);
748  })) {
749  return failure();
750  }
751 
752  return success();
753 }
754 
755 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
756 /// generated from `tensor.extract`. The offset is calculated as follows
757 /// (example using scalar values):
758 ///
759 /// offset = extractOp.indices[0]
760 /// for (i = 1; i < numIndices; i++)
761 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
762 ///
763 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
764 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
766  VectorizationState &state,
767  tensor::ExtractOp extractOp,
768  const IRMapping &bvm) {
769  // The vector of indices for GatherOp should be shaped as the output vector.
770  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
771  auto loc = extractOp.getLoc();
772 
773  Value offset = broadcastIfNeeded(
774  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
775 
776  const size_t numIndices = extractOp.getIndices().size();
777  for (size_t i = 1; i < numIndices; i++) {
778  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
779 
780  auto dimSize = broadcastIfNeeded(
781  rewriter,
782  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
783  indexVecType);
784 
785  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
786 
787  auto extractOpIndex = broadcastIfNeeded(
788  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
789 
790  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
791  }
792 
793  return offset;
794 }
795 
797 
798 /// Checks whether /p val can be used for calculating a loop invariant index.
799 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
800 
801  auto targetShape = linalgOp.getStaticLoopRanges();
802  assert(((llvm::count_if(targetShape,
803  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
804  "n-D vectors are not yet supported");
805  assert(targetShape.back() != 1 &&
806  "1-D vectors with the trailing dim eqaual 1 are not yet supported");
807 
808  // Blocks outside _this_ linalg.generic are effectively loop invariant.
809  // However, analysing block arguments for _this_ linalg.generic Op is a bit
810  // tricky. Just bail out in the latter case.
811  // TODO: We could try analysing the corresponding affine map here.
812  auto *block = linalgOp.getBlock();
813  if (isa<BlockArgument>(val))
814  return llvm::all_of(block->getArguments(),
815  [&val](Value v) { return (v != val); });
816 
817  Operation *defOp = val.getDefiningOp();
818  assert(defOp && "This is neither a block argument nor an operation result");
819 
820  // IndexOp is loop invariant as long as its result remains constant across
821  // iterations. Given the assumptions on the loop ranges above, only the
822  // trailing loop dim ever changes.
823  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
824  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
825  return (indexOp.getDim() != trailingLoopDim);
826 
827  auto *ancestor = block->findAncestorOpInBlock(*defOp);
828 
829  // Values define outside `linalgOp` are loop invariant.
830  if (!ancestor)
831  return true;
832 
833  // Values defined inside `linalgOp`, which are constant, are loop invariant.
834  if (isa<arith::ConstantOp>(ancestor))
835  return true;
836 
837  bool result = true;
838  for (auto op : ancestor->getOperands())
839  result &= isLoopInvariantIdx(linalgOp, op);
840 
841  return result;
842 }
843 
844 /// Check whether \p val could be used for calculating the trailing index for a
845 /// contiguous load operation.
846 ///
847 /// There are currently 3 types of values that are allowed here:
848 /// 1. loop-invariant values,
849 /// 2. values that increment by 1 with every loop iteration,
850 /// 3. results of basic arithmetic operations (linear and continuous)
851 /// involving 1., 2. and 3.
852 /// This method returns True if indeed only such values are used in calculating
853 /// \p val.
854 ///
855 /// Additionally, the trailing index for a contiguous load operation should
856 /// increment by 1 with every loop iteration, i.e. be based on:
857 /// * `linalg.index <dim>` ,
858 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
859 /// updated to `true` when such an op is found.
860 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
861  bool &foundIndexOp) {
862 
863  auto targetShape = linalgOp.getStaticLoopRanges();
864  assert(((llvm::count_if(targetShape,
865  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
866  "n-D vectors are not yet supported");
867  assert(targetShape.back() != 1 &&
868  "1-D vectors with the trailing dim 1 are not yet supported");
869 
870  // Blocks outside _this_ linalg.generic are effectively loop invariant.
871  // However, analysing block arguments for _this_ linalg.generic Op is a bit
872  // tricky. Just bail out in the latter case.
873  // TODO: We could try analysing the corresponding affine map here.
874  auto *block = linalgOp.getBlock();
875  if (isa<BlockArgument>(val))
876  return llvm::all_of(block->getArguments(),
877  [&val](Value v) { return (v != val); });
878 
879  Operation *defOp = val.getDefiningOp();
880  assert(defOp && "This is neither a block argument nor an operation result");
881 
882  // Given the assumption on the loop ranges above, only the trailing loop
883  // index is not constant.
884  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
885  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
886  foundIndexOp = (indexOp.getDim() == trailingLoopDim);
887  return true;
888  }
889 
890  auto *ancestor = block->findAncestorOpInBlock(*defOp);
891 
892  if (!ancestor)
893  return false;
894 
895  // Conservatively reject Ops that could lead to indices with stride other
896  // than 1.
897  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
898  return false;
899 
900  bool result = false;
901  for (auto op : ancestor->getOperands())
902  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
903 
904  return result;
905 }
906 
907 /// Infer the memory access pattern for the input ExtractOp
908 ///
909 /// Based on the operation shapes and indices (usually based on the iteration
910 /// space of the parent `linalgOp` operation), decides whether the input
911 /// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
912 /// gather load.
913 ///
914 /// Note that it is always safe to use gather load operations for contiguous
915 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
916 /// that `extractOp` is a gather load.
918 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
919  LinalgOp &linalgOp) {
920 
921  auto targetShape = linalgOp.getStaticLoopRanges();
922  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
923 
924  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
925  if (inputShape.getShape().empty())
927 
928  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
929  // gather load.
930  // TODO: Relax this condition.
931  if (linalgOp.hasDynamicShape())
933 
934  // 1. Assume that it's a gather load when reading _into_:
935  // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
936  // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
937  // TODO: Relax these conditions.
938  // FIXME: This condition assumes non-dynamic sizes.
939  if ((llvm::count_if(targetShape,
940  [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
941  targetShape.back() == 1)
943 
944  // 2. Assume that it's a gather load when reading _from_ a tensor for which
945  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
946  // TODO: Relax this condition.
947  if (inputShape.getShape().back() == 1)
949 
950  bool leadingIdxsLoopInvariant = true;
951 
952  // 3. Analyze the leading indices of `extractOp`.
953  // Look at the way each index is calculated and decide whether it is suitable
954  // for a contiguous load, i.e. whether it's loop invariant.
955  auto indices = extractOp.getIndices();
956  auto leadIndices = indices.drop_back(1);
957 
958  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
959  if (inputShape.getShape()[i] == 1)
960  continue;
961 
962  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
963  }
964 
965  if (!leadingIdxsLoopInvariant) {
966  LDBG("Found gather load: " << extractOp);
968  }
969 
970  // 4. Analyze the trailing index for `extractOp`.
971  // At this point we know that the leading indices are loop invariant. This
972  // means that is potentially a scalar or a contiguous load. We can decide
973  // based on the trailing idx.
974  auto extractOpTrailingIdx = indices.back();
975 
976  // 4a. Scalar broadcast load
977  // If the trailing index is loop invariant then this is a scalar load.
978  if (leadingIdxsLoopInvariant &&
979  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
980  LDBG("Found scalar broadcast load: " << extractOp);
981 
983  }
984 
985  // 4b. Contiguous loads
986  // The trailing `extractOp` index should increment with every loop iteration.
987  // This effectively means that it must be based on the trailing loop index.
988  // This is what the following bool captures.
989  bool foundIndexOp = false;
990  bool isContiguousLoad =
991  isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
992  isContiguousLoad &= foundIndexOp;
993 
994  if (isContiguousLoad) {
995  LDBG("Found contigous load: " << extractOp);
997  }
998 
999  // 5. Fallback case - gather load.
1000  LDBG("Found gather load: " << extractOp);
1002 }
1003 
1004 /// Helper function to vectorize the tensor.extract operations. Returns
1005 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1006 /// should map the produced operations. This function is meant to be used as a
1007 /// CustomVectorizationHook.
1008 static VectorizationResult
1010  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1011  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1012  if (!extractOp)
1014  auto loc = extractOp.getLoc();
1015 
1016  // Compute the static loop sizes of the extract op.
1017  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1018  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1019  loc,
1020  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1021  /*value=*/true));
1022  auto passThruConstantOp =
1023  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1024 
1025  // Base indices are currently set to 0. We will need to re-visit if more
1026  // generic scenarios are to be supported.
1027  SmallVector<Value> baseIndices(
1028  extractOp.getIndices().size(),
1029  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1030 
1031  VectorMemoryAccessKind memAccessKind =
1032  getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1033 
1034  // 1. Handle gather access
1035  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1036  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1037 
1038  // Generate the gather load
1039  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1040  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1041  maskConstantOp, passThruConstantOp);
1042  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1043 
1044  LDBG("Vectorised as gather load: " << extractOp << "\n");
1046  }
1047 
1048  // 2. Handle:
1049  // a. scalar loads + broadcast,
1050  // b. contiguous loads.
1051  // Both cases use vector.transfer_read.
1052 
1053  // Collect indices for `vector.transfer_read`. At this point, the indices will
1054  // either be scalars or would have been broadcast to vectors matching the
1055  // result type. For indices that are vectors, there are two options:
1056  // * for non-trailing indices, all elements are identical (contiguous
1057  // loads are identified by looking for non-trailing indices that are
1058  // invariant with respect to the corresponding linalg.generic), or
1059  // * for trailing indices, the index vector will contain values with stride
1060  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1061  // needed.
1062  // This means that
1063  // * for scalar indices - just re-use it,
1064  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1065  // (0th) element and use that.
1066  SmallVector<Value> transferReadIdxs;
1067  auto resTrailingDim = resultType.getShape().back();
1068  auto zero = rewriter.create<arith::ConstantOp>(
1069  loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1070  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1071  auto idx = bvm.lookup(extractOp.getIndices()[i]);
1072  if (idx.getType().isIndex()) {
1073  transferReadIdxs.push_back(idx);
1074  continue;
1075  }
1076 
1077  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1078  loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
1079  bvm.lookup(extractOp.getIndices()[i]));
1080  transferReadIdxs.push_back(
1081  rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1082  }
1083 
1084  // `tensor.extract_element` is always in-bounds, hence the following holds.
1085  auto dstRank = resultType.getRank();
1086  auto srcRank = extractOp.getTensor().getType().getRank();
1087  SmallVector<bool> inBounds(dstRank, true);
1088 
1089  // 2a. Handle scalar broadcast access.
1090  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1091  MLIRContext *ctx = rewriter.getContext();
1092  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1093  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1094 
1095  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1096  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1097  permutationMap, inBounds);
1098 
1099  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1100  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1101  }
1102 
1103  // 2b. Handle contiguous access.
1104  auto permutationMap = AffineMap::getMinorIdentityMap(
1105  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1106 
1107  int32_t rankDiff = dstRank - srcRank;
1108  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1109  // dims so that the ranks match. This is done by extending the map with 0s.
1110  // For example, for dstRank = 3, srcRank = 2, the following map created
1111  // above:
1112  // (d0, d1) --> (d0, d1)
1113  // is extended as:
1114  // (d0, d1) --> (0, d0, d1)
1115  while (rankDiff > 0) {
1116  permutationMap = permutationMap.insertResult(
1117  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1118  rankDiff--;
1119  }
1120 
1121  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1122  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1123  inBounds);
1124 
1125  LDBG("Vectorised as contiguous load: " << extractOp);
1126  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1127 }
1128 
1129 /// Emit reduction operations if the shapes of the value to reduce is different
1130 /// that the result shape.
1131 // Note: this is a true builder that notifies the OpBuilder listener.
1132 // TODO: Consider moving as a static helper on the ReduceOp.
1133 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1134  Value reduceValue, Value initialValue,
1135  const IRMapping &bvm) {
1136  Value reduceVec = bvm.lookup(reduceValue);
1137  Value outputVec = bvm.lookup(initialValue);
1138  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1139  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1140  // Reduce only if needed as the value may already have been reduce for
1141  // contraction vectorization.
1142  if (!reduceType ||
1143  (outputType && reduceType.getShape() == outputType.getShape()))
1144  return nullptr;
1145  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1146  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1147 }
1148 
1149 /// Generic vectorization for a single operation `op`, given already vectorized
1150 /// operands carried by `bvm`. Vectorization occurs as follows:
1151 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1152 /// result on success.
1153 /// 2. Clone any constant in the current scope without vectorization: each
1154 /// consumer of the constant will later determine the shape to which the
1155 /// constant needs to be broadcast to.
1156 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1157 /// of the `customVectorizationHooks` to cover such cases.
1158 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1159 /// operand of maximal rank. Other operands have smaller rank and are
1160 /// broadcast accordingly. It is assumed this broadcast is always legal,
1161 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1162 ///
1163 /// This function assumes all operands of `op` have been vectorized and are in
1164 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1165 /// a topologically-sorted list of ops.
1166 /// This function does not update `bvm` but returns a VectorizationStatus that
1167 /// instructs the caller what `bvm` update needs to occur.
1168 static VectorizationResult
1170  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1171  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1172  LDBG("vectorize op " << *op << "\n");
1173 
1174  // 1. Try to apply any CustomVectorizationHook.
1175  if (!customVectorizationHooks.empty()) {
1176  for (auto &customFunc : customVectorizationHooks) {
1177  VectorizationResult result = customFunc(op, bvm);
1178  if (result.status == VectorizationStatus::Failure)
1179  continue;
1180  return result;
1181  }
1182  }
1183 
1184  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1185  // Clone so that the constant is not confined to the linalgOp block .
1186  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1187  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1188 
1189  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1192 
1193  // 4 . Check if the operation is a reduction.
1194  SmallVector<std::pair<Value, Value>> reductionOperands;
1195  for (Value operand : op->getOperands()) {
1196  auto blockArg = dyn_cast<BlockArgument>(operand);
1197  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1198  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1199  continue;
1200  SmallVector<Operation *> reductionOps;
1201  Value reduceValue = matchReduction(
1202  linalgOp.getRegionOutputArgs(),
1203  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1204  if (!reduceValue)
1205  continue;
1206  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1207  }
1208  if (!reductionOperands.empty()) {
1209  assert(reductionOperands.size() == 1);
1210  Operation *reduceOp =
1211  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1212  reductionOperands[0].second, bvm);
1213  if (reduceOp)
1215  }
1216 
1217  // 5. Generic vectorization path for ElementwiseMappable ops.
1218  // a. Get the first max ranked shape.
1219  VectorType firstMaxRankedType;
1220  for (Value operand : op->getOperands()) {
1221  auto vecOperand = bvm.lookup(operand);
1222  assert(vecOperand && "Vector operand couldn't be found");
1223 
1224  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1225  if (vecType && (!firstMaxRankedType ||
1226  firstMaxRankedType.getRank() < vecType.getRank()))
1227  firstMaxRankedType = vecType;
1228  }
1229  // b. Broadcast each op if needed.
1230  SmallVector<Value> vecOperands;
1231  for (Value scalarOperand : op->getOperands()) {
1232  Value vecOperand = bvm.lookup(scalarOperand);
1233  assert(vecOperand && "Vector operand couldn't be found");
1234 
1235  if (firstMaxRankedType) {
1236  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1237  getElementTypeOrSelf(vecOperand.getType()),
1238  firstMaxRankedType.getScalableDims());
1239  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1240  } else {
1241  vecOperands.push_back(vecOperand);
1242  }
1243  }
1244  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1245  SmallVector<Type> resultTypes;
1246  for (Type resultType : op->getResultTypes()) {
1247  resultTypes.push_back(
1248  firstMaxRankedType
1249  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1250  firstMaxRankedType.getScalableDims())
1251  : resultType);
1252  }
1253  // d. Build and return the new op.
1254  return VectorizationResult{
1256  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1257  resultTypes, op->getAttrs())};
1258 }
1259 
1260 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1261 /// vector form. Generic vectorization proceeds as follows:
1262 /// 1. Verify the `linalgOp` has one non-empty region.
1263 /// 2. Values defined above the region are mapped to themselves and will be
1264 /// broadcasted on a per-need basis by their consumers.
1265 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1266 /// load).
1267 /// TODO: Reuse opportunities for RAR dependencies.
1268 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1269 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1270 /// iteration indices.
1271 /// 5. Iteratively call vectorizeOneOp on the region operations.
1272 ///
1273 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1274 /// performed to the maximal common vector size implied by the `linalgOp`
1275 /// iteration space. This eager broadcasting is introduced in the
1276 /// permutation_map of the vector.transfer_read operations. The eager
1277 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1278 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1279 /// the absence of good canonicalizations, the amount of work increases.
1280 /// This is not deemed a problem as we expect canonicalizations and foldings to
1281 /// aggressively clean up the useless work.
1282 static LogicalResult
1284  LinalgOp linalgOp,
1285  SmallVectorImpl<Value> &newResults) {
1286  LDBG("Vectorizing operation as linalg generic\n");
1287  Block *block = linalgOp.getBlock();
1288 
1289  // 2. Values defined above the region can only be broadcast for now. Make them
1290  // map to themselves.
1291  IRMapping bvm;
1292  SetVector<Value> valuesSet;
1293  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1294  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1295 
1296  if (linalgOp.getNumDpsInits() == 0)
1297  return failure();
1298 
1299  // 3. Turn all BBArgs into vector.transfer_read / load.
1300  Location loc = linalgOp.getLoc();
1301  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1302  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1303  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1304  if (linalgOp.isScalar(opOperand)) {
1305  bvm.map(bbarg, opOperand->get());
1306  continue;
1307  }
1308 
1309  // 3.a. Convert the indexing map for this input/output to a transfer read
1310  // permutation map and masking map.
1311  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1312 
1313  // Remove zeros from indexing map to use it as masking map.
1314  SmallVector<int64_t> zeroPos;
1315  auto results = indexingMap.getResults();
1316  for (const auto &result : llvm::enumerate(results)) {
1317  if (isa<AffineConstantExpr>(result.value())) {
1318  zeroPos.push_back(result.index());
1319  }
1320  }
1321  AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1322 
1323  AffineMap readMap;
1324  VectorType readType;
1325  Type elemType = getElementTypeOrSelf(opOperand->get());
1326  if (linalgOp.isDpsInput(opOperand)) {
1327  // 3.a.i. For input reads we use the canonical vector shape.
1328  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1329  readType = state.getCanonicalVecType(elemType);
1330  } else {
1331  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1332  // reductions), the vector shape is computed by mapping the canonical
1333  // vector shape to the output domain and back to the canonical domain.
1334  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1335  readType =
1336  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1337  }
1338 
1339  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1340 
1341  Operation *read = rewriter.create<vector::TransferReadOp>(
1342  loc, readType, opOperand->get(), indices, readMap);
1343  read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1344  Value readValue = read->getResult(0);
1345 
1346  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1347  // will be in-bounds.
1348  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1349  SmallVector<bool> inBounds(readType.getRank(), true);
1350  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1351  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1352  }
1353 
1354  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1355  // TODO: remove this.
1356  if (readType.getRank() == 0)
1357  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1358 
1359  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1360  << "\n");
1361  bvm.map(bbarg, readValue);
1362  bvm.map(opOperand->get(), readValue);
1363  }
1364 
1366  // 4a. Register CustomVectorizationHook for yieldOp.
1367  CustomVectorizationHook vectorizeYield =
1368  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1369  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1370  };
1371  hooks.push_back(vectorizeYield);
1372 
1373  // 4b. Register CustomVectorizationHook for indexOp.
1374  CustomVectorizationHook vectorizeIndex =
1375  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1376  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1377  };
1378  hooks.push_back(vectorizeIndex);
1379 
1380  // 4c. Register CustomVectorizationHook for extractOp.
1381  CustomVectorizationHook vectorizeExtract =
1382  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1383  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1384  };
1385  hooks.push_back(vectorizeExtract);
1386 
1387  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1388  for (Operation &op : block->getOperations()) {
1389  VectorizationResult result =
1390  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1391  if (result.status == VectorizationStatus::Failure) {
1392  LDBG("failed to vectorize: " << op << "\n");
1393  return failure();
1394  }
1395  if (result.status == VectorizationStatus::NewOp) {
1396  Operation *maybeMaskedOp =
1397  state.maskOperation(rewriter, result.newOp, linalgOp);
1398  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1399  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1400  }
1401  }
1402 
1403  return success();
1404 }
1405 
1406 /// Given a tensor::PackOp, return the `dest` shape before any packing
1407 /// permutations.
1408 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1409  ArrayRef<int64_t> destShape) {
1410  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1411 }
1412 
1413 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1414 /// create an empty destination tensor and create a TransferWriteOp from the
1415 /// input to the empty tensor. If the destination shape is not the same as the
1416 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417 /// mask for the write.
1419  Value input,
1420  SmallVector<OpFoldResult> destSizes,
1421  ArrayRef<int64_t> inputVectorSizes) {
1422  auto inputType = cast<VectorType>(input.getType());
1423  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1424  inputType.getElementType());
1425  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1426  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1427  Operation *write = builder.create<vector::TransferWriteOp>(
1428  loc,
1429  /*vector=*/input,
1430  /*source=*/dest,
1431  /*indices=*/SmallVector<Value>(rank, zero),
1432  /*inBounds=*/SmallVector<bool>(rank, true));
1433  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1434  assert(llvm::none_of(
1435  destShape.drop_front(inputVectorSizes.size()),
1436  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1437  "Only dims aligned with inputVectorSizes may be dynamic");
1438  bool needMaskForWrite = !llvm::equal(
1439  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1440  if (needMaskForWrite) {
1441  SmallVector<int64_t> writeMaskShape;
1442  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1443  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1444  destShape.end());
1445  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1446  Value maskForWrite =
1447  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1448  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1449  }
1450  return write;
1451 }
1452 
1453 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1454 /// padding value and (3) input vector sizes into:
1455 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1456 /// As in the following example:
1457 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1458 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1459 ///
1460 /// This pack would be vectorized to:
1461 ///
1462 /// %load = vector.mask %mask {
1463 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1464 /// {in_bounds = [true, true, true]} :
1465 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1466 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1467 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1468 /// to vector<32x4x2x1x16xf32>
1469 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1470 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1471 /// %write = vector.transfer_write %transpose,
1472 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1473 /// {in_bounds = [true, true, true, true, true]}
1474 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1475 ///
1476 /// If the (3) input vector sizes are not provided, the vector sizes are
1477 /// determined by the result tensor shape. Also, we update the inBounds
1478 /// attribute instead of masking.
1479 static LogicalResult
1480 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1481  ArrayRef<int64_t> inputVectorSizes,
1482  SmallVectorImpl<Value> &newResults) {
1483  OpBuilder::InsertionGuard g(rewriter);
1484  rewriter.setInsertionPoint(packOp);
1485 
1486  Location loc = packOp.getLoc();
1487  auto padValue = packOp.getPaddingValue();
1488  if (!padValue) {
1489  padValue = rewriter.create<arith::ConstantOp>(
1490  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1491  }
1492  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1493  LogicalResult status =
1494  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1495  .reifyResultShapes(rewriter, reifiedReturnShapes);
1496  (void)status; // prevent unused variable warning on non-assert builds.
1497  assert(succeeded(status) && "failed to reify result shapes");
1498 
1499  // If the input vector sizes are not provided, then the vector sizes are
1500  // determined by the result tensor shape. In case the vector sizes aren't
1501  // provided, we update the inBounds attribute instead of masking.
1502  bool useInBoundsInsteadOfMasking = true;
1503  if (inputVectorSizes.empty()) {
1504  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1505  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1506  useInBoundsInsteadOfMasking = false;
1507  }
1508 
1509  // Create masked TransferReadOp.
1510  SmallVector<int64_t> inputShape(inputVectorSizes);
1511  auto innerTiles = packOp.getStaticInnerTiles();
1512  auto innerDimsPos = packOp.getInnerDimsPos();
1513  auto outerDimsPerm = packOp.getOuterDimsPerm();
1514  if (!outerDimsPerm.empty())
1515  applyPermutationToVector(inputShape,
1516  invertPermutationVector(outerDimsPerm));
1517  for (auto [idx, size] : enumerate(innerTiles))
1518  inputShape[innerDimsPos[idx]] *= size;
1519  auto maskedRead = vector::createReadOrMaskedRead(
1520  rewriter, loc, packOp.getSource(), inputShape, padValue,
1521  useInBoundsInsteadOfMasking);
1522 
1523  // Create ShapeCastOp.
1524  SmallVector<int64_t> destShape(inputVectorSizes);
1525  destShape.append(innerTiles.begin(), innerTiles.end());
1526  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1527  packOp.getDestType().getElementType());
1528  auto shapeCastOp =
1529  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1530 
1531  // Create TransposeOp.
1532  auto destPermutation =
1534  auto transposeOp = rewriter.create<vector::TransposeOp>(
1535  loc, shapeCastOp.getResult(), destPermutation);
1536 
1537  // Create TransferWriteOp.
1538  Operation *write =
1539  createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1540  reifiedReturnShapes[0], inputVectorSizes);
1541  newResults.push_back(write->getResult(0));
1542  return success();
1543 }
1544 
1545 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1546 /// Vector::TransferReadOp - Reads a vector from the source tensor
1547 /// vector::TransposeOp - Transpose the Source tensor
1548 /// ShapeCastOp - Reshape the data based on the target.
1549 /// vector::TransferWriteOp. - Write the result vector back to the destination
1550 /// tensor
1551 static LogicalResult
1552 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1553  ArrayRef<int64_t> inputVectorSizes,
1554  SmallVectorImpl<Value> &newResults) {
1555 
1556  OpBuilder::InsertionGuard g(rewriter);
1557  rewriter.setInsertionPoint(unpackOp);
1558 
1559  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1560 
1561  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1562  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1563 
1564  SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
1565  inputVectorSizes.end());
1566  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1567  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1568 
1569  // ReadMask is the size of tensor used to read and apply mask. It is
1570  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1571  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1572  // size M-N
1573  // Thus:
1574  // - initially: ReadMaskShape = vectorInputSizes
1575  // - Divide all the readMaskShape locations pointed by innerDimPos
1576  // by the innerTileSize attribute value.
1577  // - if outer_dims_perms is present: do that permutation on readMaskShape.
1578  // - Append the remaining shape from SS
1579  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1580  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1581  // 128] and outer_dims_perm is [1, 0] then read shape is:
1582  // ReadMaskShape(initial): [512, 128]
1583  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1584  // = [16, 8]
1585  // After applying outer_dims_perm: [8, 16]
1586  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1587 
1588  for (auto [index, size] : enumerate(innerTiles)) {
1589  readMaskShape[innerDimPos[index]] =
1590  llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
1591  }
1592  if (!outerDimsPerm.empty()) {
1593  applyPermutationToVector(readMaskShape, outerDimsPerm);
1594  }
1595  readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
1596  sourceShape.end());
1597 
1598  ReifiedRankedShapedTypeDims reifiedRetShapes;
1599  LogicalResult status =
1600  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1601  .reifyResultShapes(rewriter, reifiedRetShapes);
1602  if (status.failed()) {
1603  LDBG("Unable to reify result shapes of " << unpackOp);
1604  return failure();
1605  }
1606  Location loc = unpackOp->getLoc();
1607 
1608  auto padValue = rewriter.create<arith::ConstantOp>(
1609  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1610 
1611  // Read result, mask if necessary. If transferReadOp shape is not equal
1612  // to shape of source, then a mask is necessary.
1614  rewriter, loc, unpackOp.getSource(),
1615  ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
1616 
1617  PackingMetadata packMetadata;
1618  SmallVector<int64_t> lastDimToInsertPosPerm =
1619  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1620  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1621  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1622  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1623  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1624  RankedTensorType stripMineTensorType =
1625  RankedTensorType::get(stripMineShape, stripMineElemType);
1626  // Transpose the appropriate rows to match output.
1627  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1628  loc, readResult, lastDimToInsertPosPerm);
1629 
1630  // Collapse the vector to the size required by result.
1631  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1632  stripMineTensorType, packMetadata.reassociations);
1633  mlir::VectorType vecCollapsedType =
1634  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1635  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1636  loc, vecCollapsedType, transposeOp->getResult(0));
1637 
1638  // WriteMaskShape had to match the shapecast shape for dynamic sizes,
1639  // otherwise the validator complains that the mask size is invalid.
1640  SmallVector<int64_t> writeMaskShape(
1641  unpackOp.getDestType().hasStaticShape()
1642  ? inputVectorSizes
1643  : shapeCastOp.getResultVectorType().getShape());
1644  Operation *write =
1645  createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
1646  reifiedRetShapes[0], writeMaskShape);
1647  newResults.push_back(write->getResult(0));
1648  return success();
1649 }
1650 
1651 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1652 /// and (3) all-zero lowPad to
1653 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1654 static LogicalResult
1655 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1656  ArrayRef<int64_t> inputVectorSizes,
1657  SmallVectorImpl<Value> &newResults) {
1658  auto padValue = padOp.getConstantPaddingValue();
1659  Location loc = padOp.getLoc();
1660 
1661  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1662  OpBuilder::InsertionGuard g(rewriter);
1663  rewriter.setInsertionPoint(padOp);
1664 
1665  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1666  LogicalResult status =
1667  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1668  .reifyResultShapes(rewriter, reifiedReturnShapes);
1669  (void)status; // prevent unused variable warning on non-assert builds
1670  assert(succeeded(status) && "failed to reify result shapes");
1671  auto maskedRead = vector::createReadOrMaskedRead(
1672  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue);
1674  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
1675  newResults.push_back(write->getResult(0));
1676  return success();
1677 }
1678 
1679 // TODO: probably need some extra checks for reduction followed by consumer
1680 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1682  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1683  LDBG("reduction precondition failed: no reduction iterator\n");
1684  return failure();
1685  }
1686  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1687  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1688  if (indexingMap.isPermutation())
1689  continue;
1690 
1691  Operation *reduceOp = matchLinalgReduction(&opOperand);
1692  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1693  LDBG("reduction precondition failed: reduction detection failed\n");
1694  return failure();
1695  }
1696  }
1697  return success();
1698 }
1699 
1700 static LogicalResult
1702  bool flatten1DDepthwiseConv) {
1703  if (flatten1DDepthwiseConv) {
1704  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1705  "supported\n");
1706  return failure();
1707  }
1708 
1709  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1710  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1711  return failure();
1712  }
1713 
1714  // Support dynamic shapes in 1D depthwise convolution, but only in the
1715  // _channel_ dimension.
1716  Value lhs = conv.getDpsInputOperand(0)->get();
1717  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1718  auto shapeWithoutCh = lhsShape.drop_back(1);
1719  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1720  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1721  "channel dim can be dynamic\n");
1722  return failure();
1723  }
1724 
1725  return success();
1726 }
1727 
1728 static LogicalResult
1730  bool flatten1DDepthwiseConv) {
1731  if (isa<ConvolutionOpInterface>(op.getOperation()))
1732  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1733 
1734  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1735  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1736  if (!isElementwise(op) &&
1737  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1738  op.getOperation()))
1739  return failure();
1740 
1741  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1742  return success();
1743 }
1744 
1745 /// Need to check if the inner-tiles are static/constant.
1746 static LogicalResult
1747 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1748  ArrayRef<int64_t> inputVectorSizes) {
1749 
1750  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1751  return !getConstantIntValue(res).has_value();
1752  })) {
1753  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1754  return failure();
1755  }
1756  llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1757  if (!inputVectorSizes.empty() &&
1758  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1759  return failure();
1760 
1761  return success();
1762 }
1763 
1765  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1766  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1767  // tensor with dimension of 0 cannot be vectorized.
1768  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1769  return failure();
1770  // Check API contract for input vector sizes.
1771  if (!inputVectorSizes.empty() &&
1772  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1773  inputVectorSizes)))
1774  return failure();
1775 
1776  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1777  linalgOp, flatten1DDepthwiseConv))) {
1778  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1779  return failure();
1780  }
1781 
1783 
1784  // Register CustomVectorizationPrecondition for extractOp.
1785  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1786 
1787  // All types in the body should be a supported element type for VectorType.
1788  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1789  // Check if any custom hook can vectorize the inner op.
1790  if (llvm::any_of(
1791  customPreconditions,
1792  [&](const CustomVectorizationPrecondition &customPrecondition) {
1793  return succeeded(
1794  customPrecondition(&innerOp, vectorizeNDExtract));
1795  })) {
1796  continue;
1797  }
1798  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1799  return !VectorType::isValidElementType(type);
1800  })) {
1801  return failure();
1802  }
1803  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1804  return !VectorType::isValidElementType(type);
1805  })) {
1806  return failure();
1807  }
1808  }
1809  if (isElementwise(linalgOp))
1810  return success();
1811 
1812  // TODO: isaConvolutionOpInterface that can also infer from generic
1813  // features. But we will still need stride/dilation attributes that will be
1814  // annoying to reverse-engineer...
1815  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1816  return success();
1817  // TODO: the common vector shape is equal to the static loop sizes only when
1818  // all indexing maps are projected permutations. For convs and stencils the
1819  // logic will need to evolve.
1820  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1821  LDBG("precondition failed: not projected permutations\n");
1822  return failure();
1823  }
1824  if (failed(reductionPreconditions(linalgOp))) {
1825  LDBG("precondition failed: reduction preconditions\n");
1826  return failure();
1827  }
1828  return success();
1829 }
1830 
1831 static LogicalResult
1832 vectorizePackOpPrecondition(tensor::PackOp packOp,
1833  ArrayRef<int64_t> inputVectorSizes) {
1834  auto padValue = packOp.getPaddingValue();
1835  Attribute cstAttr;
1836  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1837  LDBG("pad value is not constant: " << packOp << "\n");
1838  return failure();
1839  }
1840  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1841  bool satisfyEmptyCond = true;
1842  if (inputVectorSizes.empty()) {
1843  if (!packOp.getDestType().hasStaticShape() ||
1844  !packOp.getSourceType().hasStaticShape())
1845  satisfyEmptyCond = false;
1846  }
1847 
1848  if (!satisfyEmptyCond &&
1850  resultTensorShape.take_front(packOp.getSourceRank()),
1851  inputVectorSizes)))
1852  return failure();
1853 
1854  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1855  return !getConstantIntValue(v).has_value();
1856  })) {
1857  LDBG("inner_tiles must be constant: " << packOp << "\n");
1858  return failure();
1859  }
1860 
1861  return success();
1862 }
1863 
1864 static LogicalResult
1865 vectorizePadOpPrecondition(tensor::PadOp padOp,
1866  ArrayRef<int64_t> inputVectorSizes) {
1867  auto padValue = padOp.getConstantPaddingValue();
1868  if (!padValue) {
1869  LDBG("pad value is not constant: " << padOp << "\n");
1870  return failure();
1871  }
1872 
1873  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1874  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1875  inputVectorSizes)))
1876  return failure();
1877 
1878  if (llvm::any_of(padOp.getLow(), [](Value v) {
1879  std::optional<int64_t> res = getConstantIntValue(v);
1880  return !res.has_value() || res.value() != 0;
1881  })) {
1882  LDBG("low pad must all be zero: " << padOp << "\n");
1883  return failure();
1884  }
1885 
1886  return success();
1887 }
1888 
1889 /// Preconditions for scalable vectors.
1890 static LogicalResult
1892  ArrayRef<int64_t> inputVectorSizes,
1893  ArrayRef<bool> inputScalableVecDims) {
1894  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1895  "Number of input vector sizes and scalable dims doesn't match");
1896 
1897  if (inputVectorSizes.empty())
1898  return success();
1899 
1900  bool isScalable = inputScalableVecDims.back();
1901  if (!isScalable)
1902  return success();
1903 
1904  // Only element-wise and 1d depthwise conv ops supported in the presence of
1905  // scalable dims.
1906  auto linalgOp = dyn_cast<LinalgOp>(op);
1907  return success(linalgOp && (isElementwise(linalgOp) ||
1908  isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
1909 }
1910 
1912  Operation *op, ArrayRef<int64_t> inputVectorSizes,
1913  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
1914  bool flatten1DDepthwiseConv) {
1915  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
1916  inputScalableVecDims)))
1917  return failure();
1918 
1920  .Case<linalg::LinalgOp>([&](auto linalgOp) {
1921  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1922  vectorizeNDExtract,
1923  flatten1DDepthwiseConv);
1924  })
1925  .Case<tensor::PadOp>([&](auto padOp) {
1926  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
1927  })
1928  .Case<tensor::PackOp>([&](auto packOp) {
1929  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
1930  })
1931  .Case<tensor::UnPackOp>([&](auto unpackOp) {
1932  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
1933  })
1934  .Default([](auto) { return failure(); });
1935 }
1936 
1937 /// Converts affine.apply Ops to arithmetic operations.
1938 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1939  OpBuilder::InsertionGuard g(rewriter);
1940  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
1941 
1942  for (auto op : make_early_inc_range(toReplace)) {
1943  rewriter.setInsertionPoint(op);
1944  auto expanded = affine::expandAffineExpr(
1945  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
1946  op.getOperands().take_front(op.getAffineMap().getNumDims()),
1947  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
1948  rewriter.replaceOp(op, expanded);
1949  }
1950 }
1951 
1952 /// Emit a suitable vector form for an operation. If provided,
1953 /// `inputVectorSizes` are used to vectorize this operation.
1954 /// `inputVectorSizes` must match the rank of the iteration space of the
1955 /// operation and the input vector sizes must be greater than or equal to
1956 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
1957 /// also allows the vectorization of operations with dynamic shapes.
1959  ArrayRef<int64_t> inputVectorSizes,
1960  ArrayRef<bool> inputScalableVecDims,
1961  bool vectorizeNDExtract,
1962  bool flatten1DDepthwiseConv) {
1963  LDBG("Attempting to vectorize:\n" << *op << "\n");
1964  LDBG("Input vector sizes: ");
1965  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1966  LLVM_DEBUG(llvm::dbgs() << "\n");
1967  LDBG("Input scalable vector dims: ");
1968  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
1969  LLVM_DEBUG(llvm::dbgs() << "\n");
1970 
1971  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
1972  vectorizeNDExtract,
1973  flatten1DDepthwiseConv))) {
1974  LDBG("Vectorization pre-conditions failed\n");
1975  return failure();
1976  }
1977 
1978  // Initialize vectorization state.
1979  VectorizationState state(rewriter);
1980  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
1981  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
1982  inputScalableVecDims))) {
1983  LDBG("Vectorization state couldn't be initialized\n");
1984  return failure();
1985  }
1986  }
1987 
1988  SmallVector<Value> results;
1989  auto vectorizeResult =
1991  .Case<linalg::LinalgOp>([&](auto linalgOp) {
1992  // TODO: isaConvolutionOpInterface that can also infer from
1993  // generic features. Will require stride/dilation attributes
1994  // inference.
1995  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
1997  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
1998  flatten1DDepthwiseConv);
1999  if (succeeded(convOr)) {
2000  llvm::append_range(results, (*convOr)->getResults());
2001  return success();
2002  }
2003 
2004  LDBG("Unsupported convolution can't be vectorized.\n");
2005  return failure();
2006  }
2007 
2008  LDBG("Vectorize generic by broadcasting to the canonical vector "
2009  "shape\n");
2010 
2011  // Pre-process before proceeding.
2012  convertAffineApply(rewriter, linalgOp);
2013 
2014  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2015  // to 'OpBuilder' when it is passed over to some methods like
2016  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2017  // erase an op within these methods, the actual rewriter won't be
2018  // notified and we will end up with read-after-free issues!
2019  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2020  })
2021  .Case<tensor::PadOp>([&](auto padOp) {
2022  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2023  results);
2024  })
2025  .Case<tensor::PackOp>([&](auto packOp) {
2026  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2027  results);
2028  })
2029  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2030  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2031  inputVectorSizes, results);
2032  })
2033  .Default([](auto) { return failure(); });
2034 
2035  if (failed(vectorizeResult)) {
2036  LDBG("Vectorization failed\n");
2037  return failure();
2038  }
2039 
2040  if (!results.empty())
2041  rewriter.replaceOp(op, results);
2042  else
2043  rewriter.eraseOp(op);
2044 
2045  return success();
2046 }
2047 
2049  memref::CopyOp copyOp) {
2050  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2051  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2052  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2053  return failure();
2054 
2055  auto srcElementType = getElementTypeOrSelf(srcType);
2056  auto dstElementType = getElementTypeOrSelf(dstType);
2057  if (!VectorType::isValidElementType(srcElementType) ||
2058  !VectorType::isValidElementType(dstElementType))
2059  return failure();
2060 
2061  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2062  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2063 
2064  Location loc = copyOp->getLoc();
2065  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2066  SmallVector<Value> indices(srcType.getRank(), zero);
2067 
2068  Value readValue = rewriter.create<vector::TransferReadOp>(
2069  loc, readType, copyOp.getSource(), indices,
2070  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2071  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2072  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2073  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2074  }
2075  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2076  loc, readValue, copyOp.getTarget(), indices,
2077  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2078  rewriter.replaceOp(copyOp, writeValue->getResults());
2079  return success();
2080 }
2081 
2082 //----------------------------------------------------------------------------//
2083 // Misc. vectorization patterns.
2084 //----------------------------------------------------------------------------//
2085 
2086 /// Helper function that retrieves the value of an IntegerAttr.
2087 static int64_t getIntFromAttr(Attribute attr) {
2088  return cast<IntegerAttr>(attr).getInt();
2089 }
2090 
2091 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2092 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2093 /// not supported.
2095  ArrayRef<OpFoldResult> ofrs) {
2096  SmallVector<Value> result;
2097  for (auto o : ofrs) {
2098  if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2099  result.push_back(val);
2100  } else {
2101  result.push_back(rewriter.create<arith::ConstantIndexOp>(
2102  loc, getIntFromAttr(o.template get<Attribute>())));
2103  }
2104  }
2105  return result;
2106 }
2107 
2108 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2109 /// InsertSliceOp. For now, only constant padding values are supported.
2110 /// If there is enough static type information, TransferReadOps and
2111 /// TransferWriteOps may be generated instead of InsertSliceOps.
2114  PatternBenefit benefit = 1)
2115  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2116  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2117  /// each dimension size is statically know in the source type or the result
2118  /// type (or both).
2120  tensor::PadOp padOp, Value dest) {
2121  auto sourceType = padOp.getSourceType();
2122  auto resultType = padOp.getResultType();
2123  if (!VectorType::isValidElementType(sourceType.getElementType()))
2124  return failure();
2125 
2126  // Copy cannot be vectorized if pad value is non-constant and source shape
2127  // is dynamic. In case of a dynamic source shape, padding must be appended
2128  // by TransferReadOp, but TransferReadOp supports only constant padding.
2129  auto padValue = padOp.getConstantPaddingValue();
2130  if (!padValue) {
2131  if (!sourceType.hasStaticShape())
2132  return failure();
2133  // Create dummy padding value.
2134  auto elemType = sourceType.getElementType();
2135  padValue = rewriter.create<arith::ConstantOp>(
2136  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2137  }
2138 
2139  SmallVector<int64_t> vecShape;
2140  SmallVector<bool> readInBounds;
2141  SmallVector<bool> writeInBounds;
2142  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2143  if (!sourceType.isDynamicDim(i)) {
2144  vecShape.push_back(sourceType.getDimSize(i));
2145  // Source shape is statically known: Neither read nor write are
2146  // out-of- bounds.
2147  readInBounds.push_back(true);
2148  writeInBounds.push_back(true);
2149  } else if (!resultType.isDynamicDim(i)) {
2150  // Source shape is not statically known, but result shape is.
2151  // Vectorize with size of result shape. This may be larger than the
2152  // source size.
2153  vecShape.push_back(resultType.getDimSize(i));
2154  // Read may be out-of-bounds because the result size could be larger
2155  // than the source size.
2156  readInBounds.push_back(false);
2157  // Write is out-of-bounds if low padding > 0.
2158  writeInBounds.push_back(
2159  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2160  static_cast<int64_t>(0));
2161  } else {
2162  // Neither source nor result dim of padOp is static. Cannot vectorize
2163  // the copy.
2164  return failure();
2165  }
2166  }
2167  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2168 
2169  // Generate TransferReadOp.
2170  SmallVector<Value> readIndices(
2171  vecType.getRank(),
2172  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2173  auto read = rewriter.create<vector::TransferReadOp>(
2174  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2175  ArrayRef<bool>{readInBounds});
2176 
2177  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2178  // entire tensor, write directly to the FillOp's operand.
2179  if (llvm::equal(vecShape, resultType.getShape()) &&
2180  llvm::all_of(writeInBounds, [](bool b) { return b; }))
2181  if (auto fill = dest.getDefiningOp<FillOp>())
2182  dest = fill.output();
2183 
2184  // Generate TransferWriteOp.
2185  auto writeIndices =
2186  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2187  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2188  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2189 
2190  return success();
2191  }
2192 };
2193 
2194 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2195 /// given operation type OpTy.
2196 template <typename OpTy>
2197 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2199 
2200  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2201  PatternRewriter &rewriter) const final {
2202  bool changed = false;
2203  // Insert users in vector, because some users may be replaced/removed.
2204  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2205  if (auto op = dyn_cast<OpTy>(user))
2206  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2207  return success(changed);
2208  }
2209 
2210 protected:
2212  tensor::PadOp padOp, OpTy op) const = 0;
2213 };
2214 
2215 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2216 /// ```
2217 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2218 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2219 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2220 /// ```
2221 /// is rewritten to:
2222 /// ```
2223 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2224 /// {in_bounds = [true, true]}
2225 /// : tensor<?x?xf32>, vector<17x5xf32>
2226 /// ```
2227 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2228 /// sure that the original padding value %cst was never used.
2229 ///
2230 /// This rewrite is possible if:
2231 /// - `xferOp` has no out-of-bounds dims or mask.
2232 /// - Low padding is static 0.
2233 /// - Single, scalar padding value.
2235  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2237  vector::TransferReadOp>::VectorizePadOpUserPattern;
2238 
2239  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2240  vector::TransferReadOp xferOp) const override {
2241  // Low padding must be static 0.
2242  if (!padOp.hasZeroLowPad())
2243  return failure();
2244  // Pad value must be a constant.
2245  auto padValue = padOp.getConstantPaddingValue();
2246  if (!padValue)
2247  return failure();
2248  // Padding value of existing `xferOp` is unused.
2249  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2250  return failure();
2251 
2252  rewriter.modifyOpInPlace(xferOp, [&]() {
2253  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2254  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2255  rewriter.getBoolArrayAttr(inBounds));
2256  xferOp.getSourceMutable().assign(padOp.getSource());
2257  xferOp.getPaddingMutable().assign(padValue);
2258  });
2259 
2260  return success();
2261  }
2262 };
2263 
2264 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2265 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2266 /// value, where the same amount of padding is immediately removed again after
2267 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2268 /// tensor value and apply out-of-bounds masking. E.g.:
2269 /// ```
2270 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2271 /// : tensor<...> to tensor<?x?xf32>
2272 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2273 /// %2 = vector.transfer_write %vec, %1[...]
2274 /// : vector<17x5xf32>, tensor<17x5xf32>
2275 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2276 /// : tensor<17x5xf32> to tensor<?x?xf32>
2277 /// ```
2278 /// is rewritten to:
2279 /// ```
2280 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2281 /// : tensor<...> to tensor<?x?xf32>
2282 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2283 /// tensor<?x?xf32>
2284 /// ```
2285 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2286 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2287 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2288 /// from %r's old dimensions.
2289 ///
2290 /// This rewrite is possible if:
2291 /// - Low padding is static 0.
2292 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2293 /// ExtractSliceOp trims the same amount of padding that was added
2294 /// beforehand.
2295 /// - Single, scalar padding value.
2297  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2299  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2300 
2301  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2302  vector::TransferWriteOp xferOp) const override {
2303  // TODO: support 0-d corner case.
2304  if (xferOp.getTransferRank() == 0)
2305  return failure();
2306 
2307  // Low padding must be static 0.
2308  if (!padOp.hasZeroLowPad())
2309  return failure();
2310  // Pad value must be a constant.
2311  auto padValue = padOp.getConstantPaddingValue();
2312  if (!padValue)
2313  return failure();
2314  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2315  if (!xferOp->hasOneUse())
2316  return failure();
2317  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2318  if (!trimPadding)
2319  return failure();
2320  // Only static zero offsets supported when trimming padding.
2321  if (!trimPadding.hasZeroOffset())
2322  return failure();
2323  // trimPadding must remove the amount of padding that was added earlier.
2324  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2325  return failure();
2326 
2327  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2328  rewriter.setInsertionPoint(xferOp);
2329 
2330  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2331  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2332  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2333  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2334  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2335  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2336 
2337  return success();
2338  }
2339 
2340  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2341  /// i.e., same dimensions.
2342  ///
2343  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2344  /// dimensions, this function tries to infer the (static) tensor size by
2345  /// looking at the defining op and utilizing op-specific knowledge.
2346  ///
2347  /// This is a conservative analysis. In case equal tensor sizes cannot be
2348  /// proven statically, this analysis returns `false` even though the tensor
2349  /// sizes may turn out to be equal at runtime.
2350  bool hasSameTensorSize(Value beforePadding,
2351  tensor::ExtractSliceOp afterTrimming) const {
2352  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2353  // result and CastOp operand.
2354  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2355  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2356  return true;
2357 
2358  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2359  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2360  // Only RankedTensorType supported.
2361  if (!t1 || !t2)
2362  return false;
2363  // Rank of both values must be the same.
2364  if (t1.getRank() != t2.getRank())
2365  return false;
2366 
2367  // All static dimensions must be the same. Mixed cases (e.g., dimension
2368  // static in `t1` but dynamic in `t2`) are not supported.
2369  for (unsigned i = 0; i < t1.getRank(); ++i) {
2370  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2371  return false;
2372  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2373  return false;
2374  }
2375 
2376  // Nothing more to check if all dimensions are static.
2377  if (t1.getNumDynamicDims() == 0)
2378  return true;
2379 
2380  // All dynamic sizes must be the same. The only supported case at the
2381  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2382  // thereof).
2383 
2384  // Apart from CastOp, only ExtractSliceOp is supported.
2385  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2386  if (!beforeSlice)
2387  return false;
2388 
2389  assert(static_cast<size_t>(t1.getRank()) ==
2390  beforeSlice.getMixedSizes().size());
2391  assert(static_cast<size_t>(t2.getRank()) ==
2392  afterTrimming.getMixedSizes().size());
2393 
2394  for (unsigned i = 0; i < t1.getRank(); ++i) {
2395  // Skip static dimensions.
2396  if (!t1.isDynamicDim(i))
2397  continue;
2398  auto size1 = beforeSlice.getMixedSizes()[i];
2399  auto size2 = afterTrimming.getMixedSizes()[i];
2400 
2401  // Case 1: Same value or same constant int.
2402  if (isEqualConstantIntOrValue(size1, size2))
2403  continue;
2404 
2405  // Other cases: Take a deeper look at defining ops of values.
2406  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2407  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2408  if (!v1 || !v2)
2409  return false;
2410 
2411  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2412  // CSE is run.)
2413  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2414  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2415  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2416  minOp1.getOperands() == minOp2.getOperands())
2417  continue;
2418 
2419  // Add additional cases as needed.
2420  }
2421 
2422  // All tests passed.
2423  return true;
2424  }
2425 };
2426 
2427 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2428 /// ```
2429 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2430 /// %r = tensor.insert_slice %0
2431 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2432 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2433 /// ```
2434 /// is rewritten to:
2435 /// ```
2436 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2437 /// : tensor<?x?xf32>, vector<17x5xf32>
2438 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2439 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2440 /// ```
2441 ///
2442 /// This rewrite is possible if:
2443 /// - Low padding is static 0.
2444 /// - `padOp` result shape is static.
2445 /// - The entire padded tensor is inserted.
2446 /// (Implies that sizes of `insertOp` are all static.)
2447 /// - Only unit strides in `insertOp`.
2448 /// - Single, scalar padding value.
2449 /// - `padOp` result not used as destination.
2451  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2453  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2454 
2455  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2456  tensor::InsertSliceOp insertOp) const override {
2457  // Low padding must be static 0.
2458  if (!padOp.hasZeroLowPad())
2459  return failure();
2460  // Only unit stride supported.
2461  if (!insertOp.hasUnitStride())
2462  return failure();
2463  // Pad value must be a constant.
2464  auto padValue = padOp.getConstantPaddingValue();
2465  if (!padValue)
2466  return failure();
2467  // Dynamic shapes not supported.
2468  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2469  return failure();
2470  // Pad result not used as destination.
2471  if (insertOp.getDest() == padOp.getResult())
2472  return failure();
2473 
2474  auto vecType = VectorType::get(padOp.getType().getShape(),
2475  padOp.getType().getElementType());
2476  unsigned vecRank = vecType.getRank();
2477  unsigned tensorRank = insertOp.getType().getRank();
2478 
2479  // Check if sizes match: Insert the entire tensor into most minor dims.
2480  // (No permutations allowed.)
2481  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2482  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2483  if (!llvm::all_of(
2484  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2485  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2486  }))
2487  return failure();
2488 
2489  // Insert the TransferReadOp and TransferWriteOp at the position of the
2490  // InsertSliceOp.
2491  rewriter.setInsertionPoint(insertOp);
2492 
2493  // Generate TransferReadOp: Read entire source tensor and add high
2494  // padding.
2495  SmallVector<Value> readIndices(
2496  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2497  auto read = rewriter.create<vector::TransferReadOp>(
2498  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2499 
2500  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2501  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2502  // source must fit into the destination at the specified offsets.
2503  auto writeIndices =
2504  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2505  SmallVector<bool> inBounds(vecRank, true);
2506  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2507  insertOp, read, insertOp.getDest(), writeIndices,
2508  ArrayRef<bool>{inBounds});
2509 
2510  return success();
2511  }
2512 };
2513 
2515  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2516  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2517  baseBenefit);
2518  // Try these specialized patterns first before resorting to the generic one.
2522  patterns.getContext(), baseBenefit.getBenefit() + 1);
2523 }
2524 
2525 //----------------------------------------------------------------------------//
2526 // Forwarding patterns
2527 //----------------------------------------------------------------------------//
2528 
2529 /// Check whether there is any interleaved use of any `values` between
2530 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2531 /// is in a different block.
2532 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2533  ValueRange values) {
2534  if (firstOp->getBlock() != secondOp->getBlock() ||
2535  !firstOp->isBeforeInBlock(secondOp)) {
2536  LDBG("interleavedUses precondition failed, firstOp: "
2537  << *firstOp << ", second op: " << *secondOp << "\n");
2538  return true;
2539  }
2540  for (auto v : values) {
2541  for (auto &u : v.getUses()) {
2542  Operation *owner = u.getOwner();
2543  if (owner == firstOp || owner == secondOp)
2544  continue;
2545  // TODO: this is too conservative, use dominance info in the future.
2546  if (owner->getBlock() == firstOp->getBlock() &&
2547  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2548  continue;
2549  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2550  << ", second op: " << *secondOp << "\n");
2551  return true;
2552  }
2553  }
2554  return false;
2555 }
2556 
2557 /// Return the unique subview use of `v` if it is indeed unique, null
2558 /// otherwise.
2559 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2560  memref::SubViewOp subViewOp;
2561  for (auto &u : v.getUses()) {
2562  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2563  if (subViewOp)
2564  return memref::SubViewOp();
2565  subViewOp = newSubViewOp;
2566  }
2567  }
2568  return subViewOp;
2569 }
2570 
2571 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2572 /// when available.
2574  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2575 
2576  // TODO: support mask.
2577  if (xferOp.getMask())
2578  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2579 
2580  // Transfer into `view`.
2581  Value viewOrAlloc = xferOp.getSource();
2582  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2583  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2584  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2585 
2586  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2587  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2588  if (!subViewOp)
2589  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2590  Value subView = subViewOp.getResult();
2591 
2592  // Find the copy into `subView` without interleaved uses.
2593  memref::CopyOp copyOp;
2594  for (auto &u : subView.getUses()) {
2595  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2596  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2597  if (newCopyOp.getTarget() != subView)
2598  continue;
2599  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2600  continue;
2601  copyOp = newCopyOp;
2602  break;
2603  }
2604  }
2605  if (!copyOp)
2606  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2607 
2608  // Find the fill into `viewOrAlloc` without interleaved uses before the
2609  // copy.
2610  FillOp maybeFillOp;
2611  for (auto &u : viewOrAlloc.getUses()) {
2612  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2613  assert(isa<MemRefType>(newFillOp.output().getType()));
2614  if (newFillOp.output() != viewOrAlloc)
2615  continue;
2616  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2617  continue;
2618  maybeFillOp = newFillOp;
2619  break;
2620  }
2621  }
2622  // Ensure padding matches.
2623  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2624  return rewriter.notifyMatchFailure(xferOp,
2625  "padding value does not match fill");
2626 
2627  // `in` is the subview that memref.copy reads. Replace it.
2628  Value in = copyOp.getSource();
2629 
2630  // memref.copy + linalg.fill can be used to create a padded local buffer.
2631  // The `masked` attribute is only valid on this padded buffer.
2632  // When forwarding to vector.transfer_read, the attribute must be reset
2633  // conservatively.
2634  Value res = rewriter.create<vector::TransferReadOp>(
2635  xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2636  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2637  // in_bounds is explicitly reset
2638  /*inBoundsAttr=*/ArrayAttr());
2639 
2640  if (maybeFillOp)
2641  rewriter.eraseOp(maybeFillOp);
2642  rewriter.eraseOp(copyOp);
2643  rewriter.replaceOp(xferOp, res);
2644 
2645  return success();
2646 }
2647 
2648 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2649 /// when available.
2651  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2652  // TODO: support mask.
2653  if (xferOp.getMask())
2654  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2655 
2656  // Transfer into `viewOrAlloc`.
2657  Value viewOrAlloc = xferOp.getSource();
2658  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2659  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2660  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2661 
2662  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2663  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2664  if (!subViewOp)
2665  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2666  Value subView = subViewOp.getResult();
2667 
2668  // Find the copy from `subView` without interleaved uses.
2669  memref::CopyOp copyOp;
2670  for (auto &u : subViewOp.getResult().getUses()) {
2671  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2672  if (newCopyOp.getSource() != subView)
2673  continue;
2674  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2675  continue;
2676  copyOp = newCopyOp;
2677  break;
2678  }
2679  }
2680  if (!copyOp)
2681  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2682 
2683  // `out` is the subview copied into that we replace.
2684  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2685  Value out = copyOp.getTarget();
2686 
2687  // Forward vector.transfer into copy.
2688  // memref.copy + linalg.fill can be used to create a padded local buffer.
2689  // The `masked` attribute is only valid on this padded buffer.
2690  // When forwarding to vector.transfer_write, the attribute must be reset
2691  // conservatively.
2692  rewriter.create<vector::TransferWriteOp>(
2693  xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2694  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2695  // in_bounds is explicitly reset
2696  /*inBoundsAttr=*/ArrayAttr());
2697 
2698  rewriter.eraseOp(copyOp);
2699  rewriter.eraseOp(xferOp);
2700 
2701  return success();
2702 }
2703 
2704 //===----------------------------------------------------------------------===//
2705 // Convolution vectorization patterns
2706 //===----------------------------------------------------------------------===//
2707 
2708 template <int N>
2709 static void bindShapeDims(ShapedType shapedType) {}
2710 
2711 template <int N, typename IntTy, typename... IntTy2>
2712 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2713  val = shapedType.getShape()[N];
2714  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2715 }
2716 
2717 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2718 template <typename... IntTy>
2719 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2720  bindShapeDims<0>(shapedType, vals...);
2721 }
2722 
2723 namespace {
2724 bool isCastOfBlockArgument(Operation *op) {
2725  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2726  isa<BlockArgument>(op->getOperand(0));
2727 }
2728 
2729 bool isSupportedPoolKind(vector::CombiningKind kind) {
2730  switch (kind) {
2731  case vector::CombiningKind::ADD:
2732  case vector::CombiningKind::MAXNUMF:
2733  case vector::CombiningKind::MAXIMUMF:
2734  case vector::CombiningKind::MAXSI:
2735  case vector::CombiningKind::MAXUI:
2736  case vector::CombiningKind::MINNUMF:
2737  case vector::CombiningKind::MINIMUMF:
2738  case vector::CombiningKind::MINSI:
2740  return true;
2741  default:
2742  return false;
2743  }
2744 }
2745 
2746 /// Generate a vector implementation for either:
2747 /// ```
2748 /// Op def: ( w, kw )
2749 /// Iters: ({Par(), Red()})
2750 /// Layout: {{w + kw}, {kw}, {w}}
2751 /// ```
2752 /// kw is unrolled.
2753 ///
2754 /// or
2755 ///
2756 /// ```
2757 /// Op def: ( n, w, c, kw, f )
2758 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2759 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2760 /// ```
2761 /// kw is unrolled, w is unrolled iff dilationW > 1.
2762 ///
2763 /// or
2764 ///
2765 /// ```
2766 /// Op def: ( n, c, w, f, kw )
2767 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2768 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2769 /// ```
2770 /// kw is unrolled, w is unrolled iff dilationW > 1.
2771 ///
2772 /// or
2773 ///
2774 /// ```
2775 /// Op def: ( n, w, c, kw )
2776 /// Iters: ({Par(), Par(), Par(), Red()})
2777 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2778 /// ```
2779 /// kw is unrolled, w is unrolled iff dilationW > 1.
2780 struct Conv1DGenerator
2781  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2782  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2783  int dilationW)
2784  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2785  strideW(strideW), dilationW(dilationW) {
2786  // Determine whether `linalgOp` can be generated with this generator
2787  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2788  return;
2789  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2790  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2791  resShaped = linalgOp.getDpsInitOperand(0)->get();
2792  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2793  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2794  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2795  if (!lhsShapedType || !rhsShapedType || !resShapedType)
2796  return;
2797  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2798  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2799  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2800  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2801  return;
2802 
2803  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2804  if (!reduceOp)
2805  return;
2806  redOp = reduceOp->getName().getIdentifier();
2807 
2808  if (!setOperKind(reduceOp))
2809  return;
2810  auto maybeKind = getCombinerOpKind(reduceOp);
2811  if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2812  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2813  return;
2814  }
2815 
2816  auto rhsRank = rhsShapedType.getRank();
2817  switch (oper) {
2818  case Conv:
2819  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2820  return;
2821  break;
2822  case Pool:
2823  if (rhsRank != 1)
2824  return;
2825  break;
2826  }
2827  // The op is now known to be valid.
2828  valid = true;
2829  }
2830 
2831  /// Generate a vector implementation for:
2832  /// ```
2833  /// Op def: ( w, kw )
2834  /// Iters: ({Par(), Red()})
2835  /// Layout: {{w + kw}, {kw}, {w}}
2836  /// ```
2837  /// kw is always unrolled.
2838  ///
2839  /// or
2840  ///
2841  /// ```
2842  /// Op def: ( n, w, c, kw, f )
2843  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2844  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2845  /// ```
2846  /// kw is always unrolled.
2847  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2848  /// > 1.
2849  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
2850  if (!valid)
2851  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
2852 
2853  int64_t nSize, wSize, cSize, kwSize, fSize;
2854  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
2855  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
2856  switch (conv1DOpOrder) {
2857  case Conv1DOpOrder::W:
2858  // Initialize unused dimensions
2859  nSize = fSize = cSize = 0;
2860  // out{W}
2861  bindShapeDims(resShapedType, wSize);
2862  // kernel{kw}
2863  bindShapeDims(rhsShapedType, kwSize);
2864  lhsShape = {// iw = ow + kw - 1
2865  // (i.e. 16 convolved with 3 -> 14)
2866  (wSize + kwSize - 1)};
2867  rhsShape = {kwSize};
2868  resShape = {wSize};
2869  break;
2870  case Conv1DOpOrder::Nwc:
2871  // out{n, w, f}
2872  bindShapeDims(resShapedType, nSize, wSize, fSize);
2873  switch (oper) {
2874  case Conv:
2875  // kernel{kw, c, f}
2876  bindShapeDims(rhsShapedType, kwSize, cSize);
2877  break;
2878  case Pool:
2879  // kernel{kw}
2880  bindShapeDims(rhsShapedType, kwSize);
2881  cSize = fSize;
2882  break;
2883  }
2884  lhsShape = {nSize,
2885  // iw = ow * sw + kw * dw - 1
2886  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2887  // Perform the proper inclusive -> exclusive -> inclusive.
2888  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2889  1,
2890  cSize};
2891  switch (oper) {
2892  case Conv:
2893  rhsShape = {kwSize, cSize, fSize};
2894  break;
2895  case Pool:
2896  rhsShape = {kwSize};
2897  break;
2898  }
2899  resShape = {nSize, wSize, fSize};
2900  break;
2901  case Conv1DOpOrder::Ncw:
2902  // out{n, f, w}
2903  bindShapeDims(resShapedType, nSize, fSize, wSize);
2904  switch (oper) {
2905  case Conv:
2906  // kernel{f, c, kw}
2907  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
2908  break;
2909  case Pool:
2910  // kernel{kw}
2911  bindShapeDims(rhsShapedType, kwSize);
2912  cSize = fSize;
2913  break;
2914  }
2915  lhsShape = {nSize, cSize,
2916  // iw = ow * sw + kw * dw - 1
2917  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2918  // Perform the proper inclusive -> exclusive -> inclusive.
2919  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2920  1};
2921  switch (oper) {
2922  case Conv:
2923  rhsShape = {fSize, cSize, kwSize};
2924  break;
2925  case Pool:
2926  rhsShape = {kwSize};
2927  break;
2928  }
2929  resShape = {nSize, fSize, wSize};
2930  break;
2931  }
2932 
2933  vector::TransferWriteOp write;
2934  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2935 
2936  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2937  // When strideW == 1, we can batch the contiguous loads and avoid
2938  // unrolling
2939  int64_t wSizeStep = strideW == 1 ? wSize : 1;
2940 
2941  Type lhsEltType = lhsShapedType.getElementType();
2942  Type rhsEltType = rhsShapedType.getElementType();
2943  Type resEltType = resShapedType.getElementType();
2944  auto lhsType = VectorType::get(lhsShape, lhsEltType);
2945  auto rhsType = VectorType::get(rhsShape, rhsEltType);
2946  auto resType = VectorType::get(resShape, resEltType);
2947  // Zero padding with the corresponding dimensions for lhs, rhs and res.
2948  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
2949  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
2950  SmallVector<Value> resPadding(resShape.size(), zero);
2951 
2952  // Read the whole lhs, rhs and res in one shot (with zero padding).
2953  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2954  lhsPadding);
2955  // This is needed only for Conv.
2956  Value rhs = nullptr;
2957  if (oper == Conv)
2958  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2959  rhsPadding);
2960  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
2961  resPadding);
2962 
2963  // The base vectorization case for channeled convolution is input:
2964  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2965  // vectorization case, we do pre transpose on input, weight, and output.
2966  switch (conv1DOpOrder) {
2967  case Conv1DOpOrder::W:
2968  case Conv1DOpOrder::Nwc:
2969  // Base case, so no transposes necessary.
2970  break;
2971  case Conv1DOpOrder::Ncw: {
2972  // To match base vectorization case, we pre-transpose current case.
2973  // ncw -> nwc
2974  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2975  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
2976  // fcw -> wcf
2977  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2978 
2979  // This is needed only for Conv.
2980  if (oper == Conv)
2981  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
2982  // nfw -> nwf
2983  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2984  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
2985  break;
2986  }
2987  }
2988 
2989  //===------------------------------------------------------------------===//
2990  // Begin vector-only rewrite part
2991  //===------------------------------------------------------------------===//
2992  // Unroll along kw and read slices of lhs and rhs.
2993  SmallVector<Value> lhsVals, rhsVals, resVals;
2994  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
2995  kwSize, strideW, dilationW, wSizeStep,
2996  isSingleChanneled);
2997  // Do not do for pooling.
2998  if (oper == Conv)
2999  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3000  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3001  wSizeStep, isSingleChanneled);
3002 
3003  auto linearIndex = [&](int64_t kw, int64_t w) {
3004  return kw * (wSize / wSizeStep) + w;
3005  };
3006 
3007  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3008  // or perform outerproduct for non-channeled convolution or perform simple
3009  // arith operation for pooling
3010  for (int64_t kw = 0; kw < kwSize; ++kw) {
3011  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3012  switch (oper) {
3013  case Conv:
3014  if (isSingleChanneled) {
3015  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3016  lhsVals[linearIndex(kw, w)],
3017  rhsVals[kw], resVals[w]);
3018  } else {
3019  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3020  lhsVals[linearIndex(kw, w)],
3021  rhsVals[kw], resVals[w]);
3022  }
3023  break;
3024  case Pool:
3025  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3026  resVals[w]);
3027  break;
3028  }
3029  }
3030  }
3031 
3032  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3033  isSingleChanneled);
3034  //===------------------------------------------------------------------===//
3035  // End vector-only rewrite part
3036  //===------------------------------------------------------------------===//
3037 
3038  // The base vectorization case for channeled convolution is output:
3039  // {n,w,f} To reuse the result from base pattern vectorization case, we
3040  // post transpose the base case result.
3041  switch (conv1DOpOrder) {
3042  case Conv1DOpOrder::W:
3043  case Conv1DOpOrder::Nwc:
3044  // Base case, so no transposes necessary.
3045  break;
3046  case Conv1DOpOrder::Ncw: {
3047  // nwf -> nfw
3048  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3049  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3050  break;
3051  }
3052  }
3053 
3054  return rewriter
3055  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3056  .getOperation();
3057  }
3058 
3059  // Take a value and widen to have the same element type as `ty`.
3060  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3061  const Type srcElementType = getElementTypeOrSelf(val.getType());
3062  const Type dstElementType = getElementTypeOrSelf(ty);
3063  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3064  if (srcElementType == dstElementType)
3065  return val;
3066 
3067  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3068  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3069  const Type dstType =
3070  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3071 
3072  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3073  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3074  }
3075 
3076  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3077  srcWidth < dstWidth)
3078  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3079 
3080  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3081  srcWidth < dstWidth)
3082  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3083 
3084  assert(false && "unhandled promotion case");
3085  return nullptr;
3086  }
3087 
3088  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3089  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3090  Value lhs, Value rhs, Value res) {
3091  vector::IteratorType par = vector::IteratorType::parallel;
3092  vector::IteratorType red = vector::IteratorType::reduction;
3093  AffineExpr n, w, f, c;
3094  bindDims(ctx, n, w, f, c);
3095  lhs = promote(rewriter, loc, lhs, res.getType());
3096  rhs = promote(rewriter, loc, rhs, res.getType());
3097  return rewriter.create<vector::ContractionOp>(
3098  loc, lhs, rhs, res,
3099  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3100  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3101  }
3102 
3103  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3104  // convolution.
3105  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3106  Value lhs, Value rhs, Value res) {
3107  return rewriter.create<vector::OuterProductOp>(
3108  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3109  }
3110 
3111  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3112  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3113  Value res) {
3114  if (isPoolExt)
3115  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3116  return rewriter
3117  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3118  ->getResult(0);
3119  }
3120 
3121  /// Generate a vector implementation for:
3122  /// ```
3123  /// Op def: ( n, w, c, kw)
3124  /// Iters: ({Par(), Par(), Par(), Red()})
3125  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3126  /// ```
3127  /// kw is always unrolled.
3128  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3129  /// > 1.
3130  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3131  bool channelDimScalableFlag,
3132  bool flatten) {
3133  if (!valid)
3134  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3135 
3136  bool scalableChDim = false;
3137  bool useMasking = false;
3138  int64_t nSize, wSize, cSize, kwSize;
3139  // kernel{kw, c}
3140  bindShapeDims(rhsShapedType, kwSize, cSize);
3141  if (ShapedType::isDynamic(cSize)) {
3142  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3143  cSize = channelDimVecSize;
3144  // Scalable vectors are only used when both conditions are met:
3145  // 1. channel dim is dynamic
3146  // 2. channelDimScalableFlag is set
3147  scalableChDim = channelDimScalableFlag;
3148  useMasking = true;
3149  }
3150 
3151  assert(!(useMasking && flatten) &&
3152  "Unsupported flattened conv with dynamic shapes");
3153 
3154  // out{n, w, c}
3155  bindShapeDims(resShapedType, nSize, wSize);
3156 
3157  vector::TransferWriteOp write;
3158  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3159 
3160  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3161  // When strideW == 1, we can batch the contiguous loads and avoid
3162  // unrolling
3163  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3164 
3165  Type lhsEltType = lhsShapedType.getElementType();
3166  Type rhsEltType = rhsShapedType.getElementType();
3167  Type resEltType = resShapedType.getElementType();
3168  VectorType lhsType = VectorType::get(
3169  {nSize,
3170  // iw = ow * sw + kw * dw - 1
3171  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3172  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3173  cSize},
3174  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3175  VectorType rhsType =
3176  VectorType::get({kwSize, cSize}, rhsEltType,
3177  /*scalableDims=*/{false, scalableChDim});
3178  VectorType resType =
3179  VectorType::get({nSize, wSize, cSize}, resEltType,
3180  /*scalableDims=*/{false, false, scalableChDim});
3181 
3182  // Masks the input xfer Op along the channel dim, iff the corresponding
3183  // scalable flag is set.
3184  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3185  ArrayRef<bool> scalableDims,
3186  Operation *opToMask) {
3187  if (!useMasking)
3188  return opToMask;
3189  auto maskType =
3190  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3191 
3193  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3194 
3195  Value maskOp =
3196  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3197 
3198  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3199  };
3200 
3201  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3202  // 0].
3203  Value lhs = rewriter.create<vector::TransferReadOp>(
3204  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3205  auto maybeMaskedLhs = maybeMaskXferOp(
3206  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3207 
3208  // Read rhs slice of size {kw, c} @ [0, 0].
3209  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3210  ValueRange{zero, zero});
3211  auto maybeMaskedRhs = maybeMaskXferOp(
3212  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3213 
3214  // Read res slice of size {n, w, c} @ [0, 0, 0].
3215  Value res = rewriter.create<vector::TransferReadOp>(
3216  loc, resType, resShaped, ValueRange{zero, zero, zero});
3217  auto maybeMaskedRes = maybeMaskXferOp(
3218  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3219 
3220  //===------------------------------------------------------------------===//
3221  // Begin vector-only rewrite part
3222  //===------------------------------------------------------------------===//
3223  // Unroll along kw and read slices of lhs and rhs.
3224  SmallVector<Value> lhsVals, rhsVals, resVals;
3225  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3226  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3227 
3228  // Extract lhs slice of size {n, wSizeStep, c}
3229  // @ [0, sw * w + dw * kw, 0].
3230  for (int64_t kw = 0; kw < kwSize; ++kw) {
3231  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3232  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3233  loc, maybeMaskedLhs->getResult(0),
3234  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3235  inOutSliceSizes, inOutStrides));
3236  }
3237  }
3238  // Extract rhs slice of size {c} @ [kw].
3239  for (int64_t kw = 0; kw < kwSize; ++kw) {
3240  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3241  loc, maybeMaskedRhs->getResult(0),
3242  /*offsets=*/ArrayRef<int64_t>{kw}));
3243  }
3244  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3245  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3246  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3247  loc, maybeMaskedRes->getResult(0),
3248  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3249  inOutStrides));
3250  }
3251 
3252  auto linearIndex = [&](int64_t kw, int64_t w) {
3253  return kw * (wSize / wSizeStep) + w;
3254  };
3255 
3256  // Note - the scalable flags are ignored as flattening combined with
3257  // scalable vectorization is not supported.
3258  auto inOutFlattenSliceSizes =
3259  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3260  auto lhsTypeAfterFlattening =
3261  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3262  auto resTypeAfterFlattening =
3263  VectorType::get(inOutFlattenSliceSizes, resEltType);
3264 
3265  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3266  for (int64_t kw = 0; kw < kwSize; ++kw) {
3267  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3268  Value lhsVal = lhsVals[linearIndex(kw, w)];
3269  Value resVal = resVals[w];
3270  if (flatten) {
3271  // Flatten the input and output vectors (collapse the channel
3272  // dimension)
3273  lhsVal = rewriter.create<vector::ShapeCastOp>(
3274  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3275  resVal = rewriter.create<vector::ShapeCastOp>(
3276  loc, resTypeAfterFlattening, resVals[w]);
3277  }
3278  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3279  rhsVals[kw], resVal, flatten);
3280  if (flatten) {
3281  // Un-flatten the output vector (restore the channel dimension)
3282  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3283  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3284  }
3285  }
3286  }
3287 
3288  // Its possible we failed to create the Fma.
3289  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3290  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3291  for (auto &collection :
3292  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3293  for (Value v : collection)
3294  rewriter.eraseOp(v.getDefiningOp());
3295  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3296  }
3297 
3298  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3299  // This does not depend on kw.
3300  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3301  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3302  loc, resVals[w], maybeMaskedRes->getResult(0),
3303  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3304  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3305  }
3306  //===------------------------------------------------------------------===//
3307  // End vector-only rewrite part
3308  //===------------------------------------------------------------------===//
3309 
3310  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3311  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3312  loc, maybeMaskedRes->getResult(0), resShaped,
3313  ValueRange{zero, zero, zero});
3314  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3315  resOut);
3316  }
3317 
3318  /// Lower:
3319  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3320  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3321  /// to MulAcc.
3322  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3323  Value lhs, Value rhs, Value res,
3324  bool flatten) {
3325  auto rhsTy = cast<ShapedType>(rhs.getType());
3326  auto resTy = cast<ShapedType>(res.getType());
3327 
3328  // TODO(suderman): Change this to use a vector.ima intrinsic.
3329  lhs = promote(rewriter, loc, lhs, resTy);
3330 
3331  if (flatten) {
3332  // NOTE: This following logic won't work for scalable vectors. For this
3333  // reason, "flattening" is not supported when shapes are dynamic (this
3334  // should be captured by one of the pre-conditions).
3335 
3336  // There are two options for handling the filter:
3337  // * shape_cast(broadcast(filter))
3338  // * broadcast(shuffle(filter))
3339  // Opt for the option without shape_cast to simplify the codegen.
3340  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3341  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3342 
3343  SmallVector<int64_t, 16> indicies;
3344  for (int i = 0; i < resSize / rhsSize; ++i) {
3345  for (int j = 0; j < rhsSize; ++j)
3346  indicies.push_back(j);
3347  }
3348 
3349  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
3350  }
3351  // Broadcast the filter to match the output vector
3352  rhs = rewriter.create<vector::BroadcastOp>(
3353  loc, resTy.clone(rhsTy.getElementType()), rhs);
3354 
3355  rhs = promote(rewriter, loc, rhs, resTy);
3356 
3357  if (!lhs || !rhs)
3358  return nullptr;
3359 
3360  if (isa<FloatType>(resTy.getElementType()))
3361  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3362 
3363  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3364  return rewriter.create<arith::AddIOp>(loc, mul, res);
3365  }
3366 
3367  /// Entry point for non-channeled convolution:
3368  /// {{w + kw}, {kw}, {w}}
3369  FailureOr<Operation *> generateNonChanneledConv() {
3370  AffineExpr w, kw;
3371  bindDims(ctx, w, kw);
3372  if (!iters({Par(), Red()}))
3373  return rewriter.notifyMatchFailure(op,
3374  "failed to match conv::W 1-par 1-red");
3375 
3376  // No transposition needed.
3377  if (layout({/*lhsIndex*/ {w + kw},
3378  /*rhsIndex*/ {kw},
3379  /*resIndex*/ {w}}))
3380  return conv(Conv1DOpOrder::W);
3381 
3382  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3383  }
3384 
3385  /// Entry point that transposes into the common form:
3386  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3387  FailureOr<Operation *> generateNwcConv() {
3388  AffineExpr n, w, f, kw, c;
3389  bindDims(ctx, n, w, f, kw, c);
3390  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3391  return rewriter.notifyMatchFailure(
3392  op, "failed to match conv::Nwc 3-par 2-red");
3393 
3394  // No transposition needed.
3395  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3396  /*rhsIndex*/ {kw, c, f},
3397  /*resIndex*/ {n, w, f}}))
3398  return conv(Conv1DOpOrder::Nwc);
3399 
3400  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3401  }
3402 
3403  /// Entry point that transposes into the common form:
3404  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3405  FailureOr<Operation *> generateNcwConv() {
3406  AffineExpr n, w, f, kw, c;
3407  bindDims(ctx, n, f, w, c, kw);
3408  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3409  return rewriter.notifyMatchFailure(
3410  op, "failed to match conv::Ncw 3-par 2-red");
3411 
3412  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3413  /*rhsIndex*/ {f, c, kw},
3414  /*resIndex*/ {n, f, w}}))
3415  return conv(Conv1DOpOrder::Ncw);
3416 
3417  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3418  }
3419 
3420  /// Entry point that transposes into the common form:
3421  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3422  FailureOr<Operation *> generateNwcPooling() {
3423  AffineExpr n, w, c, kw;
3424  bindDims(ctx, n, w, c, kw);
3425  if (!iters({Par(), Par(), Par(), Red()}))
3426  return rewriter.notifyMatchFailure(op,
3427  "failed to match pooling 3-par 1-red");
3428 
3429  // No transposition needed.
3430  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3431  /*rhsIndex*/ {kw},
3432  /*resIndex*/ {n, w, c}}))
3433  return conv(Conv1DOpOrder::Nwc);
3434 
3435  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3436  }
3437 
3438  /// Entry point that transposes into the common form:
3439  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3440  FailureOr<Operation *> generateNcwPooling() {
3441  AffineExpr n, w, c, kw;
3442  bindDims(ctx, n, c, w, kw);
3443  if (!iters({Par(), Par(), Par(), Red()}))
3444  return rewriter.notifyMatchFailure(op,
3445  "failed to match pooling 3-par 1-red");
3446 
3447  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3448  /*rhsIndex*/ {kw},
3449  /*resIndex*/ {n, c, w}}))
3450  return conv(Conv1DOpOrder::Ncw);
3451 
3452  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3453  }
3454 
3455  /// Entry point that transposes into the common form:
3456  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3457  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3458  bool vecChDimScalableFlag = false,
3459  bool flatten = false) {
3460  AffineExpr n, w, c, kw;
3461  bindDims(ctx, n, w, c, kw);
3462  if (!iters({Par(), Par(), Par(), Red()}))
3463  return rewriter.notifyMatchFailure(
3464  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3465 
3466  // No transposition needed.
3467  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3468  /*rhsIndex*/ {kw, c},
3469  /*resIndex*/ {n, w, c}}))
3470  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3471 
3472  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3473  }
3474 
3475 private:
3476  enum OperKind { Conv, Pool };
3477  bool valid = false;
3478  OperKind oper = Conv;
3479  StringAttr redOp;
3480  StringAttr poolExtOp;
3481  bool isPoolExt = false;
3482  int strideW, dilationW;
3483  Value lhsShaped, rhsShaped, resShaped;
3484  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3485 
3486  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3487  // Returns true iff it is a valid conv/pooling op.
3488  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3489  // + yield) and rhs is not used) then it is the body of a pooling
3490  // If conv, check for single `mul` predecessor. The `mul` operands must be
3491  // block arguments or extension of block arguments.
3492  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3493  // must be block arguments or extension of block arguments.
3494  bool setOperKind(Operation *reduceOp) {
3495  int numBlockArguments =
3496  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3497  switch (numBlockArguments) {
3498  case 1: {
3499  // Will be convolution if feeder is a MulOp.
3500  // Otherwise, if it can be pooling.
3501  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3502  llvm::IsaPred<BlockArgument>);
3503  Operation *feedOp = (*feedValIt).getDefiningOp();
3504  if (isCastOfBlockArgument(feedOp)) {
3505  oper = Pool;
3506  isPoolExt = true;
3507  poolExtOp = feedOp->getName().getIdentifier();
3508  } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3509  llvm::all_of(feedOp->getOperands(), [](Value v) {
3510  if (isa<BlockArgument>(v))
3511  return true;
3512  if (Operation *op = v.getDefiningOp())
3513  return isCastOfBlockArgument(op);
3514  return false;
3515  }))) {
3516  return false;
3517  }
3518  return true;
3519  }
3520  case 2:
3521  // Must be pooling
3522  oper = Pool;
3523  isPoolExt = false;
3524  return true;
3525  default:
3526  return false;
3527  }
3528  }
3529 };
3530 } // namespace
3531 
3532 /// Helper function to vectorize a LinalgOp with convolution semantics.
3533 // TODO: extend the generic vectorization to support windows and drop this.
3535  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3536  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3537  // The ConvolutionOpInterface gives us guarantees of existence for
3538  // strides/dilations. However, we do not need to rely on those, we can
3539  // simply use them if present, otherwise use the default and let the generic
3540  // conv. matcher in the ConvGenerator succeed or fail.
3541  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3542  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3543  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3544  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3545  Conv1DGenerator e(rewriter, op, stride, dilation);
3546  auto res = e.generateNonChanneledConv();
3547  if (succeeded(res))
3548  return res;
3549  res = e.generateNwcConv();
3550  if (succeeded(res))
3551  return res;
3552  res = e.generateNcwConv();
3553  if (succeeded(res))
3554  return res;
3555  res = e.generateNwcPooling();
3556  if (succeeded(res))
3557  return res;
3558  res = e.generateNcwPooling();
3559  if (succeeded(res))
3560  return res;
3561 
3562  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3563  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3564  // masked/scalable) is the channel dim (i.e. the trailing dim).
3565  uint64_t vecChDimSize = ShapedType::kDynamic;
3566  bool vecChDimScalableFlag = false;
3567  if (!inputVecSizes.empty()) {
3568  // Only use the input vector size corresponding to the channel dim. Other
3569  // vector dims will be inferred from the Ops.
3570  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3571  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3572  "Not a 1D depthwise conv!");
3573  size_t chDimIdx =
3575  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3576  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3577 
3578  vecChDimSize = inputVecSizes[chDimIdx];
3579  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3580  }
3581  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3582  flatten1DDepthwiseConv);
3583 }
3584 
3587 
3589  PatternRewriter &rewriter) const override {
3590  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3591  if (failed(resultOrFail))
3592  return failure();
3593  Operation *newOp = *resultOrFail;
3594  if (newOp->getNumResults() == 0) {
3595  rewriter.eraseOp(op.getOperation());
3596  return success();
3597  }
3598  assert(newOp->getNumResults() == 1 && "expected single result");
3599  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3600  return success();
3601  }
3602 };
3603 
3605  RewritePatternSet &patterns, PatternBenefit benefit) {
3606  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3607 }
static VectorShape vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:237
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:132
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:292
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:318
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:579
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
unsigned getNumInputs() const
Definition: AffineMap.cpp:387
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:139
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:609
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:30
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:289
OpListType & getOperations()
Definition: Block.h:134
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
IntegerType getI32Type()
Definition: Builders.cpp:83
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:150
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1391
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:215
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2236
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking=true)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:650
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:777
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:63
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:683
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1352
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.