MLIR  20.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66  OpType res;
67  block.walk([&](OpType op) {
68  if (res) {
69  res = nullptr;
70  return WalkResult::interrupt();
71  }
72  res = op;
73  return WalkResult::advance();
74  });
75  return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
82  int64_t nSize, int64_t wSize, int64_t cSize,
83  int64_t kwSize, int strideW, int dilationW,
84  int64_t wSizeStep, bool isSingleChanneled) {
85  SmallVector<Value> result;
86  if (isSingleChanneled) {
87  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88  // convolution.
89  SmallVector<int64_t> sizes{wSizeStep};
90  SmallVector<int64_t> strides{1};
91  for (int64_t kw = 0; kw < kwSize; ++kw) {
92  for (int64_t w = 0; w < wSize; w += wSizeStep) {
93  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95  }
96  }
97  } else {
98  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99  // for channeled convolution.
100  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101  SmallVector<int64_t> strides{1, 1, 1};
102  for (int64_t kw = 0; kw < kwSize; ++kw) {
103  for (int64_t w = 0; w < wSize; w += wSizeStep) {
104  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105  loc, input,
106  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107  sizes, strides));
108  }
109  }
110  }
111  return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
117  Location loc, Value filter,
118  int64_t kwSize) {
119  SmallVector<Value> result;
120  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121  // non-chanelled convolution] @ [kw].
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  result.push_back(rewriter.create<vector::ExtractOp>(
124  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125  }
126  return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
133  int64_t nSize, int64_t wSize, int64_t fSize,
134  int64_t wSizeStep, bool isSingleChanneled) {
135  SmallVector<Value> result;
136  if (isSingleChanneled) {
137  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138  SmallVector<int64_t> sizes{wSizeStep};
139  SmallVector<int64_t> strides{1};
140  for (int64_t w = 0; w < wSize; w += wSizeStep) {
141  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143  }
144  } else {
145  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146  // convolution.
147  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148  SmallVector<int64_t> strides{1, 1, 1};
149  for (int64_t w = 0; w < wSize; w += wSizeStep) {
150  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152  }
153  }
154  return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
159  Value res, int64_t wSize, int64_t wSizeStep,
160  SmallVectorImpl<Value> &resVals,
161  bool isSingleChanneled) {
162 
163  if (isSingleChanneled) {
164  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165  // This does not depend on kw.
166  SmallVector<int64_t> strides{1};
167  for (int64_t w = 0; w < wSize; w += wSizeStep) {
168  res = rewriter.create<vector::InsertStridedSliceOp>(
169  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170  }
171  } else {
172  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173  // convolution. This does not depend on kw.
174  SmallVector<int64_t> strides{1, 1, 1};
175  for (int64_t w = 0; w < wSize; w += wSizeStep) {
176  res = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178  strides);
179  }
180  }
181  return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
187  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189  /// Initializes the vectorization state, including the computation of the
190  /// canonical vector shape for vectorization.
191  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192  ArrayRef<int64_t> inputVectorSizes,
193  ArrayRef<bool> inputScalableVecDims);
194 
195  /// Returns the canonical vector shape used to vectorize the iteration space.
196  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198  /// Returns the vector dimensions that are scalable in the canonical vector
199  /// shape.
200  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202  /// Returns a vector type of the provided `elementType` with the canonical
203  /// vector shape and the corresponding fixed/scalable dimensions bit. If
204  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205  /// accordingly.
207  Type elementType,
208  std::optional<AffineMap> dimPermutation = std::nullopt) const {
210  SmallVector<bool> scalableDims;
211  if (dimPermutation.has_value()) {
212  vectorShape =
213  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214  scalableDims =
215  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216  } else {
217  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219  }
220 
221  return VectorType::get(vectorShape, elementType, scalableDims);
222  }
223 
224  /// Masks an operation with the canonical vector mask if the operation needs
225  /// masking. Returns the masked operation or the original operation if masking
226  /// is not needed. If provided, the canonical mask for this operation is
227  /// permuted using `maybeMaskingMap`.
228  Operation *
229  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230  std::optional<AffineMap> maybeMaskingMap = std::nullopt);
231 
232 private:
233  /// Initializes the iteration space static sizes using the Linalg op
234  /// information. This may become more complicated in the future.
235  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237  }
238 
239  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240  /// all the static and dynamic dimensions of the iteration space to be
241  /// vectorized and store them in `iterSpaceValueSizes`.
242  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243  LinalgOp linalgOp);
244 
245  /// Create or retrieve an existing mask value to mask `opToMask` in the
246  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247  /// permuted using that permutation map. If a new mask is created, it will be
248  /// cached for future users.
249  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250  LinalgOp linalgOp,
251  std::optional<AffineMap> maybeMaskingMap);
252 
253  // Holds the compile-time static sizes of the iteration space to vectorize.
254  // Dynamic dimensions are represented using ShapedType::kDynamic.
255  SmallVector<int64_t> iterSpaceStaticSizes;
256 
257  /// Holds the value sizes of the iteration space to vectorize. Static
258  /// dimensions are represented by 'arith.constant' and dynamic
259  /// dimensions by 'tensor/memref.dim'.
260  SmallVector<Value> iterSpaceValueSizes;
261 
262  /// Holds the canonical vector shape used to vectorize the iteration space.
263  SmallVector<int64_t> canonicalVecShape;
264 
265  /// Holds the vector dimensions that are scalable in the canonical vector
266  /// shape.
267  SmallVector<bool> scalableVecDims;
268 
269  /// Holds the active masks for permutations of the canonical vector iteration
270  /// space.
271  DenseMap<AffineMap, Value> activeMaskCache;
272 
273  /// Global vectorization guard for the incoming rewriter. It's initialized
274  /// when the vectorization state is initialized.
276 };
277 
278 LogicalResult
279 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
280  LinalgOp linalgOp) {
281  // TODO: Support 0-d vectors.
282  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
283  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
284  // Create constant index op for static dimensions.
285  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
286  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
287  continue;
288  }
289 
290  // Find an operand defined on this dimension of the iteration space to
291  // extract the runtime dimension size.
292  Value operand;
293  unsigned operandDimPos;
294  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
295  operandDimPos)))
296  return failure();
297 
298  Value dynamicDim = linalgOp.hasPureTensorSemantics()
299  ? (Value)rewriter.create<tensor::DimOp>(
300  linalgOp.getLoc(), operand, operandDimPos)
301  : (Value)rewriter.create<memref::DimOp>(
302  linalgOp.getLoc(), operand, operandDimPos);
303  iterSpaceValueSizes.push_back(dynamicDim);
304  }
305 
306  return success();
307 }
308 
309 /// Initializes the vectorization state, including the computation of the
310 /// canonical vector shape for vectorization.
311 // TODO: Move this to the constructor when we can remove the failure cases.
312 LogicalResult
313 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
314  ArrayRef<int64_t> inputVectorSizes,
315  ArrayRef<bool> inputScalableVecDims) {
316  // Initialize the insertion point.
317  rewriter.setInsertionPoint(linalgOp);
318 
319  if (!inputVectorSizes.empty()) {
320  // Get the canonical vector shape from the input vector sizes provided. This
321  // path should be taken to vectorize code with dynamic shapes and when using
322  // vector sizes greater than the iteration space sizes.
323  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
324  scalableVecDims.append(inputScalableVecDims.begin(),
325  inputScalableVecDims.end());
326  } else {
327  // Compute the canonical vector shape from the operation shape. If there are
328  // dynamic shapes, the operation won't be vectorized. We assume all the
329  // vector dimensions are fixed.
330  canonicalVecShape = linalgOp.getStaticLoopRanges();
331  scalableVecDims.append(linalgOp.getNumLoops(), false);
332  }
333 
334  LDBG("Canonical vector shape: ");
335  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
336  LLVM_DEBUG(llvm::dbgs() << "\n");
337  LDBG("Scalable vector dims: ");
338  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
339  LLVM_DEBUG(llvm::dbgs() << "\n");
340 
341  if (ShapedType::isDynamicShape(canonicalVecShape))
342  return failure();
343 
344  // Initialize iteration space static sizes.
345  initIterSpaceStaticSizes(linalgOp);
346 
347  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
348  // all the static and dynamic dimensions of the iteration space, needed to
349  // compute a mask during vectorization.
350  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
351  return failure();
352 
353  return success();
354 }
355 
356 /// Create or retrieve an existing mask value to mask `opToMask` in the
357 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
358 /// using that permutation map. If a new mask is created, it will be cached for
359 /// future users.
360 Value VectorizationState::getOrCreateMaskFor(
361  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362  std::optional<AffineMap> maybeMaskingMap) {
363  // No mask is needed if the operation is not maskable.
364  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365  if (!maskableOp)
366  return Value();
367 
368  assert(!maskableOp.isMasked() &&
369  "Masking an operation that is already masked");
370 
371  // If no masking map was provided, use an identity map with the loop dims.
372  assert((!maybeMaskingMap || *maybeMaskingMap) &&
373  "Unexpected null mask permutation map");
374  AffineMap maskingMap =
375  maybeMaskingMap ? *maybeMaskingMap
377  linalgOp.getNumLoops(), rewriter.getContext());
378 
379  LDBG("Masking map: " << maskingMap << "\n");
380 
381  // Return the active mask for the masking map of this operation if it was
382  // already created.
383  auto activeMaskIt = activeMaskCache.find(maskingMap);
384  if (activeMaskIt != activeMaskCache.end()) {
385  Value mask = activeMaskIt->second;
386  LDBG("Reusing mask: " << mask << "\n");
387  return mask;
388  }
389 
390  // Compute permuted projection of the iteration space to be masked and the
391  // corresponding mask shape. If the resulting iteration space dimensions are
392  // static and identical to the mask shape, masking is not needed for this
393  // operation.
394  // TODO: Improve this check. Only projected permutation indexing maps are
395  // supported.
396  SmallVector<int64_t> permutedStaticSizes =
397  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
398  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
399  auto maskShape = maskType.getShape();
400 
401  LDBG("Mask shape: ");
402  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
403  LLVM_DEBUG(llvm::dbgs() << "\n");
404 
405  if (permutedStaticSizes == maskShape) {
406  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
407  activeMaskCache[maskingMap] = Value();
408  return Value();
409  }
410 
411  // Permute the iteration space value sizes to compute the mask upper bounds.
412  SmallVector<Value> upperBounds =
413  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
414  assert(!maskShape.empty() && !upperBounds.empty() &&
415  "Masked 0-d vectors are not supported yet");
416 
417  // Create the mask based on the dimension values.
418  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
419  maskType, upperBounds);
420  LDBG("Creating new mask: " << mask << "\n");
421  activeMaskCache[maskingMap] = mask;
422  return mask;
423 }
424 
425 /// Masks an operation with the canonical vector mask if the operation needs
426 /// masking. Returns the masked operation or the original operation if masking
427 /// is not needed. If provided, the canonical mask for this operation is
428 /// permuted using `maybeMaskingMap`.
429 Operation *
431  LinalgOp linalgOp,
432  std::optional<AffineMap> maybeMaskingMap) {
433  LDBG("Trying to mask: " << *opToMask << "\n");
434 
435  // Create or retrieve mask for this operation.
436  Value mask =
437  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
438 
439  if (!mask) {
440  LDBG("No mask required\n");
441  return opToMask;
442  }
443 
444  // Wrap the operation with a new `vector.mask` and update D-U chain.
445  assert(opToMask && "Expected a valid operation to mask");
446  auto maskOp = cast<vector::MaskOp>(
447  mlir::vector::maskOperation(rewriter, opToMask, mask));
448  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
449 
450  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
451  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
452  maskOpTerminator);
453 
454  LDBG("Masked operation: " << *maskOp << "\n");
455  return maskOp;
456 }
457 
458 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
459 /// projectedPermutation, compress the unused dimensions to serve as a
460 /// permutation_map for a vector transfer operation.
461 /// For example, given a linalg op such as:
462 ///
463 /// ```
464 /// %0 = linalg.generic {
465 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
466 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
467 /// }
468 /// ins(%0 : tensor<2x3x4xf32>)
469 /// outs(%1 : tensor<5x6xf32>)
470 /// ```
471 ///
472 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
473 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
474 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
476  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
477  "expected projected permutation");
478  auto res = compressUnusedDims(map);
479  assert(res.getNumDims() == res.getNumResults() &&
480  "expected reindexed map with same number of dims and results");
481  return res;
482 }
483 
484 /// Helper enum to represent conv1d input traversal order.
485 enum class Conv1DOpOrder {
486  W, // Corresponds to non-channeled 1D convolution operation.
487  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
488  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
489 };
490 
491 /// Helper data structure to represent the result of vectorization.
492 /// In certain specific cases, like terminators, we do not want to propagate/
494  /// Op failed to vectorize.
495  Failure = 0,
496  /// Op vectorized and custom function took care of replacement logic
498  /// Op vectorized into a new Op whose results will replace original Op's
499  /// results.
500  NewOp
501  // TODO: support values if Op vectorized to Many-Ops whose results we need to
502  // aggregate for replacement.
503 };
505  /// Return status from vectorizing the current op.
507  /// New vectorized operation to replace the current op.
508  /// Replacement behavior is specified by `status`.
510 };
511 
512 std::optional<vector::CombiningKind>
514  using ::mlir::vector::CombiningKind;
515 
516  if (!combinerOp)
517  return std::nullopt;
519  .Case<arith::AddIOp, arith::AddFOp>(
520  [&](auto op) { return CombiningKind::ADD; })
521  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
522  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
523  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
524  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
525  .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
526  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
527  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
528  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
529  .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
530  .Case<arith::MulIOp, arith::MulFOp>(
531  [&](auto op) { return CombiningKind::MUL; })
532  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
533  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
534  .Default([&](auto op) { return std::nullopt; });
535 }
536 
537 /// Check whether `outputOperand` is a reduction with a single combiner
538 /// operation. Return the combiner operation of the reduction. Return
539 /// nullptr otherwise. Multiple reduction operations would impose an
540 /// ordering between reduction dimensions and is currently unsupported in
541 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
542 /// max(min(X))
543 // TODO: use in LinalgOp verification, there is a circular dependency atm.
544 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
545  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
546  unsigned outputPos =
547  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
548  // Only single combiner operations are supported for now.
549  SmallVector<Operation *, 4> combinerOps;
550  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
551  combinerOps.size() != 1)
552  return nullptr;
553 
554  // Return the combiner operation.
555  return combinerOps[0];
556 }
557 
558 /// Broadcast `value` to a vector of `shape` if possible. Return value
559 /// otherwise.
560 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
561  auto dstVecType = dyn_cast<VectorType>(dstType);
562  // If no shape to broadcast to, just return `value`.
563  if (dstVecType.getRank() == 0)
564  return value;
565  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
567  return value;
568  Location loc = b.getInsertionPoint()->getLoc();
569  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
570 }
571 
572 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
573 /// assumes that `reductionOp` has two operands and one of them is the reduction
574 /// initial value.buildMultiDimReduce
575 // Note: this is a true builder that notifies the OpBuilder listener.
576 // TODO: Consider moving as a static helper on the ReduceOp.
578  Value valueToReduce, Value acc,
579  ArrayRef<bool> dimsToMask) {
580  auto maybeKind = getCombinerOpKind(reduceOp);
581  assert(maybeKind && "Failed precondition: could not get reduction kind");
582  return b.create<vector::MultiDimReductionOp>(
583  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
584 }
585 
586 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
587  return llvm::to_vector(
588  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
589 }
590 
591 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
592 /// reduction iterator.
593 static bool hasReductionIterator(LinalgOp &op) {
594  return isa<linalg::ReduceOp>(op) ||
595  (isa<linalg::GenericOp>(op) &&
596  llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
597 }
598 
599 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
600 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
601 /// currently being vectorized. If `dest` has null rank, build an memref.store.
602 /// Return the produced value or null if no value is produced.
603 // Note: this is a true builder that notifies the OpBuilder listener.
604 // TODO: Consider moving as a static helper on the ReduceOp.
605 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
606  OpOperand *outputOperand,
607  VectorizationState &state) {
608  Location loc = value.getLoc();
609  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
610  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
611 
612  // Compute the vector type of the value to store. This type should be an
613  // identity or projection of the canonical vector type without any permutation
614  // applied, given that any permutation in a transfer write happens as part of
615  // the write itself.
617  opOperandMap.getContext(), opOperandMap.getNumInputs(),
618  [&](AffineDimExpr dimExpr) -> bool {
619  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
620  });
621  auto vectorType = state.getCanonicalVecType(
622  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
623 
624  Operation *write;
625  if (vectorType.getRank() > 0) {
626  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
627  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
628  rewriter.create<arith::ConstantIndexOp>(loc, 0));
629  value = broadcastIfNeeded(rewriter, value, vectorType);
630  assert(value.getType() == vectorType && "Incorrect type");
631  write = rewriter.create<vector::TransferWriteOp>(
632  loc, value, outputOperand->get(), indices, writeMap);
633  } else {
634  // 0-d case is still special: do not invert the reindexing writeMap.
635  if (!isa<VectorType>(value.getType()))
636  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
637  assert(value.getType() == vectorType && "Incorrect type");
638  write = rewriter.create<vector::TransferWriteOp>(
639  loc, value, outputOperand->get(), ValueRange{});
640  }
641 
642  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
643 
644  // If masked, set in-bounds to true. Masking guarantees that the access will
645  // be in-bounds.
646  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
647  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
648  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
649  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
650  }
651 
652  LDBG("vectorized op: " << *write << "\n");
653  if (!write->getResults().empty())
654  return write->getResult(0);
655  return Value();
656 }
657 
658 // Custom vectorization precondition function type. This is intented to be used
659 // with CustomVectorizationHook. Returns success if the corresponding custom
660 // hook can vectorize the op.
662  std::function<LogicalResult(Operation *, bool)>;
663 
664 // Custom vectorization function type. Produce a vector form of Operation*
665 // assuming all its vectorized operands are already in the IRMapping.
666 // Return nullptr if the Operation cannot be vectorized.
668  std::function<VectorizationResult(Operation *, const IRMapping &)>;
669 
670 /// Helper function to vectorize the terminator of a `linalgOp`. New result
671 /// vector values are appended to `newResults`. Return
672 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
673 /// should not try to map produced operations and instead return the results
674 /// using the `newResults` vector making them available to the vectorization
675 /// algorithm for RAUW. This function is meant to be used as a
676 /// CustomVectorizationHook.
677 static VectorizationResult
679  const IRMapping &bvm, VectorizationState &state,
680  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
681  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
682  if (!yieldOp)
684  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
685  // TODO: Scan for an opportunity for reuse.
686  // TODO: use a map.
687  Value vectorValue = bvm.lookup(output.value());
688  Value newResult =
689  buildVectorWrite(rewriter, vectorValue,
690  linalgOp.getDpsInitOperand(output.index()), state);
691  if (newResult)
692  newResults.push_back(newResult);
693  }
694 
696 }
697 
698 /// Helper function to vectorize the index operations of a `linalgOp`. Return
699 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
700 /// should map the produced operations. This function is meant to be used as a
701 /// CustomVectorizationHook.
703  VectorizationState &state,
704  Operation *op,
705  LinalgOp linalgOp) {
706  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
707  if (!indexOp)
709  auto loc = indexOp.getLoc();
710  // Compute the static loop sizes of the index op.
711  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
712  auto dim = indexOp.getDim();
713  // Compute a one-dimensional index vector for the index op dimension.
714  auto indexVectorType =
715  VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
716  state.getScalableVecDims()[dim]);
717  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
718  // Return the one-dimensional index vector if it lives in the trailing
719  // dimension of the iteration space since the vectorization algorithm in this
720  // case can handle the broadcast.
721  if (dim == targetShape.size() - 1)
723  // Otherwise permute the targetShape to move the index dimension last,
724  // broadcast the one-dimensional index vector to the permuted shape, and
725  // finally transpose the broadcasted index vector to undo the permutation.
726  auto permPattern =
727  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
728  std::swap(permPattern[dim], permPattern.back());
729  auto permMap =
730  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
731 
732  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
733  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
734  indexSteps);
735  SmallVector<int64_t> transposition =
736  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
737  std::swap(transposition.back(), transposition[dim]);
738  auto transposeOp =
739  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
740  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
741 }
742 
743 /// Helper function to check if the tensor.extract can be vectorized by the
744 /// custom hook vectorizeTensorExtract.
745 static LogicalResult
746 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
747  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
748  if (!extractOp)
749  return failure();
750 
751  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
752  return failure();
753 
754  // Check the index type, but only for non 0-d tensors (for which we do need
755  // access indices).
756  if (not extractOp.getIndices().empty()) {
757  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
758  return failure();
759  }
760 
761  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
762  return !VectorType::isValidElementType(type);
763  })) {
764  return failure();
765  }
766 
767  return success();
768 }
769 
770 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
771 /// generated from `tensor.extract`. The offset is calculated as follows
772 /// (example using scalar values):
773 ///
774 /// offset = extractOp.indices[0]
775 /// for (i = 1; i < numIndices; i++)
776 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
777 ///
778 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
779 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
781  VectorizationState &state,
782  tensor::ExtractOp extractOp,
783  const IRMapping &bvm) {
784  // The vector of indices for GatherOp should be shaped as the output vector.
785  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
786  auto loc = extractOp.getLoc();
787 
788  Value offset = broadcastIfNeeded(
789  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
790 
791  const size_t numIndices = extractOp.getIndices().size();
792  for (size_t i = 1; i < numIndices; i++) {
793  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
794 
795  auto dimSize = broadcastIfNeeded(
796  rewriter,
797  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
798  indexVecType);
799 
800  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
801 
802  auto extractOpIndex = broadcastIfNeeded(
803  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
804 
805  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
806  }
807 
808  return offset;
809 }
810 
812 
813 /// Checks whether /p val can be used for calculating a loop invariant index.
814 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
815 
816  auto targetShape = linalgOp.getStaticLoopRanges();
817  assert(llvm::count_if(targetShape,
818  [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
819  "n-D vectors are not yet supported");
820 
821  // Blocks outside _this_ linalg.generic are effectively loop invariant.
822  // However, analysing block arguments for _this_ linalg.generic Op is a bit
823  // tricky. Just bail out in the latter case.
824  // TODO: We could try analysing the corresponding affine map here.
825  auto *block = linalgOp.getBlock();
826  if (isa<BlockArgument>(val))
827  return llvm::all_of(block->getArguments(),
828  [&val](Value v) { return (v != val); });
829 
830  Operation *defOp = val.getDefiningOp();
831  assert(defOp && "This is neither a block argument nor an operation result");
832 
833  // IndexOp is loop invariant as long as its result remains constant across
834  // iterations. Given the assumptions on the loop ranges above, only the
835  // trailing loop dim ever changes.
836  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
837  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
838  return (indexOp.getDim() != trailingLoopDim);
839 
840  auto *ancestor = block->findAncestorOpInBlock(*defOp);
841 
842  // Values define outside `linalgOp` are loop invariant.
843  if (!ancestor)
844  return true;
845 
846  // Values defined inside `linalgOp`, which are constant, are loop invariant.
847  if (isa<arith::ConstantOp>(ancestor))
848  return true;
849 
850  bool result = true;
851  for (auto op : ancestor->getOperands())
852  result &= isLoopInvariantIdx(linalgOp, op);
853 
854  return result;
855 }
856 
857 /// Check whether \p val could be used for calculating the trailing index for a
858 /// contiguous load operation.
859 ///
860 /// There are currently 3 types of values that are allowed here:
861 /// 1. loop-invariant values,
862 /// 2. values that increment by 1 with every loop iteration,
863 /// 3. results of basic arithmetic operations (linear and continuous)
864 /// involving 1., 2. and 3.
865 /// This method returns True if indeed only such values are used in calculating
866 /// \p val.
867 ///
868 /// Additionally, the trailing index for a contiguous load operation should
869 /// increment by 1 with every loop iteration, i.e. be based on:
870 /// * `linalg.index <dim>` ,
871 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
872 /// updated to `true` when such an op is found.
873 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
874  bool &foundIndexOp) {
875 
876  auto targetShape = linalgOp.getStaticLoopRanges();
877  assert(((llvm::count_if(targetShape,
878  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
879  "n-D vectors are not yet supported");
880 
881  // Blocks outside _this_ linalg.generic are effectively loop invariant.
882  // However, analysing block arguments for _this_ linalg.generic Op is a bit
883  // tricky. Just bail out in the latter case.
884  // TODO: We could try analysing the corresponding affine map here.
885  auto *block = linalgOp.getBlock();
886  if (isa<BlockArgument>(val))
887  return llvm::all_of(block->getArguments(),
888  [&val](Value v) { return (v != val); });
889 
890  Operation *defOp = val.getDefiningOp();
891  assert(defOp && "This is neither a block argument nor an operation result");
892 
893  // Given the assumption on the loop ranges above, only the trailing loop
894  // index is not constant.
895  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
896  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
897  foundIndexOp = (indexOp.getDim() == trailingLoopDim);
898  return true;
899  }
900 
901  auto *ancestor = block->findAncestorOpInBlock(*defOp);
902 
903  if (!ancestor)
904  return false;
905 
906  // Conservatively reject Ops that could lead to indices with stride other
907  // than 1.
908  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
909  return false;
910 
911  bool result = false;
912  for (auto op : ancestor->getOperands())
913  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
914 
915  return result;
916 }
917 
918 /// Infer the memory access pattern for the input ExtractOp
919 ///
920 /// Based on the operation shapes and indices (usually based on the iteration
921 /// space of the parent `linalgOp` operation), decides whether the input
922 /// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
923 /// gather load.
924 ///
925 /// Note that it is always safe to use gather load operations for contiguous
926 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
927 /// that `extractOp` is a gather load.
929 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
930  LinalgOp &linalgOp) {
931 
932  auto targetShape = linalgOp.getStaticLoopRanges();
933  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
934 
935  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
936  if (inputShape.getShape().empty())
938 
939  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
940  // gather load.
941  // TODO: Relax this condition.
942  if (linalgOp.hasDynamicShape())
944 
945  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
946  // otherwise.
947  bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
948  return dimSize > 1;
949  }) == 1);
950 
951  // 1. Assume that it's a gather load when reading non-1D vector.
952  if (!isOutput1DVector)
954 
955  bool leadingIdxsLoopInvariant = true;
956 
957  // 2. Analyze the leading indices of `extractOp`.
958  // Look at the way each index is calculated and decide whether it is suitable
959  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
960  // gather load.
961  auto indices = extractOp.getIndices();
962  auto leadIndices = indices.drop_back(1);
963 
964  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
965  if (inputShape.getShape()[i] == 1)
966  continue;
967 
968  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
969  }
970 
971  if (!leadingIdxsLoopInvariant) {
972  LDBG("Found gather load: " << extractOp);
974  }
975 
976  // 3. Analyze the trailing index for `extractOp`.
977  // At this point we know that the leading indices are loop invariant. This
978  // means that is potentially a scalar or a contiguous load. We can decide
979  // based on the trailing idx.
980  auto extractOpTrailingIdx = indices.back();
981 
982  // 3a. Scalar broadcast load
983  // If the trailing index is loop invariant then this is a scalar load.
984  if (leadingIdxsLoopInvariant &&
985  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
986  LDBG("Found scalar broadcast load: " << extractOp);
987 
989  }
990 
991  // 3b. Contiguous loads
992  // The trailing `extractOp` index should increment with every loop iteration.
993  // This effectively means that it must be based on the trailing loop index.
994  // This is what the following bool captures.
995  bool foundIndexOp = false;
996  bool isContiguousLoad =
997  isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
998  isContiguousLoad &= foundIndexOp;
999 
1000  if (isContiguousLoad) {
1001  LDBG("Found contigous load: " << extractOp);
1003  }
1004 
1005  // 4. Fallback case - gather load.
1006  LDBG("Found gather load: " << extractOp);
1008 }
1009 
1010 /// Helper function to vectorize the tensor.extract operations. Returns
1011 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1012 /// should map the produced operations. This function is meant to be used as a
1013 /// CustomVectorizationHook.
1014 static VectorizationResult
1016  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1017  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1018  if (!extractOp)
1020  auto loc = extractOp.getLoc();
1021 
1022  // Compute the static loop sizes of the extract op.
1023  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1024  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1025  loc,
1026  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1027  /*value=*/true));
1028  auto passThruConstantOp =
1029  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1030 
1031  // Base indices are currently set to 0. We will need to re-visit if more
1032  // generic scenarios are to be supported.
1033  SmallVector<Value> baseIndices(
1034  extractOp.getIndices().size(),
1035  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1036 
1037  VectorMemoryAccessKind memAccessKind =
1038  getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1039 
1040  // 1. Handle gather access
1041  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1042  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1043 
1044  // Generate the gather load
1045  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1046  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1047  maskConstantOp, passThruConstantOp);
1048  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1049 
1050  LDBG("Vectorised as gather load: " << extractOp << "\n");
1052  }
1053 
1054  // 2. Handle:
1055  // a. scalar loads + broadcast,
1056  // b. contiguous loads.
1057  // Both cases use vector.transfer_read.
1058 
1059  // Collect indices for `vector.transfer_read`. At this point, the indices will
1060  // either be scalars or would have been broadcast to vectors matching the
1061  // result type. For indices that are vectors, there are two options:
1062  // * for non-trailing indices, all elements are identical (contiguous
1063  // loads are identified by looking for non-trailing indices that are
1064  // invariant with respect to the corresponding linalg.generic), or
1065  // * for trailing indices, the index vector will contain values with stride
1066  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1067  // needed.
1068  // This means that
1069  // * for scalar indices - just re-use it,
1070  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1071  // (0th) element and use that.
1072  SmallVector<Value> transferReadIdxs;
1073  auto zero = rewriter.create<arith::ConstantOp>(
1074  loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1075  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1076  Value idx = bvm.lookup(extractOp.getIndices()[i]);
1077  if (idx.getType().isIndex()) {
1078  transferReadIdxs.push_back(idx);
1079  continue;
1080  }
1081 
1082  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1083  loc,
1084  VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1085  resultType.getScalableDims().back()),
1086  idx);
1087  transferReadIdxs.push_back(
1088  rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1089  }
1090 
1091  // `tensor.extract_element` is always in-bounds, hence the following holds.
1092  auto dstRank = resultType.getRank();
1093  auto srcRank = extractOp.getTensor().getType().getRank();
1094  SmallVector<bool> inBounds(dstRank, true);
1095 
1096  // 2a. Handle scalar broadcast access.
1097  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1098  MLIRContext *ctx = rewriter.getContext();
1099  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1100  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1101 
1102  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1103  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1104  permutationMap, inBounds);
1105 
1106  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1107  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1108  }
1109 
1110  // 2b. Handle contiguous access.
1111  auto permutationMap = AffineMap::getMinorIdentityMap(
1112  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1113 
1114  int32_t rankDiff = dstRank - srcRank;
1115  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1116  // dims so that the ranks match. This is done by extending the map with 0s.
1117  // For example, for dstRank = 3, srcRank = 2, the following map created
1118  // above:
1119  // (d0, d1) --> (d0, d1)
1120  // is extended as:
1121  // (d0, d1) --> (0, d0, d1)
1122  while (rankDiff > 0) {
1123  permutationMap = permutationMap.insertResult(
1124  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1125  rankDiff--;
1126  }
1127 
1128  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1129  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1130  inBounds);
1131 
1132  LDBG("Vectorised as contiguous load: " << extractOp);
1133  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1134 }
1135 
1136 /// Emit reduction operations if the shapes of the value to reduce is different
1137 /// that the result shape.
1138 // Note: this is a true builder that notifies the OpBuilder listener.
1139 // TODO: Consider moving as a static helper on the ReduceOp.
1140 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1141  Value reduceValue, Value initialValue,
1142  const IRMapping &bvm) {
1143  Value reduceVec = bvm.lookup(reduceValue);
1144  Value outputVec = bvm.lookup(initialValue);
1145  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1146  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1147  // Reduce only if needed as the value may already have been reduce for
1148  // contraction vectorization.
1149  if (!reduceType ||
1150  (outputType && reduceType.getShape() == outputType.getShape()))
1151  return nullptr;
1152  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1153  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1154 }
1155 
1156 /// Generic vectorization for a single operation `op`, given already vectorized
1157 /// operands carried by `bvm`. Vectorization occurs as follows:
1158 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1159 /// result on success.
1160 /// 2. Clone any constant in the current scope without vectorization: each
1161 /// consumer of the constant will later determine the shape to which the
1162 /// constant needs to be broadcast to.
1163 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1164 /// of the `customVectorizationHooks` to cover such cases.
1165 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1166 /// operand of maximal rank. Other operands have smaller rank and are
1167 /// broadcast accordingly. It is assumed this broadcast is always legal,
1168 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1169 ///
1170 /// This function assumes all operands of `op` have been vectorized and are in
1171 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1172 /// a topologically-sorted list of ops.
1173 /// This function does not update `bvm` but returns a VectorizationStatus that
1174 /// instructs the caller what `bvm` update needs to occur.
1175 static VectorizationResult
1177  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1178  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1179  LDBG("vectorize op " << *op << "\n");
1180 
1181  // 1. Try to apply any CustomVectorizationHook.
1182  if (!customVectorizationHooks.empty()) {
1183  for (auto &customFunc : customVectorizationHooks) {
1184  VectorizationResult result = customFunc(op, bvm);
1185  if (result.status == VectorizationStatus::Failure)
1186  continue;
1187  return result;
1188  }
1189  }
1190 
1191  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1192  // Clone so that the constant is not confined to the linalgOp block .
1193  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1194  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1195 
1196  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1199 
1200  // 4 . Check if the operation is a reduction.
1201  SmallVector<std::pair<Value, Value>> reductionOperands;
1202  for (Value operand : op->getOperands()) {
1203  auto blockArg = dyn_cast<BlockArgument>(operand);
1204  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1205  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1206  continue;
1207  SmallVector<Operation *> reductionOps;
1208  Value reduceValue = matchReduction(
1209  linalgOp.getRegionOutputArgs(),
1210  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1211  if (!reduceValue)
1212  continue;
1213  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1214  }
1215  if (!reductionOperands.empty()) {
1216  assert(reductionOperands.size() == 1);
1217  Operation *reduceOp =
1218  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1219  reductionOperands[0].second, bvm);
1220  if (reduceOp)
1222  }
1223 
1224  // 5. Generic vectorization path for ElementwiseMappable ops.
1225  // a. Get the first max ranked shape.
1226  VectorType firstMaxRankedType;
1227  for (Value operand : op->getOperands()) {
1228  auto vecOperand = bvm.lookup(operand);
1229  assert(vecOperand && "Vector operand couldn't be found");
1230 
1231  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1232  if (vecType && (!firstMaxRankedType ||
1233  firstMaxRankedType.getRank() < vecType.getRank()))
1234  firstMaxRankedType = vecType;
1235  }
1236  // b. Broadcast each op if needed.
1237  SmallVector<Value> vecOperands;
1238  for (Value scalarOperand : op->getOperands()) {
1239  Value vecOperand = bvm.lookup(scalarOperand);
1240  assert(vecOperand && "Vector operand couldn't be found");
1241 
1242  if (firstMaxRankedType) {
1243  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1244  getElementTypeOrSelf(vecOperand.getType()),
1245  firstMaxRankedType.getScalableDims());
1246  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1247  } else {
1248  vecOperands.push_back(vecOperand);
1249  }
1250  }
1251  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1252  SmallVector<Type> resultTypes;
1253  for (Type resultType : op->getResultTypes()) {
1254  resultTypes.push_back(
1255  firstMaxRankedType
1256  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1257  firstMaxRankedType.getScalableDims())
1258  : resultType);
1259  }
1260  // d. Build and return the new op.
1261  return VectorizationResult{
1263  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1264  resultTypes, op->getAttrs())};
1265 }
1266 
1267 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1268 /// vector form. Generic vectorization proceeds as follows:
1269 /// 1. Verify the `linalgOp` has one non-empty region.
1270 /// 2. Values defined above the region are mapped to themselves and will be
1271 /// broadcasted on a per-need basis by their consumers.
1272 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1273 /// load).
1274 /// TODO: Reuse opportunities for RAR dependencies.
1275 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1276 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1277 /// iteration indices.
1278 /// 5. Iteratively call vectorizeOneOp on the region operations.
1279 ///
1280 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1281 /// performed to the maximal common vector size implied by the `linalgOp`
1282 /// iteration space. This eager broadcasting is introduced in the
1283 /// permutation_map of the vector.transfer_read operations. The eager
1284 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1285 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1286 /// the absence of good canonicalizations, the amount of work increases.
1287 /// This is not deemed a problem as we expect canonicalizations and foldings to
1288 /// aggressively clean up the useless work.
1289 static LogicalResult
1291  LinalgOp linalgOp,
1292  SmallVectorImpl<Value> &newResults) {
1293  LDBG("Vectorizing operation as linalg generic\n");
1294  Block *block = linalgOp.getBlock();
1295 
1296  // 2. Values defined above the region can only be broadcast for now. Make them
1297  // map to themselves.
1298  IRMapping bvm;
1299  SetVector<Value> valuesSet;
1300  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1301  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1302 
1303  if (linalgOp.getNumDpsInits() == 0)
1304  return failure();
1305 
1306  // 3. Turn all BBArgs into vector.transfer_read / load.
1307  Location loc = linalgOp.getLoc();
1308  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1309  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1310  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1311  if (linalgOp.isScalar(opOperand)) {
1312  bvm.map(bbarg, opOperand->get());
1313  continue;
1314  }
1315 
1316  // 3.a. Convert the indexing map for this input/output to a transfer read
1317  // permutation map and masking map.
1318  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1319 
1320  // Remove zeros from indexing map to use it as masking map.
1321  SmallVector<int64_t> zeroPos;
1322  auto results = indexingMap.getResults();
1323  for (const auto &result : llvm::enumerate(results)) {
1324  if (isa<AffineConstantExpr>(result.value())) {
1325  zeroPos.push_back(result.index());
1326  }
1327  }
1328  AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1329 
1330  AffineMap readMap;
1331  VectorType readType;
1332  Type elemType = getElementTypeOrSelf(opOperand->get());
1333  if (linalgOp.isDpsInput(opOperand)) {
1334  // 3.a.i. For input reads we use the canonical vector shape.
1335  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1336  readType = state.getCanonicalVecType(elemType);
1337  } else {
1338  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1339  // reductions), the vector shape is computed by mapping the canonical
1340  // vector shape to the output domain and back to the canonical domain.
1341  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1342  readType =
1343  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1344  }
1345 
1346  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1347 
1348  // Make sure that the in_bounds attribute corresponding to a broadcast dim
1349  // is `true`
1350  SmallVector<unsigned> broadcastedDims = readMap.getBroadcastDims();
1351  SmallVector<bool> inBounds(readType.getRank(), false);
1352 
1353  for (auto idx : broadcastedDims)
1354  inBounds[idx] = true;
1355 
1356  Operation *read = rewriter.create<vector::TransferReadOp>(
1357  loc, readType, opOperand->get(), indices, readMap,
1358  ArrayRef<bool>(inBounds));
1359  read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1360  Value readValue = read->getResult(0);
1361 
1362  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1363  // will be in-bounds.
1364  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1365  SmallVector<bool> inBounds(readType.getRank(), true);
1366  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1367  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1368  }
1369 
1370  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1371  // TODO: remove this.
1372  if (readType.getRank() == 0)
1373  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1374 
1375  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1376  << "\n");
1377  bvm.map(bbarg, readValue);
1378  bvm.map(opOperand->get(), readValue);
1379  }
1380 
1382  // 4a. Register CustomVectorizationHook for yieldOp.
1383  CustomVectorizationHook vectorizeYield =
1384  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1385  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1386  };
1387  hooks.push_back(vectorizeYield);
1388 
1389  // 4b. Register CustomVectorizationHook for indexOp.
1390  CustomVectorizationHook vectorizeIndex =
1391  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1392  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1393  };
1394  hooks.push_back(vectorizeIndex);
1395 
1396  // 4c. Register CustomVectorizationHook for extractOp.
1397  CustomVectorizationHook vectorizeExtract =
1398  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1399  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1400  };
1401  hooks.push_back(vectorizeExtract);
1402 
1403  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1404  for (Operation &op : block->getOperations()) {
1405  VectorizationResult result =
1406  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1407  if (result.status == VectorizationStatus::Failure) {
1408  LDBG("failed to vectorize: " << op << "\n");
1409  return failure();
1410  }
1411  if (result.status == VectorizationStatus::NewOp) {
1412  Operation *maybeMaskedOp =
1413  state.maskOperation(rewriter, result.newOp, linalgOp);
1414  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1415  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1416  }
1417  }
1418 
1419  return success();
1420 }
1421 
1422 /// Given a tensor::PackOp, return the `dest` shape before any packing
1423 /// permutations.
1424 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1425  ArrayRef<int64_t> destShape) {
1426  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1427 }
1428 
1429 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1430 /// create an empty destination tensor and create a TransferWriteOp from the
1431 /// input to the empty tensor. If the destination shape is not the same as the
1432 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1433 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1434 /// inBounds attribute of the transfer write op instead of masking.
1436  Value input,
1437  SmallVector<OpFoldResult> destSizes,
1438  ArrayRef<int64_t> inputVectorSizes,
1439  bool useInBoundsInsteadOfMasking) {
1440 
1441  auto inputType = cast<VectorType>(input.getType());
1442  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1443  inputType.getElementType());
1444  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1445  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1446  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1447  SmallVector<bool> inBoundsVal(rank, true);
1448  if (useInBoundsInsteadOfMasking) {
1449  // Update the inBounds attribute.
1450  for (unsigned i = 0; i < rank; i++)
1451  inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1452  !ShapedType::isDynamic(destShape[i]);
1453  }
1454  Operation *write = builder.create<vector::TransferWriteOp>(
1455  loc,
1456  /*vector=*/input,
1457  /*source=*/dest,
1458  /*indices=*/SmallVector<Value>(rank, zero),
1459  /*inBounds=*/inBoundsVal);
1460  assert(llvm::none_of(
1461  destShape.drop_front(inputVectorSizes.size()),
1462  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1463  "Only dims aligned with inputVectorSizes may be dynamic");
1464  if (useInBoundsInsteadOfMasking)
1465  return write;
1466  bool needMaskForWrite = !llvm::equal(
1467  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1468  if (needMaskForWrite) {
1469  SmallVector<int64_t> writeMaskShape;
1470  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1471  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1472  destShape.end());
1473  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1474  Value maskForWrite =
1475  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1476  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1477  }
1478  return write;
1479 }
1480 
1481 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1482 /// padding value and (3) input vector sizes into:
1483 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1484 /// As in the following example:
1485 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1486 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1487 ///
1488 /// This pack would be vectorized to:
1489 ///
1490 /// %load = vector.mask %mask {
1491 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1492 /// {in_bounds = [true, true, true]} :
1493 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1494 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1495 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1496 /// to vector<32x4x2x1x16xf32>
1497 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1498 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1499 /// %write = vector.transfer_write %transpose,
1500 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1501 /// {in_bounds = [true, true, true, true, true]}
1502 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1503 ///
1504 /// If the (3) input vector sizes are not provided, the vector sizes are
1505 /// determined by the result tensor shape. Also, we update the inBounds
1506 /// attribute instead of masking.
1507 static LogicalResult
1508 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1509  ArrayRef<int64_t> inputVectorSizes,
1510  SmallVectorImpl<Value> &newResults) {
1511  OpBuilder::InsertionGuard g(rewriter);
1512  rewriter.setInsertionPoint(packOp);
1513 
1514  Location loc = packOp.getLoc();
1515  auto padValue = packOp.getPaddingValue();
1516  if (!padValue) {
1517  padValue = rewriter.create<arith::ConstantOp>(
1518  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1519  }
1520  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1521  LogicalResult status =
1522  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1523  .reifyResultShapes(rewriter, reifiedReturnShapes);
1524  (void)status; // prevent unused variable warning on non-assert builds.
1525  assert(succeeded(status) && "failed to reify result shapes");
1526 
1527  // If the input vector sizes are not provided, then the vector sizes are
1528  // determined by the result tensor shape. In case the vector sizes aren't
1529  // provided, we update the inBounds attribute instead of masking.
1530  bool useInBoundsInsteadOfMasking = false;
1531  if (inputVectorSizes.empty()) {
1532  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1533  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1534  useInBoundsInsteadOfMasking = true;
1535  }
1536 
1537  // Create masked TransferReadOp.
1538  SmallVector<int64_t> inputShape(inputVectorSizes);
1539  auto innerTiles = packOp.getStaticInnerTiles();
1540  auto innerDimsPos = packOp.getInnerDimsPos();
1541  auto outerDimsPerm = packOp.getOuterDimsPerm();
1542  if (!outerDimsPerm.empty())
1543  applyPermutationToVector(inputShape,
1544  invertPermutationVector(outerDimsPerm));
1545  for (auto [idx, size] : enumerate(innerTiles))
1546  inputShape[innerDimsPos[idx]] *= size;
1547  auto maskedRead = vector::createReadOrMaskedRead(
1548  rewriter, loc, packOp.getSource(), inputShape, padValue,
1549  useInBoundsInsteadOfMasking);
1550 
1551  // Create ShapeCastOp.
1552  SmallVector<int64_t> destShape(inputVectorSizes);
1553  destShape.append(innerTiles.begin(), innerTiles.end());
1554  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1555  packOp.getDestType().getElementType());
1556  auto shapeCastOp =
1557  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1558 
1559  // Create TransposeOp.
1560  auto destPermutation =
1562  auto transposeOp = rewriter.create<vector::TransposeOp>(
1563  loc, shapeCastOp.getResult(), destPermutation);
1564 
1565  // Create TransferWriteOp.
1567  rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1568  inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1569  newResults.push_back(write->getResult(0));
1570  return success();
1571 }
1572 
1573 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1574 /// Vector::TransferReadOp - Reads a vector from the source tensor
1575 /// vector::TransposeOp - Transpose the Source tensor
1576 /// ShapeCastOp - Reshape the data based on the target.
1577 /// vector::TransferWriteOp. - Write the result vector back to the destination
1578 /// tensor.
1579 /// If the vector sizes are not provided:
1580 /// * the vector sizes are determined by the input operand and attributes,
1581 /// * update the inBounds attribute instead of masking.
1582 static LogicalResult
1583 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1584  ArrayRef<int64_t> inputVectorSizes,
1585  SmallVectorImpl<Value> &newResults) {
1586 
1587  OpBuilder::InsertionGuard g(rewriter);
1588  rewriter.setInsertionPoint(unpackOp);
1589 
1590  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1591 
1592  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1593  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1594  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1595  bool useInBoundsInsteadOfMasking = false;
1596  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1597 
1598  auto destSize = unpackOp.getDestRank();
1599 
1600  if (!inputVectorSizes.empty())
1601  assert(inputVectorSizes.size() == destSize &&
1602  "Incorrect number of input vector sizes");
1603 
1604  // vectorSizes is the shape of the vector that will be used to do final
1605  // write on the destination tensor. It is set like this: Let's say the
1606  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1607  // Thus:
1608  // 1. vectorSizes = sourceShape.take_front(N)
1609  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1610  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1611  // innerTiles attribute value.
1612  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1613  if (vectorSizes.empty()) {
1614  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1615  if (!outerDimsPerm.empty())
1616  applyPermutationToVector(vectorSizes, outerDimsPerm);
1617  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1618  vectorSizes[pos] *= innerTiles[i];
1619 
1620  useInBoundsInsteadOfMasking = true;
1621  }
1622 
1623  // readVectorSizes is the size of tensor used to read and apply mask. It is
1624  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1625  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1626  // size M-N
1627  // Thus:
1628  // - initially: readVectorSizes = vectorInputSizes
1629  // - Divide all the readMaskShape locations pointed by innerDimPos
1630  // by the innerTileSize attribute value.
1631  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1632  // - Append the remaining shape from SS
1633  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1634  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1635  // 128] and outer_dims_perm is [1, 0] then read shape is:
1636  // ReadVectorSizes(initial): [512, 128]
1637  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1638  // = [16, 8]
1639  // After applying outer_dims_perm: [8, 16]
1640  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1641 
1642  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1643 
1644  for (auto [index, size] : enumerate(innerTiles)) {
1645  readVectorSizes[innerDimPos[index]] =
1646  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1647  }
1648  if (!outerDimsPerm.empty()) {
1649  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1650  }
1651  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1652  sourceShape.end());
1653 
1654  ReifiedRankedShapedTypeDims reifiedRetShapes;
1655  LogicalResult status =
1656  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1657  .reifyResultShapes(rewriter, reifiedRetShapes);
1658  if (status.failed()) {
1659  LDBG("Unable to reify result shapes of " << unpackOp);
1660  return failure();
1661  }
1662  Location loc = unpackOp->getLoc();
1663 
1664  auto padValue = rewriter.create<arith::ConstantOp>(
1665  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1666 
1667  // Read result, mask if necessary. If transferReadOp shape is not equal
1668  // to shape of source, then a mask is necessary.
1670  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1671  /*useInBoundsInsteadOfMasking=*/false);
1672 
1673  PackingMetadata packMetadata;
1674  SmallVector<int64_t> lastDimToInsertPosPerm =
1675  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1676  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1677  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1678  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1679  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1680  RankedTensorType stripMineTensorType =
1681  RankedTensorType::get(stripMineShape, stripMineElemType);
1682  // Transpose the appropriate rows to match output.
1683  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1684  loc, readResult, lastDimToInsertPosPerm);
1685 
1686  // Collapse the vector to the size required by result.
1687  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1688  stripMineTensorType, packMetadata.reassociations);
1689  mlir::VectorType vecCollapsedType =
1690  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1691  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1692  loc, vecCollapsedType, transposeOp->getResult(0));
1693 
1694  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1695  // otherwise the validator complains that the mask size is invalid.
1696  SmallVector<int64_t> writeVectorSizes(
1697  unpackOp.getDestType().hasStaticShape()
1698  ? vectorSizes
1699  : shapeCastOp.getResultVectorType().getShape());
1701  rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1702  writeVectorSizes, useInBoundsInsteadOfMasking);
1703  newResults.push_back(write->getResult(0));
1704  return success();
1705 }
1706 
1707 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1708 /// and (3) all-zero lowPad to
1709 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1710 static LogicalResult
1711 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1712  ArrayRef<int64_t> inputVectorSizes,
1713  SmallVectorImpl<Value> &newResults) {
1714  auto padValue = padOp.getConstantPaddingValue();
1715  Location loc = padOp.getLoc();
1716 
1717  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1718  OpBuilder::InsertionGuard g(rewriter);
1719  rewriter.setInsertionPoint(padOp);
1720 
1721  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1722  LogicalResult status =
1723  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1724  .reifyResultShapes(rewriter, reifiedReturnShapes);
1725  (void)status; // prevent unused variable warning on non-assert builds
1726  assert(succeeded(status) && "failed to reify result shapes");
1727  auto maskedRead = vector::createReadOrMaskedRead(
1728  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1729  /*useInBoundsInsteadOfMasking=*/false);
1731  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1732  /*useInBoundsInsteadOfMasking=*/false);
1733  newResults.push_back(write->getResult(0));
1734  return success();
1735 }
1736 
1737 // TODO: probably need some extra checks for reduction followed by consumer
1738 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1739 static LogicalResult reductionPreconditions(LinalgOp op) {
1740  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1741  LDBG("reduction precondition failed: no reduction iterator\n");
1742  return failure();
1743  }
1744  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1745  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1746  if (indexingMap.isPermutation())
1747  continue;
1748 
1749  Operation *reduceOp = matchLinalgReduction(&opOperand);
1750  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1751  LDBG("reduction precondition failed: reduction detection failed\n");
1752  return failure();
1753  }
1754  }
1755  return success();
1756 }
1757 
1758 static LogicalResult
1760  bool flatten1DDepthwiseConv) {
1761  if (flatten1DDepthwiseConv) {
1762  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1763  "supported\n");
1764  return failure();
1765  }
1766 
1767  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1768  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1769  return failure();
1770  }
1771 
1772  // Support dynamic shapes in 1D depthwise convolution, but only in the
1773  // _channel_ dimension.
1774  Value lhs = conv.getDpsInputOperand(0)->get();
1775  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1776  auto shapeWithoutCh = lhsShape.drop_back(1);
1777  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1778  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1779  "channel dim can be dynamic\n");
1780  return failure();
1781  }
1782 
1783  return success();
1784 }
1785 
1786 static LogicalResult
1788  bool flatten1DDepthwiseConv) {
1789  if (isa<ConvolutionOpInterface>(op.getOperation()))
1790  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1791 
1792  if (hasReductionIterator(op))
1793  return reductionPreconditions(op);
1794 
1795  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1796  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1797  if (!isElementwise(op) &&
1798  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1799  op.getOperation()))
1800  return failure();
1801 
1802  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1803  return success();
1804 }
1805 
1806 /// Need to check if the inner-tiles are static/constant.
1807 static LogicalResult
1808 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1809  ArrayRef<int64_t> inputVectorSizes) {
1810 
1811  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1812  return !getConstantIntValue(res).has_value();
1813  })) {
1814  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1815  return failure();
1816  }
1817  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1818  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1819  unpackOp.getDestType().hasStaticShape() &&
1820  unpackOp.getSourceType().hasStaticShape();
1821  if (!satisfyEmptyCond &&
1822  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1823  return failure();
1824 
1825  return success();
1826 }
1827 
1828 static LogicalResult vectorizeLinalgOpPrecondition(
1829  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1830  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1831  // tensor with dimension of 0 cannot be vectorized.
1832  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1833  return failure();
1834  // Check API contract for input vector sizes.
1835  if (!inputVectorSizes.empty() &&
1836  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1837  inputVectorSizes)))
1838  return failure();
1839 
1840  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1841  linalgOp, flatten1DDepthwiseConv))) {
1842  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1843  return failure();
1844  }
1845 
1847 
1848  // Register CustomVectorizationPrecondition for extractOp.
1849  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1850 
1851  // All types in the body should be a supported element type for VectorType.
1852  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1853  // Check if any custom hook can vectorize the inner op.
1854  if (llvm::any_of(
1855  customPreconditions,
1856  [&](const CustomVectorizationPrecondition &customPrecondition) {
1857  return succeeded(
1858  customPrecondition(&innerOp, vectorizeNDExtract));
1859  })) {
1860  continue;
1861  }
1862  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1863  return !VectorType::isValidElementType(type);
1864  })) {
1865  return failure();
1866  }
1867  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1868  return !VectorType::isValidElementType(type);
1869  })) {
1870  return failure();
1871  }
1872  }
1873  if (isElementwise(linalgOp))
1874  return success();
1875 
1876  // TODO: isaConvolutionOpInterface that can also infer from generic
1877  // features. But we will still need stride/dilation attributes that will be
1878  // annoying to reverse-engineer...
1879  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1880  return success();
1881  // TODO: the common vector shape is equal to the static loop sizes only when
1882  // all indexing maps are projected permutations. For convs and stencils the
1883  // logic will need to evolve.
1884  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1885  LDBG("precondition failed: not projected permutations\n");
1886  return failure();
1887  }
1888  if (failed(reductionPreconditions(linalgOp))) {
1889  LDBG("precondition failed: reduction preconditions\n");
1890  return failure();
1891  }
1892  return success();
1893 }
1894 
1895 static LogicalResult
1896 vectorizePackOpPrecondition(tensor::PackOp packOp,
1897  ArrayRef<int64_t> inputVectorSizes) {
1898  auto padValue = packOp.getPaddingValue();
1899  Attribute cstAttr;
1900  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1901  LDBG("pad value is not constant: " << packOp << "\n");
1902  return failure();
1903  }
1904  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1905  bool satisfyEmptyCond = true;
1906  if (inputVectorSizes.empty()) {
1907  if (!packOp.getDestType().hasStaticShape() ||
1908  !packOp.getSourceType().hasStaticShape())
1909  satisfyEmptyCond = false;
1910  }
1911 
1912  if (!satisfyEmptyCond &&
1914  resultTensorShape.take_front(packOp.getSourceRank()),
1915  inputVectorSizes)))
1916  return failure();
1917 
1918  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1919  return !getConstantIntValue(v).has_value();
1920  })) {
1921  LDBG("inner_tiles must be constant: " << packOp << "\n");
1922  return failure();
1923  }
1924 
1925  return success();
1926 }
1927 
1928 static LogicalResult
1929 vectorizePadOpPrecondition(tensor::PadOp padOp,
1930  ArrayRef<int64_t> inputVectorSizes) {
1931  auto padValue = padOp.getConstantPaddingValue();
1932  if (!padValue) {
1933  LDBG("pad value is not constant: " << padOp << "\n");
1934  return failure();
1935  }
1936 
1937  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1938  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1939  inputVectorSizes)))
1940  return failure();
1941 
1942  if (llvm::any_of(padOp.getLow(), [](Value v) {
1943  std::optional<int64_t> res = getConstantIntValue(v);
1944  return !res.has_value() || res.value() != 0;
1945  })) {
1946  LDBG("low pad must all be zero: " << padOp << "\n");
1947  return failure();
1948  }
1949 
1950  return success();
1951 }
1952 
1953 /// Preconditions for scalable vectors. This is quite restrictive - it models
1954 /// the fact that in practice we would only make selected dimensions scalable.
1955 static LogicalResult
1957  ArrayRef<int64_t> inputVectorSizes,
1958  ArrayRef<bool> inputScalableVecDims) {
1959  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1960  "Number of input vector sizes and scalable dims doesn't match");
1961 
1962  size_t numOfScalableDims =
1963  llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
1964 
1965  if (numOfScalableDims == 0)
1966  return success();
1967 
1968  auto linalgOp = dyn_cast<LinalgOp>(op);
1969 
1970  // Cond 1: There's been no need for scalable vectorisation of
1971  // non-linalg Ops so far
1972  if (!linalgOp)
1973  return failure();
1974 
1975  // Cond 2: There's been no need for more than 2 scalable dims so far
1976  if (numOfScalableDims > 2)
1977  return failure();
1978 
1979  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
1980  // it matches one of the supported cases:
1981  // 1. exactly 1 dim is scalable and that's the _last_ parallel dim
1982  // 2. exactly 2 dims are scalable and those are the _last two adjacent_
1983  // parallel dims
1984  // 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
1985  // The 2nd restriction above means that only Matmul-like Ops are supported
1986  // when 2 dims are scalable, e.g. :
1987  // * iterators = [parallel, parallel, reduction]
1988  // * scalable flags = [true, true, false]
1989 
1990  // Find the first scalable flag
1991  bool seenParalell = false;
1992  auto iterators = linalgOp.getIteratorTypesArray();
1993  SmallVector<bool> scalableFlags(inputScalableVecDims);
1994  while (!scalableFlags.back()) {
1995  seenParalell |= (iterators.back() == utils::IteratorType::parallel);
1996 
1997  iterators.pop_back();
1998  scalableFlags.pop_back();
1999  }
2000 
2001  switch (iterators.back()) {
2002  case utils::IteratorType::reduction: {
2003  // Check 3. above is met.
2004  if (iterators.size() != inputVectorSizes.size()) {
2005  LDBG("Non-trailing reduction dim requested for scalable "
2006  "vectorization\n");
2007  return failure();
2008  }
2009  if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2010  LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2011  "is not supported\n");
2012  return failure();
2013  }
2014  break;
2015  }
2016  case utils::IteratorType::parallel: {
2017  // Check 1. and 2. above are met.
2018  if (seenParalell) {
2019  LDBG("Inner parallel dim not requested for scalable "
2020  "vectorization\n");
2021  return failure();
2022  }
2023  break;
2024  }
2025  }
2026 
2027  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2028  // supported for which expect the folowing config:
2029  // * iterators = [parallel, parallel, reduction]
2030  // * scalable flags = [true, true, false]
2031  if (numOfScalableDims == 2) {
2032  // Disallow below case which breaks 3. above:
2033  // * iterators = [..., parallel, reduction]
2034  // * scalable flags = [..., true, true]
2035  if (iterators.back() == utils::IteratorType::reduction) {
2036  LDBG("Higher dim than the trailing reduction dim requested for scalable "
2037  "vectorization\n");
2038  return failure();
2039  }
2040  scalableFlags.pop_back();
2041  iterators.pop_back();
2042 
2043  if (!scalableFlags.back() ||
2044  (iterators.back() != utils::IteratorType::parallel))
2045  return failure();
2046  }
2047 
2048  // Cond 4: Only the following ops are supported in the
2049  // presence of scalable vectors
2050  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2051  isa<linalg::MatmulTransposeAOp>(op) ||
2052  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2053  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2054 }
2055 
2057  Operation *op, ArrayRef<int64_t> inputVectorSizes,
2058  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2059  bool flatten1DDepthwiseConv) {
2060  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2061  inputScalableVecDims)))
2062  return failure();
2063 
2065  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2066  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2067  vectorizeNDExtract,
2068  flatten1DDepthwiseConv);
2069  })
2070  .Case<tensor::PadOp>([&](auto padOp) {
2071  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2072  })
2073  .Case<tensor::PackOp>([&](auto packOp) {
2074  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2075  })
2076  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2077  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2078  })
2079  .Default([](auto) { return failure(); });
2080 }
2081 
2082 /// Converts affine.apply Ops to arithmetic operations.
2083 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2084  OpBuilder::InsertionGuard g(rewriter);
2085  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2086 
2087  for (auto op : make_early_inc_range(toReplace)) {
2088  rewriter.setInsertionPoint(op);
2089  auto expanded = affine::expandAffineExpr(
2090  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2091  op.getOperands().take_front(op.getAffineMap().getNumDims()),
2092  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2093  rewriter.replaceOp(op, expanded);
2094  }
2095 }
2096 
2097 /// Emit a suitable vector form for an operation. If provided,
2098 /// `inputVectorSizes` are used to vectorize this operation.
2099 /// `inputVectorSizes` must match the rank of the iteration space of the
2100 /// operation and the input vector sizes must be greater than or equal to
2101 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2102 /// also allows the vectorization of operations with dynamic shapes.
2103 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2104  ArrayRef<int64_t> inputVectorSizes,
2105  ArrayRef<bool> inputScalableVecDims,
2106  bool vectorizeNDExtract,
2107  bool flatten1DDepthwiseConv) {
2108  LDBG("Attempting to vectorize:\n" << *op << "\n");
2109  LDBG("Input vector sizes: ");
2110  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2111  LLVM_DEBUG(llvm::dbgs() << "\n");
2112  LDBG("Input scalable vector dims: ");
2113  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2114  LLVM_DEBUG(llvm::dbgs() << "\n");
2115 
2116  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2117  vectorizeNDExtract,
2118  flatten1DDepthwiseConv))) {
2119  LDBG("Vectorization pre-conditions failed\n");
2120  return failure();
2121  }
2122 
2123  // Initialize vectorization state.
2124  VectorizationState state(rewriter);
2125  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2126  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2127  inputScalableVecDims))) {
2128  LDBG("Vectorization state couldn't be initialized\n");
2129  return failure();
2130  }
2131  }
2132 
2133  SmallVector<Value> results;
2134  auto vectorizeResult =
2136  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2137  // TODO: isaConvolutionOpInterface that can also infer from
2138  // generic features. Will require stride/dilation attributes
2139  // inference.
2140  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2141  FailureOr<Operation *> convOr = vectorizeConvolution(
2142  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2143  flatten1DDepthwiseConv);
2144  if (succeeded(convOr)) {
2145  llvm::append_range(results, (*convOr)->getResults());
2146  return success();
2147  }
2148 
2149  LDBG("Unsupported convolution can't be vectorized.\n");
2150  return failure();
2151  }
2152 
2153  LDBG("Vectorize generic by broadcasting to the canonical vector "
2154  "shape\n");
2155 
2156  // Pre-process before proceeding.
2157  convertAffineApply(rewriter, linalgOp);
2158 
2159  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2160  // to 'OpBuilder' when it is passed over to some methods like
2161  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2162  // erase an op within these methods, the actual rewriter won't be
2163  // notified and we will end up with read-after-free issues!
2164  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2165  })
2166  .Case<tensor::PadOp>([&](auto padOp) {
2167  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2168  results);
2169  })
2170  .Case<tensor::PackOp>([&](auto packOp) {
2171  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2172  results);
2173  })
2174  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2175  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2176  inputVectorSizes, results);
2177  })
2178  .Default([](auto) { return failure(); });
2179 
2180  if (failed(vectorizeResult)) {
2181  LDBG("Vectorization failed\n");
2182  return failure();
2183  }
2184 
2185  if (!results.empty())
2186  rewriter.replaceOp(op, results);
2187  else
2188  rewriter.eraseOp(op);
2189 
2190  return success();
2191 }
2192 
2194  memref::CopyOp copyOp) {
2195  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2196  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2197  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2198  return failure();
2199 
2200  auto srcElementType = getElementTypeOrSelf(srcType);
2201  auto dstElementType = getElementTypeOrSelf(dstType);
2202  if (!VectorType::isValidElementType(srcElementType) ||
2203  !VectorType::isValidElementType(dstElementType))
2204  return failure();
2205 
2206  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2207  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2208 
2209  Location loc = copyOp->getLoc();
2210  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2211  SmallVector<Value> indices(srcType.getRank(), zero);
2212 
2213  Value readValue = rewriter.create<vector::TransferReadOp>(
2214  loc, readType, copyOp.getSource(), indices,
2215  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2216  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2217  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2218  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2219  }
2220  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2221  loc, readValue, copyOp.getTarget(), indices,
2222  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2223  rewriter.replaceOp(copyOp, writeValue->getResults());
2224  return success();
2225 }
2226 
2227 //----------------------------------------------------------------------------//
2228 // Misc. vectorization patterns.
2229 //----------------------------------------------------------------------------//
2230 
2231 /// Helper function that retrieves the value of an IntegerAttr.
2232 static int64_t getIntFromAttr(Attribute attr) {
2233  return cast<IntegerAttr>(attr).getInt();
2234 }
2235 
2236 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2237 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2238 /// not supported.
2240  ArrayRef<OpFoldResult> ofrs) {
2241  SmallVector<Value> result;
2242  for (auto o : ofrs) {
2243  if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2244  result.push_back(val);
2245  } else {
2246  result.push_back(rewriter.create<arith::ConstantIndexOp>(
2247  loc, getIntFromAttr(o.template get<Attribute>())));
2248  }
2249  }
2250  return result;
2251 }
2252 
2253 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2254 /// InsertSliceOp. For now, only constant padding values are supported.
2255 /// If there is enough static type information, TransferReadOps and
2256 /// TransferWriteOps may be generated instead of InsertSliceOps.
2259  PatternBenefit benefit = 1)
2260  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2261  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2262  /// each dimension size is statically know in the source type or the result
2263  /// type (or both).
2264  static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2265  tensor::PadOp padOp, Value dest) {
2266  auto sourceType = padOp.getSourceType();
2267  auto resultType = padOp.getResultType();
2268  if (!VectorType::isValidElementType(sourceType.getElementType()))
2269  return failure();
2270 
2271  // Copy cannot be vectorized if pad value is non-constant and source shape
2272  // is dynamic. In case of a dynamic source shape, padding must be appended
2273  // by TransferReadOp, but TransferReadOp supports only constant padding.
2274  auto padValue = padOp.getConstantPaddingValue();
2275  if (!padValue) {
2276  if (!sourceType.hasStaticShape())
2277  return failure();
2278  // Create dummy padding value.
2279  auto elemType = sourceType.getElementType();
2280  padValue = rewriter.create<arith::ConstantOp>(
2281  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2282  }
2283 
2284  SmallVector<int64_t> vecShape;
2285  SmallVector<bool> readInBounds;
2286  SmallVector<bool> writeInBounds;
2287  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2288  if (!sourceType.isDynamicDim(i)) {
2289  vecShape.push_back(sourceType.getDimSize(i));
2290  // Source shape is statically known: Neither read nor write are
2291  // out-of- bounds.
2292  readInBounds.push_back(true);
2293  writeInBounds.push_back(true);
2294  } else if (!resultType.isDynamicDim(i)) {
2295  // Source shape is not statically known, but result shape is.
2296  // Vectorize with size of result shape. This may be larger than the
2297  // source size.
2298  vecShape.push_back(resultType.getDimSize(i));
2299  // Read may be out-of-bounds because the result size could be larger
2300  // than the source size.
2301  readInBounds.push_back(false);
2302  // Write is out-of-bounds if low padding > 0.
2303  writeInBounds.push_back(
2304  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2305  static_cast<int64_t>(0));
2306  } else {
2307  // Neither source nor result dim of padOp is static. Cannot vectorize
2308  // the copy.
2309  return failure();
2310  }
2311  }
2312  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2313 
2314  // Generate TransferReadOp.
2315  SmallVector<Value> readIndices(
2316  vecType.getRank(),
2317  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2318  auto read = rewriter.create<vector::TransferReadOp>(
2319  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2320  ArrayRef<bool>{readInBounds});
2321 
2322  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2323  // entire tensor, write directly to the FillOp's operand.
2324  if (llvm::equal(vecShape, resultType.getShape()) &&
2325  llvm::all_of(writeInBounds, [](bool b) { return b; }))
2326  if (auto fill = dest.getDefiningOp<FillOp>())
2327  dest = fill.output();
2328 
2329  // Generate TransferWriteOp.
2330  auto writeIndices =
2331  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2332  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2333  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2334 
2335  return success();
2336  }
2337 };
2338 
2339 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2340 /// given operation type OpTy.
2341 template <typename OpTy>
2342 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2344 
2345  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2346  PatternRewriter &rewriter) const final {
2347  bool changed = false;
2348  // Insert users in vector, because some users may be replaced/removed.
2349  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2350  if (auto op = dyn_cast<OpTy>(user))
2351  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2352  return success(changed);
2353  }
2354 
2355 protected:
2356  virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2357  tensor::PadOp padOp, OpTy op) const = 0;
2358 };
2359 
2360 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2361 /// ```
2362 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2363 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2364 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2365 /// ```
2366 /// is rewritten to:
2367 /// ```
2368 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2369 /// {in_bounds = [true, true]}
2370 /// : tensor<?x?xf32>, vector<17x5xf32>
2371 /// ```
2372 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2373 /// sure that the original padding value %cst was never used.
2374 ///
2375 /// This rewrite is possible if:
2376 /// - `xferOp` has no out-of-bounds dims or mask.
2377 /// - Low padding is static 0.
2378 /// - Single, scalar padding value.
2380  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2382  vector::TransferReadOp>::VectorizePadOpUserPattern;
2383 
2384  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2385  vector::TransferReadOp xferOp) const override {
2386  // Low padding must be static 0.
2387  if (!padOp.hasZeroLowPad())
2388  return failure();
2389  // Pad value must be a constant.
2390  auto padValue = padOp.getConstantPaddingValue();
2391  if (!padValue)
2392  return failure();
2393  // Padding value of existing `xferOp` is unused.
2394  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2395  return failure();
2396 
2397  rewriter.modifyOpInPlace(xferOp, [&]() {
2398  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2399  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2400  rewriter.getBoolArrayAttr(inBounds));
2401  xferOp.getSourceMutable().assign(padOp.getSource());
2402  xferOp.getPaddingMutable().assign(padValue);
2403  });
2404 
2405  return success();
2406  }
2407 };
2408 
2409 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2410 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2411 /// value, where the same amount of padding is immediately removed again after
2412 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2413 /// tensor value and apply out-of-bounds masking. E.g.:
2414 /// ```
2415 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2416 /// : tensor<...> to tensor<?x?xf32>
2417 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2418 /// %2 = vector.transfer_write %vec, %1[...]
2419 /// : vector<17x5xf32>, tensor<17x5xf32>
2420 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2421 /// : tensor<17x5xf32> to tensor<?x?xf32>
2422 /// ```
2423 /// is rewritten to:
2424 /// ```
2425 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2426 /// : tensor<...> to tensor<?x?xf32>
2427 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2428 /// tensor<?x?xf32>
2429 /// ```
2430 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2431 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2432 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2433 /// from %r's old dimensions.
2434 ///
2435 /// This rewrite is possible if:
2436 /// - Low padding is static 0.
2437 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2438 /// ExtractSliceOp trims the same amount of padding that was added
2439 /// beforehand.
2440 /// - Single, scalar padding value.
2442  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2444  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2445 
2446  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2447  vector::TransferWriteOp xferOp) const override {
2448  // TODO: support 0-d corner case.
2449  if (xferOp.getTransferRank() == 0)
2450  return failure();
2451 
2452  // Low padding must be static 0.
2453  if (!padOp.hasZeroLowPad())
2454  return failure();
2455  // Pad value must be a constant.
2456  auto padValue = padOp.getConstantPaddingValue();
2457  if (!padValue)
2458  return failure();
2459  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2460  if (!xferOp->hasOneUse())
2461  return failure();
2462  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2463  if (!trimPadding)
2464  return failure();
2465  // Only static zero offsets supported when trimming padding.
2466  if (!trimPadding.hasZeroOffset())
2467  return failure();
2468  // trimPadding must remove the amount of padding that was added earlier.
2469  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2470  return failure();
2471 
2472  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2473  rewriter.setInsertionPoint(xferOp);
2474 
2475  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2476  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2477  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2478  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2479  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2480  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2481 
2482  return success();
2483  }
2484 
2485  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2486  /// i.e., same dimensions.
2487  ///
2488  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2489  /// dimensions, this function tries to infer the (static) tensor size by
2490  /// looking at the defining op and utilizing op-specific knowledge.
2491  ///
2492  /// This is a conservative analysis. In case equal tensor sizes cannot be
2493  /// proven statically, this analysis returns `false` even though the tensor
2494  /// sizes may turn out to be equal at runtime.
2495  bool hasSameTensorSize(Value beforePadding,
2496  tensor::ExtractSliceOp afterTrimming) const {
2497  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2498  // result and CastOp operand.
2499  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2500  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2501  return true;
2502 
2503  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2504  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2505  // Only RankedTensorType supported.
2506  if (!t1 || !t2)
2507  return false;
2508  // Rank of both values must be the same.
2509  if (t1.getRank() != t2.getRank())
2510  return false;
2511 
2512  // All static dimensions must be the same. Mixed cases (e.g., dimension
2513  // static in `t1` but dynamic in `t2`) are not supported.
2514  for (unsigned i = 0; i < t1.getRank(); ++i) {
2515  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2516  return false;
2517  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2518  return false;
2519  }
2520 
2521  // Nothing more to check if all dimensions are static.
2522  if (t1.getNumDynamicDims() == 0)
2523  return true;
2524 
2525  // All dynamic sizes must be the same. The only supported case at the
2526  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2527  // thereof).
2528 
2529  // Apart from CastOp, only ExtractSliceOp is supported.
2530  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2531  if (!beforeSlice)
2532  return false;
2533 
2534  assert(static_cast<size_t>(t1.getRank()) ==
2535  beforeSlice.getMixedSizes().size());
2536  assert(static_cast<size_t>(t2.getRank()) ==
2537  afterTrimming.getMixedSizes().size());
2538 
2539  for (unsigned i = 0; i < t1.getRank(); ++i) {
2540  // Skip static dimensions.
2541  if (!t1.isDynamicDim(i))
2542  continue;
2543  auto size1 = beforeSlice.getMixedSizes()[i];
2544  auto size2 = afterTrimming.getMixedSizes()[i];
2545 
2546  // Case 1: Same value or same constant int.
2547  if (isEqualConstantIntOrValue(size1, size2))
2548  continue;
2549 
2550  // Other cases: Take a deeper look at defining ops of values.
2551  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2552  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2553  if (!v1 || !v2)
2554  return false;
2555 
2556  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2557  // CSE is run.)
2558  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2559  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2560  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2561  minOp1.getOperands() == minOp2.getOperands())
2562  continue;
2563 
2564  // Add additional cases as needed.
2565  }
2566 
2567  // All tests passed.
2568  return true;
2569  }
2570 };
2571 
2572 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2573 /// ```
2574 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2575 /// %r = tensor.insert_slice %0
2576 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2577 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2578 /// ```
2579 /// is rewritten to:
2580 /// ```
2581 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2582 /// : tensor<?x?xf32>, vector<17x5xf32>
2583 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2584 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2585 /// ```
2586 ///
2587 /// This rewrite is possible if:
2588 /// - Low padding is static 0.
2589 /// - `padOp` result shape is static.
2590 /// - The entire padded tensor is inserted.
2591 /// (Implies that sizes of `insertOp` are all static.)
2592 /// - Only unit strides in `insertOp`.
2593 /// - Single, scalar padding value.
2594 /// - `padOp` result not used as destination.
2596  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2598  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2599 
2600  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2601  tensor::InsertSliceOp insertOp) const override {
2602  // Low padding must be static 0.
2603  if (!padOp.hasZeroLowPad())
2604  return failure();
2605  // Only unit stride supported.
2606  if (!insertOp.hasUnitStride())
2607  return failure();
2608  // Pad value must be a constant.
2609  auto padValue = padOp.getConstantPaddingValue();
2610  if (!padValue)
2611  return failure();
2612  // Dynamic shapes not supported.
2613  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2614  return failure();
2615  // Pad result not used as destination.
2616  if (insertOp.getDest() == padOp.getResult())
2617  return failure();
2618 
2619  auto vecType = VectorType::get(padOp.getType().getShape(),
2620  padOp.getType().getElementType());
2621  unsigned vecRank = vecType.getRank();
2622  unsigned tensorRank = insertOp.getType().getRank();
2623 
2624  // Check if sizes match: Insert the entire tensor into most minor dims.
2625  // (No permutations allowed.)
2626  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2627  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2628  if (!llvm::all_of(
2629  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2630  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2631  }))
2632  return failure();
2633 
2634  // Insert the TransferReadOp and TransferWriteOp at the position of the
2635  // InsertSliceOp.
2636  rewriter.setInsertionPoint(insertOp);
2637 
2638  // Generate TransferReadOp: Read entire source tensor and add high
2639  // padding.
2640  SmallVector<Value> readIndices(
2641  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2642  auto read = rewriter.create<vector::TransferReadOp>(
2643  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2644 
2645  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2646  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2647  // source must fit into the destination at the specified offsets.
2648  auto writeIndices =
2649  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2650  SmallVector<bool> inBounds(vecRank, true);
2651  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2652  insertOp, read, insertOp.getDest(), writeIndices,
2653  ArrayRef<bool>{inBounds});
2654 
2655  return success();
2656  }
2657 };
2658 
2660  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2661  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2662  baseBenefit);
2663  // Try these specialized patterns first before resorting to the generic one.
2667  patterns.getContext(), baseBenefit.getBenefit() + 1);
2668 }
2669 
2670 //----------------------------------------------------------------------------//
2671 // Forwarding patterns
2672 //----------------------------------------------------------------------------//
2673 
2674 /// Check whether there is any interleaved use of any `values` between
2675 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2676 /// is in a different block.
2677 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2678  ValueRange values) {
2679  if (firstOp->getBlock() != secondOp->getBlock() ||
2680  !firstOp->isBeforeInBlock(secondOp)) {
2681  LDBG("interleavedUses precondition failed, firstOp: "
2682  << *firstOp << ", second op: " << *secondOp << "\n");
2683  return true;
2684  }
2685  for (auto v : values) {
2686  for (auto &u : v.getUses()) {
2687  Operation *owner = u.getOwner();
2688  if (owner == firstOp || owner == secondOp)
2689  continue;
2690  // TODO: this is too conservative, use dominance info in the future.
2691  if (owner->getBlock() == firstOp->getBlock() &&
2692  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2693  continue;
2694  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2695  << ", second op: " << *secondOp << "\n");
2696  return true;
2697  }
2698  }
2699  return false;
2700 }
2701 
2702 /// Return the unique subview use of `v` if it is indeed unique, null
2703 /// otherwise.
2704 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2705  memref::SubViewOp subViewOp;
2706  for (auto &u : v.getUses()) {
2707  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2708  if (subViewOp)
2709  return memref::SubViewOp();
2710  subViewOp = newSubViewOp;
2711  }
2712  }
2713  return subViewOp;
2714 }
2715 
2716 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2717 /// when available.
2719  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2720 
2721  // TODO: support mask.
2722  if (xferOp.getMask())
2723  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2724 
2725  // Transfer into `view`.
2726  Value viewOrAlloc = xferOp.getSource();
2727  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2728  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2729  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2730 
2731  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2732  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2733  if (!subViewOp)
2734  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2735  Value subView = subViewOp.getResult();
2736 
2737  // Find the copy into `subView` without interleaved uses.
2738  memref::CopyOp copyOp;
2739  for (auto &u : subView.getUses()) {
2740  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2741  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2742  if (newCopyOp.getTarget() != subView)
2743  continue;
2744  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2745  continue;
2746  copyOp = newCopyOp;
2747  break;
2748  }
2749  }
2750  if (!copyOp)
2751  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2752 
2753  // Find the fill into `viewOrAlloc` without interleaved uses before the
2754  // copy.
2755  FillOp maybeFillOp;
2756  for (auto &u : viewOrAlloc.getUses()) {
2757  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2758  assert(isa<MemRefType>(newFillOp.output().getType()));
2759  if (newFillOp.output() != viewOrAlloc)
2760  continue;
2761  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2762  continue;
2763  maybeFillOp = newFillOp;
2764  break;
2765  }
2766  }
2767  // Ensure padding matches.
2768  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2769  return rewriter.notifyMatchFailure(xferOp,
2770  "padding value does not match fill");
2771 
2772  // `in` is the subview that memref.copy reads. Replace it.
2773  Value in = copyOp.getSource();
2774 
2775  // memref.copy + linalg.fill can be used to create a padded local buffer.
2776  // The `masked` attribute is only valid on this padded buffer.
2777  // When forwarding to vector.transfer_read, the attribute must be reset
2778  // conservatively.
2779  auto vectorType = xferOp.getVectorType();
2780  Value res = rewriter.create<vector::TransferReadOp>(
2781  xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2782  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2783  rewriter.getBoolArrayAttr(
2784  SmallVector<bool>(vectorType.getRank(), false)));
2785 
2786  if (maybeFillOp)
2787  rewriter.eraseOp(maybeFillOp);
2788  rewriter.eraseOp(copyOp);
2789  rewriter.replaceOp(xferOp, res);
2790 
2791  return success();
2792 }
2793 
2794 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2795 /// when available.
2797  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2798  // TODO: support mask.
2799  if (xferOp.getMask())
2800  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2801 
2802  // Transfer into `viewOrAlloc`.
2803  Value viewOrAlloc = xferOp.getSource();
2804  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2805  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2806  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2807 
2808  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2809  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2810  if (!subViewOp)
2811  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2812  Value subView = subViewOp.getResult();
2813 
2814  // Find the copy from `subView` without interleaved uses.
2815  memref::CopyOp copyOp;
2816  for (auto &u : subViewOp.getResult().getUses()) {
2817  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2818  if (newCopyOp.getSource() != subView)
2819  continue;
2820  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2821  continue;
2822  copyOp = newCopyOp;
2823  break;
2824  }
2825  }
2826  if (!copyOp)
2827  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2828 
2829  // `out` is the subview copied into that we replace.
2830  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2831  Value out = copyOp.getTarget();
2832 
2833  // Forward vector.transfer into copy.
2834  // memref.copy + linalg.fill can be used to create a padded local buffer.
2835  // The `masked` attribute is only valid on this padded buffer.
2836  // When forwarding to vector.transfer_write, the attribute must be reset
2837  // conservatively.
2838  auto vector = xferOp.getVector();
2839  rewriter.create<vector::TransferWriteOp>(
2840  xferOp.getLoc(), vector, out, xferOp.getIndices(),
2841  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2842  rewriter.getBoolArrayAttr(
2843  SmallVector<bool>(vector.getType().getRank(), false)));
2844 
2845  rewriter.eraseOp(copyOp);
2846  rewriter.eraseOp(xferOp);
2847 
2848  return success();
2849 }
2850 
2851 //===----------------------------------------------------------------------===//
2852 // Convolution vectorization patterns
2853 //===----------------------------------------------------------------------===//
2854 
2855 template <int N>
2856 static void bindShapeDims(ShapedType shapedType) {}
2857 
2858 template <int N, typename IntTy, typename... IntTy2>
2859 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2860  val = shapedType.getShape()[N];
2861  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2862 }
2863 
2864 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2865 template <typename... IntTy>
2866 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2867  bindShapeDims<0>(shapedType, vals...);
2868 }
2869 
2870 namespace {
2871 bool isCastOfBlockArgument(Operation *op) {
2872  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2873  isa<BlockArgument>(op->getOperand(0));
2874 }
2875 
2876 bool isSupportedPoolKind(vector::CombiningKind kind) {
2877  switch (kind) {
2878  case vector::CombiningKind::ADD:
2879  case vector::CombiningKind::MAXNUMF:
2880  case vector::CombiningKind::MAXIMUMF:
2881  case vector::CombiningKind::MAXSI:
2882  case vector::CombiningKind::MAXUI:
2883  case vector::CombiningKind::MINNUMF:
2884  case vector::CombiningKind::MINIMUMF:
2885  case vector::CombiningKind::MINSI:
2887  return true;
2888  default:
2889  return false;
2890  }
2891 }
2892 
2893 /// Generate a vector implementation for either:
2894 /// ```
2895 /// Op def: ( w, kw )
2896 /// Iters: ({Par(), Red()})
2897 /// Layout: {{w + kw}, {kw}, {w}}
2898 /// ```
2899 /// kw is unrolled.
2900 ///
2901 /// or
2902 ///
2903 /// ```
2904 /// Op def: ( n, w, c, kw, f )
2905 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2906 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2907 /// ```
2908 /// kw is unrolled, w is unrolled iff dilationW > 1.
2909 ///
2910 /// or
2911 ///
2912 /// ```
2913 /// Op def: ( n, c, w, f, kw )
2914 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2915 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2916 /// ```
2917 /// kw is unrolled, w is unrolled iff dilationW > 1.
2918 ///
2919 /// or
2920 ///
2921 /// ```
2922 /// Op def: ( n, w, c, kw )
2923 /// Iters: ({Par(), Par(), Par(), Red()})
2924 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2925 /// ```
2926 /// kw is unrolled, w is unrolled iff dilationW > 1.
2927 struct Conv1DGenerator
2928  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2929  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2930  int dilationW)
2931  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2932  strideW(strideW), dilationW(dilationW) {
2933  // Determine whether `linalgOp` can be generated with this generator
2934  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2935  return;
2936  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2937  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2938  resShaped = linalgOp.getDpsInitOperand(0)->get();
2939  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2940  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2941  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2942  if (!lhsShapedType || !rhsShapedType || !resShapedType)
2943  return;
2944  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2945  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2946  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2947  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2948  return;
2949 
2950  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2951  if (!reduceOp)
2952  return;
2953  redOp = reduceOp->getName().getIdentifier();
2954 
2955  if (!setOperKind(reduceOp))
2956  return;
2957  auto maybeKind = getCombinerOpKind(reduceOp);
2958  if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2959  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2960  return;
2961  }
2962 
2963  auto rhsRank = rhsShapedType.getRank();
2964  switch (oper) {
2965  case Conv:
2966  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2967  return;
2968  break;
2969  case Pool:
2970  if (rhsRank != 1)
2971  return;
2972  break;
2973  }
2974  // The op is now known to be valid.
2975  valid = true;
2976  }
2977 
2978  /// Generate a vector implementation for:
2979  /// ```
2980  /// Op def: ( w, kw )
2981  /// Iters: ({Par(), Red()})
2982  /// Layout: {{w + kw}, {kw}, {w}}
2983  /// ```
2984  /// kw is always unrolled.
2985  ///
2986  /// or
2987  ///
2988  /// ```
2989  /// Op def: ( n, w, c, kw, f )
2990  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2991  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2992  /// ```
2993  /// kw is always unrolled.
2994  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2995  /// > 1.
2996  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
2997  if (!valid)
2998  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
2999 
3000  int64_t nSize, wSize, cSize, kwSize, fSize;
3001  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3002  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3003  switch (conv1DOpOrder) {
3004  case Conv1DOpOrder::W:
3005  // Initialize unused dimensions
3006  nSize = fSize = cSize = 0;
3007  // out{W}
3008  bindShapeDims(resShapedType, wSize);
3009  // kernel{kw}
3010  bindShapeDims(rhsShapedType, kwSize);
3011  lhsShape = {// iw = ow + kw - 1
3012  // (i.e. 16 convolved with 3 -> 14)
3013  (wSize + kwSize - 1)};
3014  rhsShape = {kwSize};
3015  resShape = {wSize};
3016  break;
3017  case Conv1DOpOrder::Nwc:
3018  // out{n, w, f}
3019  bindShapeDims(resShapedType, nSize, wSize, fSize);
3020  switch (oper) {
3021  case Conv:
3022  // kernel{kw, c, f}
3023  bindShapeDims(rhsShapedType, kwSize, cSize);
3024  break;
3025  case Pool:
3026  // kernel{kw}
3027  bindShapeDims(rhsShapedType, kwSize);
3028  cSize = fSize;
3029  break;
3030  }
3031  lhsShape = {nSize,
3032  // iw = ow * sw + kw * dw - 1
3033  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3034  // Perform the proper inclusive -> exclusive -> inclusive.
3035  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3036  1,
3037  cSize};
3038  switch (oper) {
3039  case Conv:
3040  rhsShape = {kwSize, cSize, fSize};
3041  break;
3042  case Pool:
3043  rhsShape = {kwSize};
3044  break;
3045  }
3046  resShape = {nSize, wSize, fSize};
3047  break;
3048  case Conv1DOpOrder::Ncw:
3049  // out{n, f, w}
3050  bindShapeDims(resShapedType, nSize, fSize, wSize);
3051  switch (oper) {
3052  case Conv:
3053  // kernel{f, c, kw}
3054  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3055  break;
3056  case Pool:
3057  // kernel{kw}
3058  bindShapeDims(rhsShapedType, kwSize);
3059  cSize = fSize;
3060  break;
3061  }
3062  lhsShape = {nSize, cSize,
3063  // iw = ow * sw + kw * dw - 1
3064  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3065  // Perform the proper inclusive -> exclusive -> inclusive.
3066  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3067  1};
3068  switch (oper) {
3069  case Conv:
3070  rhsShape = {fSize, cSize, kwSize};
3071  break;
3072  case Pool:
3073  rhsShape = {kwSize};
3074  break;
3075  }
3076  resShape = {nSize, fSize, wSize};
3077  break;
3078  }
3079 
3080  vector::TransferWriteOp write;
3081  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3082 
3083  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3084  // When strideW == 1, we can batch the contiguous loads and avoid
3085  // unrolling
3086  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3087 
3088  Type lhsEltType = lhsShapedType.getElementType();
3089  Type rhsEltType = rhsShapedType.getElementType();
3090  Type resEltType = resShapedType.getElementType();
3091  auto lhsType = VectorType::get(lhsShape, lhsEltType);
3092  auto rhsType = VectorType::get(rhsShape, rhsEltType);
3093  auto resType = VectorType::get(resShape, resEltType);
3094  // Zero padding with the corresponding dimensions for lhs, rhs and res.
3095  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3096  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3097  SmallVector<Value> resPadding(resShape.size(), zero);
3098 
3099  // Read the whole lhs, rhs and res in one shot (with zero padding).
3100  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3101  lhsPadding);
3102  // This is needed only for Conv.
3103  Value rhs = nullptr;
3104  if (oper == Conv)
3105  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3106  rhsPadding);
3107  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3108  resPadding);
3109 
3110  // The base vectorization case for channeled convolution is input:
3111  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3112  // vectorization case, we do pre transpose on input, weight, and output.
3113  switch (conv1DOpOrder) {
3114  case Conv1DOpOrder::W:
3115  case Conv1DOpOrder::Nwc:
3116  // Base case, so no transposes necessary.
3117  break;
3118  case Conv1DOpOrder::Ncw: {
3119  // To match base vectorization case, we pre-transpose current case.
3120  // ncw -> nwc
3121  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3122  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3123  // fcw -> wcf
3124  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3125 
3126  // This is needed only for Conv.
3127  if (oper == Conv)
3128  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3129  // nfw -> nwf
3130  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3131  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3132  break;
3133  }
3134  }
3135 
3136  //===------------------------------------------------------------------===//
3137  // Begin vector-only rewrite part
3138  //===------------------------------------------------------------------===//
3139  // Unroll along kw and read slices of lhs and rhs.
3140  SmallVector<Value> lhsVals, rhsVals, resVals;
3141  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3142  kwSize, strideW, dilationW, wSizeStep,
3143  isSingleChanneled);
3144  // Do not do for pooling.
3145  if (oper == Conv)
3146  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3147  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3148  wSizeStep, isSingleChanneled);
3149 
3150  auto linearIndex = [&](int64_t kw, int64_t w) {
3151  return kw * (wSize / wSizeStep) + w;
3152  };
3153 
3154  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3155  // or perform outerproduct for non-channeled convolution or perform simple
3156  // arith operation for pooling
3157  for (int64_t kw = 0; kw < kwSize; ++kw) {
3158  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3159  switch (oper) {
3160  case Conv:
3161  if (isSingleChanneled) {
3162  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3163  lhsVals[linearIndex(kw, w)],
3164  rhsVals[kw], resVals[w]);
3165  } else {
3166  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3167  lhsVals[linearIndex(kw, w)],
3168  rhsVals[kw], resVals[w]);
3169  }
3170  break;
3171  case Pool:
3172  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3173  resVals[w]);
3174  break;
3175  }
3176  }
3177  }
3178 
3179  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3180  isSingleChanneled);
3181  //===------------------------------------------------------------------===//
3182  // End vector-only rewrite part
3183  //===------------------------------------------------------------------===//
3184 
3185  // The base vectorization case for channeled convolution is output:
3186  // {n,w,f} To reuse the result from base pattern vectorization case, we
3187  // post transpose the base case result.
3188  switch (conv1DOpOrder) {
3189  case Conv1DOpOrder::W:
3190  case Conv1DOpOrder::Nwc:
3191  // Base case, so no transposes necessary.
3192  break;
3193  case Conv1DOpOrder::Ncw: {
3194  // nwf -> nfw
3195  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3196  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3197  break;
3198  }
3199  }
3200 
3201  return rewriter
3202  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3203  .getOperation();
3204  }
3205 
3206  // Take a value and widen to have the same element type as `ty`.
3207  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3208  const Type srcElementType = getElementTypeOrSelf(val.getType());
3209  const Type dstElementType = getElementTypeOrSelf(ty);
3210  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3211  if (srcElementType == dstElementType)
3212  return val;
3213 
3214  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3215  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3216  const Type dstType =
3217  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3218 
3219  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3220  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3221  }
3222 
3223  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3224  srcWidth < dstWidth)
3225  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3226 
3227  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3228  srcWidth < dstWidth)
3229  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3230 
3231  assert(false && "unhandled promotion case");
3232  return nullptr;
3233  }
3234 
3235  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3236  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3237  Value lhs, Value rhs, Value res) {
3238  vector::IteratorType par = vector::IteratorType::parallel;
3239  vector::IteratorType red = vector::IteratorType::reduction;
3240  AffineExpr n, w, f, c;
3241  bindDims(ctx, n, w, f, c);
3242  lhs = promote(rewriter, loc, lhs, res.getType());
3243  rhs = promote(rewriter, loc, rhs, res.getType());
3244  return rewriter.create<vector::ContractionOp>(
3245  loc, lhs, rhs, res,
3246  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3247  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3248  }
3249 
3250  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3251  // convolution.
3252  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3253  Value lhs, Value rhs, Value res) {
3254  return rewriter.create<vector::OuterProductOp>(
3255  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3256  }
3257 
3258  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3259  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3260  Value res) {
3261  if (isPoolExt)
3262  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3263  return rewriter
3264  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3265  ->getResult(0);
3266  }
3267 
3268  /// Generate a vector implementation for:
3269  /// ```
3270  /// Op def: ( n, w, c, kw)
3271  /// Iters: ({Par(), Par(), Par(), Red()})
3272  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3273  /// ```
3274  /// kw is always unrolled.
3275  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3276  /// > 1.
3277  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3278  bool channelDimScalableFlag,
3279  bool flatten) {
3280  if (!valid)
3281  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3282 
3283  bool scalableChDim = false;
3284  bool useMasking = false;
3285  int64_t nSize, wSize, cSize, kwSize;
3286  // kernel{kw, c}
3287  bindShapeDims(rhsShapedType, kwSize, cSize);
3288  if (ShapedType::isDynamic(cSize)) {
3289  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3290  cSize = channelDimVecSize;
3291  // Scalable vectors are only used when both conditions are met:
3292  // 1. channel dim is dynamic
3293  // 2. channelDimScalableFlag is set
3294  scalableChDim = channelDimScalableFlag;
3295  useMasking = true;
3296  }
3297 
3298  assert(!(useMasking && flatten) &&
3299  "Unsupported flattened conv with dynamic shapes");
3300 
3301  // out{n, w, c}
3302  bindShapeDims(resShapedType, nSize, wSize);
3303 
3304  vector::TransferWriteOp write;
3305  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3306 
3307  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3308  // When strideW == 1, we can batch the contiguous loads and avoid
3309  // unrolling
3310  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3311 
3312  Type lhsEltType = lhsShapedType.getElementType();
3313  Type rhsEltType = rhsShapedType.getElementType();
3314  Type resEltType = resShapedType.getElementType();
3315  VectorType lhsType = VectorType::get(
3316  {nSize,
3317  // iw = ow * sw + kw * dw - 1
3318  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3319  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3320  cSize},
3321  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3322  VectorType rhsType =
3323  VectorType::get({kwSize, cSize}, rhsEltType,
3324  /*scalableDims=*/{false, scalableChDim});
3325  VectorType resType =
3326  VectorType::get({nSize, wSize, cSize}, resEltType,
3327  /*scalableDims=*/{false, false, scalableChDim});
3328 
3329  // Masks the input xfer Op along the channel dim, iff the corresponding
3330  // scalable flag is set.
3331  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3332  ArrayRef<bool> scalableDims,
3333  Operation *opToMask) {
3334  if (!useMasking)
3335  return opToMask;
3336  auto maskType =
3337  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3338 
3339  SmallVector<bool> inBounds(maskShape.size(), true);
3340  auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3341  xferOp->setAttr(xferOp.getInBoundsAttrName(),
3342  rewriter.getBoolArrayAttr(inBounds));
3343 
3345  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3346 
3347  Value maskOp =
3348  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3349 
3350  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3351  };
3352 
3353  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3354  // 0].
3355  Value lhs = rewriter.create<vector::TransferReadOp>(
3356  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3357  auto maybeMaskedLhs = maybeMaskXferOp(
3358  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3359 
3360  // Read rhs slice of size {kw, c} @ [0, 0].
3361  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3362  ValueRange{zero, zero});
3363  auto maybeMaskedRhs = maybeMaskXferOp(
3364  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3365 
3366  // Read res slice of size {n, w, c} @ [0, 0, 0].
3367  Value res = rewriter.create<vector::TransferReadOp>(
3368  loc, resType, resShaped, ValueRange{zero, zero, zero});
3369  auto maybeMaskedRes = maybeMaskXferOp(
3370  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3371 
3372  //===------------------------------------------------------------------===//
3373  // Begin vector-only rewrite part
3374  //===------------------------------------------------------------------===//
3375  // Unroll along kw and read slices of lhs and rhs.
3376  SmallVector<Value> lhsVals, rhsVals, resVals;
3377  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3378  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3379 
3380  // Extract lhs slice of size {n, wSizeStep, c}
3381  // @ [0, sw * w + dw * kw, 0].
3382  for (int64_t kw = 0; kw < kwSize; ++kw) {
3383  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3384  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3385  loc, maybeMaskedLhs->getResult(0),
3386  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3387  inOutSliceSizes, inOutStrides));
3388  }
3389  }
3390  // Extract rhs slice of size {c} @ [kw].
3391  for (int64_t kw = 0; kw < kwSize; ++kw) {
3392  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3393  loc, maybeMaskedRhs->getResult(0),
3394  /*offsets=*/ArrayRef<int64_t>{kw}));
3395  }
3396  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3397  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3398  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3399  loc, maybeMaskedRes->getResult(0),
3400  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3401  inOutStrides));
3402  }
3403 
3404  auto linearIndex = [&](int64_t kw, int64_t w) {
3405  return kw * (wSize / wSizeStep) + w;
3406  };
3407 
3408  // Note - the scalable flags are ignored as flattening combined with
3409  // scalable vectorization is not supported.
3410  auto inOutFlattenSliceSizes =
3411  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3412  auto lhsTypeAfterFlattening =
3413  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3414  auto resTypeAfterFlattening =
3415  VectorType::get(inOutFlattenSliceSizes, resEltType);
3416 
3417  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3418  for (int64_t kw = 0; kw < kwSize; ++kw) {
3419  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3420  Value lhsVal = lhsVals[linearIndex(kw, w)];
3421  Value resVal = resVals[w];
3422  if (flatten) {
3423  // Flatten the input and output vectors (collapse the channel
3424  // dimension)
3425  lhsVal = rewriter.create<vector::ShapeCastOp>(
3426  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3427  resVal = rewriter.create<vector::ShapeCastOp>(
3428  loc, resTypeAfterFlattening, resVals[w]);
3429  }
3430  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3431  rhsVals[kw], resVal, flatten);
3432  if (flatten) {
3433  // Un-flatten the output vector (restore the channel dimension)
3434  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3435  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3436  }
3437  }
3438  }
3439 
3440  // Its possible we failed to create the Fma.
3441  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3442  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3443  for (auto &collection :
3444  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3445  for (Value v : collection)
3446  rewriter.eraseOp(v.getDefiningOp());
3447  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3448  }
3449 
3450  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3451  // This does not depend on kw.
3452  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3453  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3454  loc, resVals[w], maybeMaskedRes->getResult(0),
3455  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3456  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3457  }
3458  //===------------------------------------------------------------------===//
3459  // End vector-only rewrite part
3460  //===------------------------------------------------------------------===//
3461 
3462  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3463  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3464  loc, maybeMaskedRes->getResult(0), resShaped,
3465  ValueRange{zero, zero, zero});
3466  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3467  resOut);
3468  }
3469 
3470  /// Lower:
3471  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3472  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3473  /// to MulAcc.
3474  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3475  Value lhs, Value rhs, Value res,
3476  bool flatten) {
3477  auto rhsTy = cast<ShapedType>(rhs.getType());
3478  auto resTy = cast<ShapedType>(res.getType());
3479 
3480  // TODO(suderman): Change this to use a vector.ima intrinsic.
3481  lhs = promote(rewriter, loc, lhs, resTy);
3482 
3483  if (flatten) {
3484  // NOTE: This following logic won't work for scalable vectors. For this
3485  // reason, "flattening" is not supported when shapes are dynamic (this
3486  // should be captured by one of the pre-conditions).
3487 
3488  // There are two options for handling the filter:
3489  // * shape_cast(broadcast(filter))
3490  // * broadcast(shuffle(filter))
3491  // Opt for the option without shape_cast to simplify the codegen.
3492  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3493  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3494 
3495  SmallVector<int64_t, 16> indices;
3496  for (int i = 0; i < resSize / rhsSize; ++i) {
3497  for (int j = 0; j < rhsSize; ++j)
3498  indices.push_back(j);
3499  }
3500 
3501  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3502  }
3503  // Broadcast the filter to match the output vector
3504  rhs = rewriter.create<vector::BroadcastOp>(
3505  loc, resTy.clone(rhsTy.getElementType()), rhs);
3506 
3507  rhs = promote(rewriter, loc, rhs, resTy);
3508 
3509  if (!lhs || !rhs)
3510  return nullptr;
3511 
3512  if (isa<FloatType>(resTy.getElementType()))
3513  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3514 
3515  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3516  return rewriter.create<arith::AddIOp>(loc, mul, res);
3517  }
3518 
3519  /// Entry point for non-channeled convolution:
3520  /// {{w + kw}, {kw}, {w}}
3521  FailureOr<Operation *> generateNonChanneledConv() {
3522  AffineExpr w, kw;
3523  bindDims(ctx, w, kw);
3524  if (!iters({Par(), Red()}))
3525  return rewriter.notifyMatchFailure(op,
3526  "failed to match conv::W 1-par 1-red");
3527 
3528  // No transposition needed.
3529  if (layout({/*lhsIndex*/ {w + kw},
3530  /*rhsIndex*/ {kw},
3531  /*resIndex*/ {w}}))
3532  return conv(Conv1DOpOrder::W);
3533 
3534  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3535  }
3536 
3537  /// Entry point that transposes into the common form:
3538  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3539  FailureOr<Operation *> generateNwcConv() {
3540  AffineExpr n, w, f, kw, c;
3541  bindDims(ctx, n, w, f, kw, c);
3542  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3543  return rewriter.notifyMatchFailure(
3544  op, "failed to match conv::Nwc 3-par 2-red");
3545 
3546  // No transposition needed.
3547  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3548  /*rhsIndex*/ {kw, c, f},
3549  /*resIndex*/ {n, w, f}}))
3550  return conv(Conv1DOpOrder::Nwc);
3551 
3552  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3553  }
3554 
3555  /// Entry point that transposes into the common form:
3556  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3557  FailureOr<Operation *> generateNcwConv() {
3558  AffineExpr n, w, f, kw, c;
3559  bindDims(ctx, n, f, w, c, kw);
3560  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3561  return rewriter.notifyMatchFailure(
3562  op, "failed to match conv::Ncw 3-par 2-red");
3563 
3564  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3565  /*rhsIndex*/ {f, c, kw},
3566  /*resIndex*/ {n, f, w}}))
3567  return conv(Conv1DOpOrder::Ncw);
3568 
3569  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3570  }
3571 
3572  /// Entry point that transposes into the common form:
3573  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3574  FailureOr<Operation *> generateNwcPooling() {
3575  AffineExpr n, w, c, kw;
3576  bindDims(ctx, n, w, c, kw);
3577  if (!iters({Par(), Par(), Par(), Red()}))
3578  return rewriter.notifyMatchFailure(op,
3579  "failed to match pooling 3-par 1-red");
3580 
3581  // No transposition needed.
3582  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3583  /*rhsIndex*/ {kw},
3584  /*resIndex*/ {n, w, c}}))
3585  return conv(Conv1DOpOrder::Nwc);
3586 
3587  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3588  }
3589 
3590  /// Entry point that transposes into the common form:
3591  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3592  FailureOr<Operation *> generateNcwPooling() {
3593  AffineExpr n, w, c, kw;
3594  bindDims(ctx, n, c, w, kw);
3595  if (!iters({Par(), Par(), Par(), Red()}))
3596  return rewriter.notifyMatchFailure(op,
3597  "failed to match pooling 3-par 1-red");
3598 
3599  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3600  /*rhsIndex*/ {kw},
3601  /*resIndex*/ {n, c, w}}))
3602  return conv(Conv1DOpOrder::Ncw);
3603 
3604  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3605  }
3606 
3607  /// Entry point that transposes into the common form:
3608  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3609  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3610  bool vecChDimScalableFlag = false,
3611  bool flatten = false) {
3612  AffineExpr n, w, c, kw;
3613  bindDims(ctx, n, w, c, kw);
3614  if (!iters({Par(), Par(), Par(), Red()}))
3615  return rewriter.notifyMatchFailure(
3616  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3617 
3618  // No transposition needed.
3619  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3620  /*rhsIndex*/ {kw, c},
3621  /*resIndex*/ {n, w, c}}))
3622  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3623 
3624  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3625  }
3626 
3627 private:
3628  enum OperKind { Conv, Pool };
3629  bool valid = false;
3630  OperKind oper = Conv;
3631  StringAttr redOp;
3632  StringAttr poolExtOp;
3633  bool isPoolExt = false;
3634  int strideW, dilationW;
3635  Value lhsShaped, rhsShaped, resShaped;
3636  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3637 
3638  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3639  // Returns true iff it is a valid conv/pooling op.
3640  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3641  // + yield) and rhs is not used) then it is the body of a pooling
3642  // If conv, check for single `mul` predecessor. The `mul` operands must be
3643  // block arguments or extension of block arguments.
3644  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3645  // must be block arguments or extension of block arguments.
3646  bool setOperKind(Operation *reduceOp) {
3647  int numBlockArguments =
3648  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3649  switch (numBlockArguments) {
3650  case 1: {
3651  // Will be convolution if feeder is a MulOp.
3652  // Otherwise, if it can be pooling.
3653  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3654  llvm::IsaPred<BlockArgument>);
3655  Operation *feedOp = (*feedValIt).getDefiningOp();
3656  if (isCastOfBlockArgument(feedOp)) {
3657  oper = Pool;
3658  isPoolExt = true;
3659  poolExtOp = feedOp->getName().getIdentifier();
3660  } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3661  llvm::all_of(feedOp->getOperands(), [](Value v) {
3662  if (isa<BlockArgument>(v))
3663  return true;
3664  if (Operation *op = v.getDefiningOp())
3665  return isCastOfBlockArgument(op);
3666  return false;
3667  }))) {
3668  return false;
3669  }
3670  return true;
3671  }
3672  case 2:
3673  // Must be pooling
3674  oper = Pool;
3675  isPoolExt = false;
3676  return true;
3677  default:
3678  return false;
3679  }
3680  }
3681 };
3682 } // namespace
3683 
3684 /// Helper function to vectorize a LinalgOp with convolution semantics.
3685 // TODO: extend the generic vectorization to support windows and drop this.
3686 static FailureOr<Operation *> vectorizeConvolution(
3687  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3688  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3689  // The ConvolutionOpInterface gives us guarantees of existence for
3690  // strides/dilations. However, we do not need to rely on those, we can
3691  // simply use them if present, otherwise use the default and let the generic
3692  // conv. matcher in the ConvGenerator succeed or fail.
3693  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3694  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3695  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3696  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3697  Conv1DGenerator e(rewriter, op, stride, dilation);
3698  auto res = e.generateNonChanneledConv();
3699  if (succeeded(res))
3700  return res;
3701  res = e.generateNwcConv();
3702  if (succeeded(res))
3703  return res;
3704  res = e.generateNcwConv();
3705  if (succeeded(res))
3706  return res;
3707  res = e.generateNwcPooling();
3708  if (succeeded(res))
3709  return res;
3710  res = e.generateNcwPooling();
3711  if (succeeded(res))
3712  return res;
3713 
3714  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3715  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3716  // masked/scalable) is the channel dim (i.e. the trailing dim).
3717  uint64_t vecChDimSize = ShapedType::kDynamic;
3718  bool vecChDimScalableFlag = false;
3719  if (!inputVecSizes.empty()) {
3720  // Only use the input vector size corresponding to the channel dim. Other
3721  // vector dims will be inferred from the Ops.
3722  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3723  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3724  "Not a 1D depthwise conv!");
3725  size_t chDimIdx =
3727  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3728  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3729 
3730  vecChDimSize = inputVecSizes[chDimIdx];
3731  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3732  }
3733  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3734  flatten1DDepthwiseConv);
3735 }
3736 
3739 
3740  LogicalResult matchAndRewrite(LinalgOp op,
3741  PatternRewriter &rewriter) const override {
3742  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3743  if (failed(resultOrFail))
3744  return failure();
3745  Operation *newOp = *resultOrFail;
3746  if (newOp->getNumResults() == 0) {
3747  rewriter.eraseOp(op.getOperation());
3748  return success();
3749  }
3750  assert(newOp->getNumResults() == 1 && "expected single result");
3751  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3752  return success();
3753  }
3754 };
3755 
3757  RewritePatternSet &patterns, PatternBenefit benefit) {
3758  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3759 }
static VectorShape vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:595
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:142
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:625
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:290
OpListType & getOperations()
Definition: Block.h:135
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:406
IntegerType getI32Type()
Definition: Builders.cpp:95
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:343
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:85
IndexType getIndexType()
Definition: Builders.cpp:83
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:289
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:450
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:403
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:59
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:128
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:215
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2375
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:657
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:792
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:699
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1499
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.