MLIR  19.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/OpDefinition.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/Support/LLVM.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/Sequence.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/ADT/iterator_range.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/MathExtras.h"
42 #include "llvm/Support/raw_ostream.h"
43 #include <optional>
44 #include <type_traits>
45 
46 using namespace mlir;
47 using namespace mlir::linalg;
48 
49 #define DEBUG_TYPE "linalg-vectorization"
50 
51 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
52 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
53 
54 /// Try to vectorize `convOp` as a convolution.
56 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
57  bool flatten1DDepthwiseConv = false);
58 
59 /// Return the unique instance of OpType in `block` if it is indeed unique.
60 /// Return null if none or more than 1 instances exist.
61 template <typename OpType>
62 static OpType getSingleOpOfType(Block &block) {
63  OpType res;
64  block.walk([&](OpType op) {
65  if (res) {
66  res = nullptr;
67  return WalkResult::interrupt();
68  }
69  res = op;
70  return WalkResult::advance();
71  });
72  return res;
73 }
74 
75 /// Helper function to extract the input slices after filter is unrolled along
76 /// kw.
77 static SmallVector<Value>
79  int64_t nSize, int64_t wSize, int64_t cSize,
80  int64_t kwSize, int strideW, int dilationW,
81  int64_t wSizeStep, bool isSingleChanneled) {
82  SmallVector<Value> result;
83  if (isSingleChanneled) {
84  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
85  // convolution.
86  SmallVector<int64_t> sizes{wSizeStep};
87  SmallVector<int64_t> strides{1};
88  for (int64_t kw = 0; kw < kwSize; ++kw) {
89  for (int64_t w = 0; w < wSize; w += wSizeStep) {
90  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
91  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
92  }
93  }
94  } else {
95  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
96  // for channeled convolution.
97  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
98  SmallVector<int64_t> strides{1, 1, 1};
99  for (int64_t kw = 0; kw < kwSize; ++kw) {
100  for (int64_t w = 0; w < wSize; w += wSizeStep) {
101  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
102  loc, input,
103  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
104  sizes, strides));
105  }
106  }
107  }
108  return result;
109 }
110 
111 /// Helper function to extract the filter slices after filter is unrolled along
112 /// kw.
114  Location loc, Value filter,
115  int64_t kwSize) {
116  SmallVector<Value> result;
117  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
118  // non-chanelled convolution] @ [kw].
119  for (int64_t kw = 0; kw < kwSize; ++kw) {
120  result.push_back(rewriter.create<vector::ExtractOp>(
121  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
122  }
123  return result;
124 }
125 
126 /// Helper function to extract the result slices after filter is unrolled along
127 /// kw.
128 static SmallVector<Value>
130  int64_t nSize, int64_t wSize, int64_t fSize,
131  int64_t wSizeStep, bool isSingleChanneled) {
132  SmallVector<Value> result;
133  if (isSingleChanneled) {
134  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
135  SmallVector<int64_t> sizes{wSizeStep};
136  SmallVector<int64_t> strides{1};
137  for (int64_t w = 0; w < wSize; w += wSizeStep) {
138  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
139  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
140  }
141  } else {
142  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
143  // convolution.
144  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
145  SmallVector<int64_t> strides{1, 1, 1};
146  for (int64_t w = 0; w < wSize; w += wSizeStep) {
147  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
148  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
149  }
150  }
151  return result;
152 }
153 
154 /// Helper function to insert the computed result slices.
156  Value res, int64_t wSize, int64_t wSizeStep,
157  SmallVectorImpl<Value> &resVals,
158  bool isSingleChanneled) {
159 
160  if (isSingleChanneled) {
161  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
162  // This does not depend on kw.
163  SmallVector<int64_t> strides{1};
164  for (int64_t w = 0; w < wSize; w += wSizeStep) {
165  res = rewriter.create<vector::InsertStridedSliceOp>(
166  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
167  }
168  } else {
169  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
170  // convolution. This does not depend on kw.
171  SmallVector<int64_t> strides{1, 1, 1};
172  for (int64_t w = 0; w < wSize; w += wSizeStep) {
173  res = rewriter.create<vector::InsertStridedSliceOp>(
174  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
175  strides);
176  }
177  }
178  return res;
179 }
180 
181 /// Contains the vectorization state and related methods used across the
182 /// vectorization process of a given operation.
184  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
185 
186  /// Initializes the vectorization state, including the computation of the
187  /// canonical vector shape for vectorization.
188  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
189  ArrayRef<int64_t> inputVectorSizes,
190  ArrayRef<bool> inputScalableVecDims);
191 
192  /// Returns the canonical vector shape used to vectorize the iteration space.
193  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
194 
195  /// Returns a vector type of the provided `elementType` with the canonical
196  /// vector shape and the corresponding fixed/scalable dimensions bit. If
197  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
198  /// accordingly.
200  Type elementType,
201  std::optional<AffineMap> dimPermutation = std::nullopt) const {
203  SmallVector<bool> scalableDims;
204  if (dimPermutation.has_value()) {
205  vectorShape =
206  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
207  scalableDims =
208  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
209  } else {
210  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
211  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
212  }
213 
214  return VectorType::get(vectorShape, elementType, scalableDims);
215  }
216 
217  /// Masks an operation with the canonical vector mask if the operation needs
218  /// masking. Returns the masked operation or the original operation if masking
219  /// is not needed. If provided, the canonical mask for this operation is
220  /// permuted using `maybeMaskingMap`.
221  Operation *
222  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
223  std::optional<AffineMap> maybeMaskingMap = std::nullopt);
224 
225 private:
226  /// Initializes the iteration space static sizes using the Linalg op
227  /// information. This may become more complicated in the future.
228  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
229  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
230  }
231 
232  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
233  /// all the static and dynamic dimensions of the iteration space to be
234  /// vectorized and store them in `iterSpaceValueSizes`.
235  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
236  LinalgOp linalgOp);
237 
238  /// Create or retrieve an existing mask value to mask `opToMask` in the
239  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
240  /// permuted using that permutation map. If a new mask is created, it will be
241  /// cached for future users.
242  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
243  LinalgOp linalgOp,
244  std::optional<AffineMap> maybeMaskingMap);
245 
246  // Holds the compile-time static sizes of the iteration space to vectorize.
247  // Dynamic dimensions are represented using ShapedType::kDynamic.
248  SmallVector<int64_t> iterSpaceStaticSizes;
249 
250  /// Holds the value sizes of the iteration space to vectorize. Static
251  /// dimensions are represented by 'arith.constant' and dynamic
252  /// dimensions by 'tensor/memref.dim'.
253  SmallVector<Value> iterSpaceValueSizes;
254 
255  /// Holds the canonical vector shape used to vectorize the iteration space.
256  SmallVector<int64_t> canonicalVecShape;
257 
258  /// Holds the vector dimensions that are scalable in the canonical vector
259  /// shape.
260  SmallVector<bool> scalableVecDims;
261 
262  /// Holds the active masks for permutations of the canonical vector iteration
263  /// space.
264  DenseMap<AffineMap, Value> activeMaskCache;
265 
266  /// Global vectorization guard for the incoming rewriter. It's initialized
267  /// when the vectorization state is initialized.
269 };
270 
272 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
273  LinalgOp linalgOp) {
274  // TODO: Support 0-d vectors.
275  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
276  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
277  // Create constant index op for static dimensions.
278  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
279  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
280  continue;
281  }
282 
283  // Find an operand defined on this dimension of the iteration space to
284  // extract the runtime dimension size.
285  Value operand;
286  unsigned operandDimPos;
287  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
288  operandDimPos)))
289  return failure();
290 
291  Value dynamicDim = linalgOp.hasPureTensorSemantics()
292  ? (Value)rewriter.create<tensor::DimOp>(
293  linalgOp.getLoc(), operand, operandDimPos)
294  : (Value)rewriter.create<memref::DimOp>(
295  linalgOp.getLoc(), operand, operandDimPos);
296  iterSpaceValueSizes.push_back(dynamicDim);
297  }
298 
299  return success();
300 }
301 
302 /// Initializes the vectorization state, including the computation of the
303 /// canonical vector shape for vectorization.
304 // TODO: Move this to the constructor when we can remove the failure cases.
306 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
307  ArrayRef<int64_t> inputVectorSizes,
308  ArrayRef<bool> inputScalableVecDims) {
309  // Initialize the insertion point.
310  rewriter.setInsertionPoint(linalgOp);
311 
312  if (!inputVectorSizes.empty()) {
313  // Get the canonical vector shape from the input vector sizes provided. This
314  // path should be taken to vectorize code with dynamic shapes and when using
315  // vector sizes greater than the iteration space sizes.
316  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
317  scalableVecDims.append(inputScalableVecDims.begin(),
318  inputScalableVecDims.end());
319  } else {
320  // Compute the canonical vector shape from the operation shape. If there are
321  // dynamic shapes, the operation won't be vectorized. We assume all the
322  // vector dimensions are fixed.
323  canonicalVecShape = linalgOp.getStaticLoopRanges();
324  scalableVecDims.append(linalgOp.getNumLoops(), false);
325  }
326 
327  LDBG("Canonical vector shape: ");
328  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
329  LLVM_DEBUG(llvm::dbgs() << "\n");
330  LDBG("Scalable vector dims: ");
331  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
332  LLVM_DEBUG(llvm::dbgs() << "\n");
333 
334  if (ShapedType::isDynamicShape(canonicalVecShape))
335  return failure();
336 
337  // Initialize iteration space static sizes.
338  initIterSpaceStaticSizes(linalgOp);
339 
340  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
341  // all the static and dynamic dimensions of the iteration space, needed to
342  // compute a mask during vectorization.
343  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
344  return failure();
345 
346  return success();
347 }
348 
349 /// Create or retrieve an existing mask value to mask `opToMask` in the
350 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
351 /// using that permutation map. If a new mask is created, it will be cached for
352 /// future users.
353 Value VectorizationState::getOrCreateMaskFor(
354  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
355  std::optional<AffineMap> maybeMaskingMap) {
356  // No mask is needed if the operation is not maskable.
357  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
358  if (!maskableOp)
359  return Value();
360 
361  assert(!maskableOp.isMasked() &&
362  "Masking an operation that is already masked");
363 
364  // If no masking map was provided, use an identity map with the loop dims.
365  assert((!maybeMaskingMap || *maybeMaskingMap) &&
366  "Unexpected null mask permutation map");
367  AffineMap maskingMap =
368  maybeMaskingMap ? *maybeMaskingMap
370  linalgOp.getNumLoops(), rewriter.getContext());
371 
372  LDBG("Masking map: " << maskingMap << "\n");
373 
374  // Return the active mask for the masking map of this operation if it was
375  // already created.
376  auto activeMaskIt = activeMaskCache.find(maskingMap);
377  if (activeMaskIt != activeMaskCache.end()) {
378  Value mask = activeMaskIt->second;
379  LDBG("Reusing mask: " << mask << "\n");
380  return mask;
381  }
382 
383  // Compute permuted projection of the iteration space to be masked and the
384  // corresponding mask shape. If the resulting iteration space dimensions are
385  // static and identical to the mask shape, masking is not needed for this
386  // operation.
387  // TODO: Improve this check. Only projected permutation indexing maps are
388  // supported.
389  SmallVector<int64_t> permutedStaticSizes =
390  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
391  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
392  auto maskShape = maskType.getShape();
393 
394  LDBG("Mask shape: ");
395  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
396  LLVM_DEBUG(llvm::dbgs() << "\n");
397 
398  if (permutedStaticSizes == maskShape) {
399  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
400  activeMaskCache[maskingMap] = Value();
401  return Value();
402  }
403 
404  // Permute the iteration space value sizes to compute the mask upper bounds.
405  SmallVector<Value> upperBounds =
406  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
407  assert(!maskShape.empty() && !upperBounds.empty() &&
408  "Masked 0-d vectors are not supported yet");
409 
410  // Create the mask based on the dimension values.
411  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
412  maskType, upperBounds);
413  LDBG("Creating new mask: " << mask << "\n");
414  activeMaskCache[maskingMap] = mask;
415  return mask;
416 }
417 
418 /// Masks an operation with the canonical vector mask if the operation needs
419 /// masking. Returns the masked operation or the original operation if masking
420 /// is not needed. If provided, the canonical mask for this operation is
421 /// permuted using `maybeMaskingMap`.
422 Operation *
424  LinalgOp linalgOp,
425  std::optional<AffineMap> maybeMaskingMap) {
426  LDBG("Trying to mask: " << *opToMask << "\n");
427 
428  // Create or retrieve mask for this operation.
429  Value mask =
430  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
431 
432  if (!mask) {
433  LDBG("No mask required\n");
434  return opToMask;
435  }
436 
437  // Wrap the operation with a new `vector.mask` and update D-U chain.
438  assert(opToMask && "Expected a valid operation to mask");
439  auto maskOp = cast<vector::MaskOp>(
440  mlir::vector::maskOperation(rewriter, opToMask, mask));
441  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
442 
443  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
444  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
445  maskOpTerminator);
446 
447  LDBG("Masked operation: " << *maskOp << "\n");
448  return maskOp;
449 }
450 
451 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
452 /// projectedPermutation, compress the unused dimensions to serve as a
453 /// permutation_map for a vector transfer operation.
454 /// For example, given a linalg op such as:
455 ///
456 /// ```
457 /// %0 = linalg.generic {
458 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
459 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
460 /// }
461 /// ins(%0 : tensor<2x3x4xf32>)
462 /// outs(%1 : tensor<5x6xf32>)
463 /// ```
464 ///
465 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
466 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
467 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
469  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
470  "expected projected permutation");
471  auto res = compressUnusedDims(map);
472  assert(res.getNumDims() == res.getNumResults() &&
473  "expected reindexed map with same number of dims and results");
474  return res;
475 }
476 
477 /// Helper enum to represent conv1d input traversal order.
478 enum class Conv1DOpOrder {
479  W, // Corresponds to non-channeled 1D convolution operation.
480  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
481  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
482 };
483 
484 /// Helper data structure to represent the result of vectorization.
485 /// In certain specific cases, like terminators, we do not want to propagate/
487  /// Op failed to vectorize.
488  Failure = 0,
489  /// Op vectorized and custom function took care of replacement logic
491  /// Op vectorized into a new Op whose results will replace original Op's
492  /// results.
493  NewOp
494  // TODO: support values if Op vectorized to Many-Ops whose results we need to
495  // aggregate for replacement.
496 };
498  /// Return status from vectorizing the current op.
500  /// New vectorized operation to replace the current op.
501  /// Replacement behavior is specified by `status`.
503 };
504 
505 std::optional<vector::CombiningKind>
507  using ::mlir::vector::CombiningKind;
508 
509  if (!combinerOp)
510  return std::nullopt;
512  .Case<arith::AddIOp, arith::AddFOp>(
513  [&](auto op) { return CombiningKind::ADD; })
514  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
515  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
516  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
517  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
518  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
519  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
520  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
521  .Case<arith::MulIOp, arith::MulFOp>(
522  [&](auto op) { return CombiningKind::MUL; })
523  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
524  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
525  .Default([&](auto op) { return std::nullopt; });
526 }
527 
528 /// Check whether `outputOperand` is a reduction with a single combiner
529 /// operation. Return the combiner operation of the reduction. Return
530 /// nullptr otherwise. Multiple reduction operations would impose an
531 /// ordering between reduction dimensions and is currently unsupported in
532 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
533 /// max(min(X))
534 // TODO: use in LinalgOp verification, there is a circular dependency atm.
535 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
536  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
537  unsigned outputPos =
538  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
539  // Only single combiner operations are supported for now.
540  SmallVector<Operation *, 4> combinerOps;
541  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
542  combinerOps.size() != 1)
543  return nullptr;
544 
545  // Return the combiner operation.
546  return combinerOps[0];
547 }
548 
549 /// Broadcast `value` to a vector of `shape` if possible. Return value
550 /// otherwise.
551 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
552  auto dstVecType = dyn_cast<VectorType>(dstType);
553  // If no shape to broadcast to, just return `value`.
554  if (dstVecType.getRank() == 0)
555  return value;
556  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
558  return value;
559  Location loc = b.getInsertionPoint()->getLoc();
560  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
561 }
562 
563 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
564 /// assumes that `reductionOp` has two operands and one of them is the reduction
565 /// initial value.buildMultiDimReduce
566 // Note: this is a true builder that notifies the OpBuilder listener.
567 // TODO: Consider moving as a static helper on the ReduceOp.
569  Value valueToReduce, Value acc,
570  ArrayRef<bool> dimsToMask) {
571  auto maybeKind = getCombinerOpKind(reduceOp);
572  assert(maybeKind && "Failed precondition: could not get reduction kind");
573  return b.create<vector::MultiDimReductionOp>(
574  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
575 }
576 
577 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
578  return llvm::to_vector(
579  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
580 }
581 
582 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
583 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
584 /// currently being vectorized. If `dest` has null rank, build an memref.store.
585 /// Return the produced value or null if no value is produced.
586 // Note: this is a true builder that notifies the OpBuilder listener.
587 // TODO: Consider moving as a static helper on the ReduceOp.
588 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
589  OpOperand *outputOperand,
590  VectorizationState &state) {
591  Location loc = value.getLoc();
592  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
593  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
594 
595  // Compute the vector type of the value to store. This type should be an
596  // identity or projection of the canonical vector type without any permutation
597  // applied, given that any permutation in a transfer write happens as part of
598  // the write itself.
600  opOperandMap.getContext(), opOperandMap.getNumInputs(),
601  [&](AffineDimExpr dimExpr) -> bool {
602  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
603  });
604  auto vectorType = state.getCanonicalVecType(
605  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
606 
607  Operation *write;
608  if (vectorType.getRank() > 0) {
609  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
610  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
611  rewriter.create<arith::ConstantIndexOp>(loc, 0));
612  value = broadcastIfNeeded(rewriter, value, vectorType);
613  assert(value.getType() == vectorType && "Incorrect type");
614  write = rewriter.create<vector::TransferWriteOp>(
615  loc, value, outputOperand->get(), indices, writeMap);
616  } else {
617  // 0-d case is still special: do not invert the reindexing writeMap.
618  if (!isa<VectorType>(value.getType()))
619  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
620  assert(value.getType() == vectorType && "Incorrect type");
621  write = rewriter.create<vector::TransferWriteOp>(
622  loc, value, outputOperand->get(), ValueRange{});
623  }
624 
625  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
626 
627  // If masked, set in-bounds to true. Masking guarantees that the access will
628  // be in-bounds.
629  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
630  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
631  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
632  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
633  }
634 
635  LDBG("vectorized op: " << *write << "\n");
636  if (!write->getResults().empty())
637  return write->getResult(0);
638  return Value();
639 }
640 
641 // Custom vectorization precondition function type. This is intented to be used
642 // with CustomVectorizationHook. Returns success if the corresponding custom
643 // hook can vectorize the op.
645  std::function<LogicalResult(Operation *, bool)>;
646 
647 // Custom vectorization function type. Produce a vector form of Operation*
648 // assuming all its vectorized operands are already in the IRMapping.
649 // Return nullptr if the Operation cannot be vectorized.
651  std::function<VectorizationResult(Operation *, const IRMapping &)>;
652 
653 /// Helper function to vectorize the terminator of a `linalgOp`. New result
654 /// vector values are appended to `newResults`. Return
655 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
656 /// should not try to map produced operations and instead return the results
657 /// using the `newResults` vector making them available to the vectorization
658 /// algorithm for RAUW. This function is meant to be used as a
659 /// CustomVectorizationHook.
660 static VectorizationResult
662  const IRMapping &bvm, VectorizationState &state,
663  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
664  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
665  if (!yieldOp)
667  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
668  // TODO: Scan for an opportunity for reuse.
669  // TODO: use a map.
670  Value vectorValue = bvm.lookup(output.value());
671  Value newResult =
672  buildVectorWrite(rewriter, vectorValue,
673  linalgOp.getDpsInitOperand(output.index()), state);
674  if (newResult)
675  newResults.push_back(newResult);
676  }
677 
679 }
680 
681 /// Helper function to vectorize the index operations of a `linalgOp`. Return
682 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
683 /// should map the produced operations. This function is meant to be used as a
684 /// CustomVectorizationHook.
686  VectorizationState &state,
687  Operation *op,
688  LinalgOp linalgOp) {
689  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
690  if (!indexOp)
692  auto loc = indexOp.getLoc();
693  // Compute the static loop sizes of the index op.
694  auto targetShape = state.getCanonicalVecShape();
695  // Compute a one-dimensional index vector for the index op dimension.
696  auto constantSeq =
697  llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
698  auto indexSteps = rewriter.create<arith::ConstantOp>(
699  loc, rewriter.getIndexVectorAttr(constantSeq));
700  // Return the one-dimensional index vector if it lives in the trailing
701  // dimension of the iteration space since the vectorization algorithm in this
702  // case can handle the broadcast.
703  if (indexOp.getDim() == targetShape.size() - 1)
705  // Otherwise permute the targetShape to move the index dimension last,
706  // broadcast the one-dimensional index vector to the permuted shape, and
707  // finally transpose the broadcasted index vector to undo the permutation.
708  auto permPattern =
709  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
710  std::swap(permPattern[indexOp.getDim()], permPattern.back());
711  auto permMap =
712  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
713 
714  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
715  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
716  indexSteps);
717  SmallVector<int64_t> transposition =
718  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
719  std::swap(transposition.back(), transposition[indexOp.getDim()]);
720  auto transposeOp =
721  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
722  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
723 }
724 
725 /// Helper function to check if the tensor.extract can be vectorized by the
726 /// custom hook vectorizeTensorExtract.
727 static LogicalResult
728 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
729  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
730  if (!extractOp)
731  return failure();
732 
733  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
734  return failure();
735 
736  // Check the index type, but only for non 0-d tensors (for which we do need
737  // access indices).
738  if (not extractOp.getIndices().empty()) {
739  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
740  return failure();
741  }
742 
743  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
744  return !VectorType::isValidElementType(type);
745  })) {
746  return failure();
747  }
748 
749  return success();
750 }
751 
752 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
753 /// generated from `tensor.extract`. The offset is calculated as follows
754 /// (example using scalar values):
755 ///
756 /// offset = extractOp.indices[0]
757 /// for (i = 1; i < numIndices; i++)
758 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
759 ///
760 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
761 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
763  VectorizationState &state,
764  tensor::ExtractOp extractOp,
765  const IRMapping &bvm) {
766  // The vector of indices for GatherOp should be shaped as the output vector.
767  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
768  auto loc = extractOp.getLoc();
769 
770  Value offset = broadcastIfNeeded(
771  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
772 
773  const size_t numIndices = extractOp.getIndices().size();
774  for (size_t i = 1; i < numIndices; i++) {
775  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
776 
777  auto dimSize = broadcastIfNeeded(
778  rewriter,
779  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
780  indexVecType);
781 
782  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
783 
784  auto extractOpIndex = broadcastIfNeeded(
785  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
786 
787  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
788  }
789 
790  return offset;
791 }
792 
794 
795 /// Checks whether /p val can be used for calculating a loop invariant index.
796 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
797 
798  auto targetShape = linalgOp.getStaticLoopRanges();
799  assert(((llvm::count_if(targetShape,
800  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
801  "n-D vectors are not yet supported");
802  assert(targetShape.back() != 1 &&
803  "1-D vectors with the trailing dim eqaual 1 are not yet supported");
804 
805  // Blocks outside _this_ linalg.generic are effectively loop invariant.
806  // However, analysing block arguments for _this_ linalg.generic Op is a bit
807  // tricky. Just bail out in the latter case.
808  // TODO: We could try analysing the corresponding affine map here.
809  auto *block = linalgOp.getBlock();
810  if (isa<BlockArgument>(val))
811  return llvm::all_of(block->getArguments(),
812  [&val](Value v) { return (v != val); });
813 
814  Operation *defOp = val.getDefiningOp();
815  assert(defOp && "This is neither a block argument nor an operation result");
816 
817  // IndexOp is loop invariant as long as its result remains constant across
818  // iterations. Given the assumptions on the loop ranges above, only the
819  // trailing loop dim ever changes.
820  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
821  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
822  return (indexOp.getDim() != trailingLoopDim);
823 
824  auto *ancestor = block->findAncestorOpInBlock(*defOp);
825 
826  // Values define outside `linalgOp` are loop invariant.
827  if (!ancestor)
828  return true;
829 
830  // Values defined inside `linalgOp`, which are constant, are loop invariant.
831  if (isa<arith::ConstantOp>(ancestor))
832  return true;
833 
834  bool result = true;
835  for (auto op : ancestor->getOperands())
836  result &= isLoopInvariantIdx(linalgOp, op);
837 
838  return result;
839 }
840 
841 /// Check whether \p val could be used for calculating the trailing index for a
842 /// contiguous load operation.
843 ///
844 /// There are currently 3 types of values that are allowed here:
845 /// 1. loop-invariant values,
846 /// 2. values that increment by 1 with every loop iteration,
847 /// 3. results of basic arithmetic operations (linear and continuous)
848 /// involving 1., 2. and 3.
849 /// This method returns True if indeed only such values are used in calculating
850 /// \p val.
851 ///
852 /// Additionally, the trailing index for a contiguous load operation should
853 /// increment by 1 with every loop iteration, i.e. be based on:
854 /// * `linalg.index <dim>` ,
855 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
856 /// updated to `true` when such an op is found.
857 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
858  bool &foundIndexOp) {
859 
860  auto targetShape = linalgOp.getStaticLoopRanges();
861  assert(((llvm::count_if(targetShape,
862  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
863  "n-D vectors are not yet supported");
864  assert(targetShape.back() != 1 &&
865  "1-D vectors with the trailing dim 1 are not yet supported");
866 
867  // Blocks outside _this_ linalg.generic are effectively loop invariant.
868  // However, analysing block arguments for _this_ linalg.generic Op is a bit
869  // tricky. Just bail out in the latter case.
870  // TODO: We could try analysing the corresponding affine map here.
871  auto *block = linalgOp.getBlock();
872  if (isa<BlockArgument>(val))
873  return llvm::all_of(block->getArguments(),
874  [&val](Value v) { return (v != val); });
875 
876  Operation *defOp = val.getDefiningOp();
877  assert(defOp && "This is neither a block argument nor an operation result");
878 
879  // Given the assumption on the loop ranges above, only the trailing loop
880  // index is not constant.
881  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
882  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
883  foundIndexOp = (indexOp.getDim() == trailingLoopDim);
884  return true;
885  }
886 
887  auto *ancestor = block->findAncestorOpInBlock(*defOp);
888 
889  if (!ancestor)
890  return false;
891 
892  // Conservatively reject Ops that could lead to indices with stride other
893  // than 1.
894  if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
895  ancestor))
896  return false;
897 
898  bool result = false;
899  for (auto op : ancestor->getOperands())
900  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
901 
902  return result;
903 }
904 
905 /// Infer the memory access pattern for the input ExtractOp
906 ///
907 /// Based on the operation shapes and indices (usually based on the iteration
908 /// space of the parent `linalgOp` operation), decides whether the input
909 /// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
910 /// gather load.
911 ///
912 /// Note that it is always safe to use gather load operations for contiguous
913 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
914 /// that `extractOp` is a gather load.
916 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
917  LinalgOp &linalgOp) {
918 
919  auto targetShape = linalgOp.getStaticLoopRanges();
920  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
921 
922  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
923  if (inputShape.getShape().empty())
925 
926  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
927  // gather load.
928  // TODO: Relax this condition.
929  if (linalgOp.hasDynamicShape())
931 
932  // 1. Assume that it's a gather load when reading _into_:
933  // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
934  // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
935  // TODO: Relax these conditions.
936  // FIXME: This condition assumes non-dynamic sizes.
937  if ((llvm::count_if(targetShape,
938  [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
939  targetShape.back() == 1)
941 
942  // 2. Assume that it's a gather load when reading _from_ a tensor for which
943  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
944  // TODO: Relax this condition.
945  if (inputShape.getShape().back() == 1)
947 
948  bool leadingIdxsLoopInvariant = true;
949 
950  // 3. Analyze the leading indices of `extractOp`.
951  // Look at the way each index is calculated and decide whether it is suitable
952  // for a contiguous load, i.e. whether it's loop invariant.
953  auto indices = extractOp.getIndices();
954  auto leadIndices = indices.drop_back(1);
955 
956  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
957  if (inputShape.getShape()[i] == 1)
958  continue;
959 
960  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
961  }
962 
963  if (!leadingIdxsLoopInvariant) {
964  LDBG("Found gather load: " << extractOp);
966  }
967 
968  // 4. Analyze the trailing index for `extractOp`.
969  // At this point we know that the leading indices are loop invariant. This
970  // means that is potentially a scalar or a contiguous load. We can decide
971  // based on the trailing idx.
972  auto extractOpTrailingIdx = indices.back();
973 
974  // 4a. Scalar broadcast load
975  // If the trailing index is loop invariant then this is a scalar load.
976  if (leadingIdxsLoopInvariant &&
977  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
978  LDBG("Found scalar broadcast load: " << extractOp);
979 
981  }
982 
983  // 4b. Contiguous loads
984  // The trailing `extractOp` index should increment with every loop iteration.
985  // This effectively means that it must be based on the trailing loop index.
986  // This is what the following bool captures.
987  bool foundIndexOp = false;
988  bool isContiguousLoad =
989  isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
990  isContiguousLoad &= foundIndexOp;
991 
992  if (isContiguousLoad) {
993  LDBG("Found contigous load: " << extractOp);
995  }
996 
997  // 5. Fallback case - gather load.
998  LDBG("Found gather load: " << extractOp);
1000 }
1001 
1002 /// Helper function to vectorize the tensor.extract operations. Returns
1003 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1004 /// should map the produced operations. This function is meant to be used as a
1005 /// CustomVectorizationHook.
1006 static VectorizationResult
1008  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1009  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1010  if (!extractOp)
1012  auto loc = extractOp.getLoc();
1013 
1014  // Compute the static loop sizes of the extract op.
1015  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1016  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1017  loc,
1018  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1019  /*value=*/true));
1020  auto passThruConstantOp =
1021  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1022 
1023  // Base indices are currently set to 0. We will need to re-visit if more
1024  // generic scenarios are to be supported.
1025  SmallVector<Value> baseIndices(
1026  extractOp.getIndices().size(),
1027  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1028 
1029  VectorMemoryAccessKind memAccessKind =
1030  getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1031 
1032  // 1. Handle gather access
1033  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1034  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1035 
1036  // Generate the gather load
1037  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1038  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1039  maskConstantOp, passThruConstantOp);
1040  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1041 
1042  LDBG("Vectorised as gather load: " << extractOp << "\n");
1044  }
1045 
1046  // 2. Handle:
1047  // a. scalar loads + broadcast,
1048  // b. contiguous loads.
1049  // Both cases use vector.transfer_read.
1050 
1051  // Collect indices for `vector.transfer_read`. At this point, the indices will
1052  // either be scalars or would have been broadcast to vectors matching the
1053  // result type. For indices that are vectors, there are two options:
1054  // * for non-trailing indices, all elements are identical (contiguous
1055  // loads are identified by looking for non-trailing indices that are
1056  // invariant with respect to the corresponding linalg.generic), or
1057  // * for trailing indices, the index vector will contain values with stride
1058  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1059  // needed.
1060  // This means that
1061  // * for scalar indices - just re-use it,
1062  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1063  // (0th) element and use that.
1064  SmallVector<Value> transferReadIdxs;
1065  auto resTrailingDim = resultType.getShape().back();
1066  auto zero = rewriter.create<arith::ConstantOp>(
1067  loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1068  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1069  auto idx = bvm.lookup(extractOp.getIndices()[i]);
1070  if (idx.getType().isIndex()) {
1071  transferReadIdxs.push_back(idx);
1072  continue;
1073  }
1074 
1075  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1076  loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
1077  bvm.lookup(extractOp.getIndices()[i]));
1078  transferReadIdxs.push_back(
1079  rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1080  }
1081 
1082  // `tensor.extract_element` is always in-bounds, hence the following holds.
1083  auto dstRank = resultType.getRank();
1084  auto srcRank = extractOp.getTensor().getType().getRank();
1085  SmallVector<bool> inBounds(dstRank, true);
1086 
1087  // 2a. Handle scalar broadcast access.
1088  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1089  MLIRContext *ctx = rewriter.getContext();
1090  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1091  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1092 
1093  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1094  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1095  permutationMap, inBounds);
1096 
1097  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1098  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1099  }
1100 
1101  // 2b. Handle contiguous access.
1102  auto permutationMap = AffineMap::getMinorIdentityMap(
1103  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1104 
1105  int32_t rankDiff = dstRank - srcRank;
1106  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1107  // dims so that the ranks match. This is done by extending the map with 0s.
1108  // For example, for dstRank = 3, srcRank = 2, the following map created
1109  // above:
1110  // (d0, d1) --> (d0, d1)
1111  // is extended as:
1112  // (d0, d1) --> (0, d0, d1)
1113  while (rankDiff > 0) {
1114  permutationMap = permutationMap.insertResult(
1115  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1116  rankDiff--;
1117  }
1118 
1119  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1120  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1121  inBounds);
1122 
1123  LDBG("Vectorised as contiguous load: " << extractOp);
1124  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1125 }
1126 
1127 /// Emit reduction operations if the shapes of the value to reduce is different
1128 /// that the result shape.
1129 // Note: this is a true builder that notifies the OpBuilder listener.
1130 // TODO: Consider moving as a static helper on the ReduceOp.
1131 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1132  Value reduceValue, Value initialValue,
1133  const IRMapping &bvm) {
1134  Value reduceVec = bvm.lookup(reduceValue);
1135  Value outputVec = bvm.lookup(initialValue);
1136  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1137  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1138  // Reduce only if needed as the value may already have been reduce for
1139  // contraction vectorization.
1140  if (!reduceType ||
1141  (outputType && reduceType.getShape() == outputType.getShape()))
1142  return nullptr;
1143  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1144  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1145 }
1146 
1147 /// Generic vectorization for a single operation `op`, given already vectorized
1148 /// operands carried by `bvm`. Vectorization occurs as follows:
1149 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1150 /// result on success.
1151 /// 2. Clone any constant in the current scope without vectorization: each
1152 /// consumer of the constant will later determine the shape to which the
1153 /// constant needs to be broadcast to.
1154 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1155 /// of the `customVectorizationHooks` to cover such cases.
1156 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1157 /// operand of maximal rank. Other operands have smaller rank and are
1158 /// broadcast accordingly. It is assumed this broadcast is always legal,
1159 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1160 ///
1161 /// This function assumes all operands of `op` have been vectorized and are in
1162 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1163 /// a topologically-sorted list of ops.
1164 /// This function does not update `bvm` but returns a VectorizationStatus that
1165 /// instructs the caller what `bvm` update needs to occur.
1166 static VectorizationResult
1168  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1169  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1170  LDBG("vectorize op " << *op << "\n");
1171 
1172  // 1. Try to apply any CustomVectorizationHook.
1173  if (!customVectorizationHooks.empty()) {
1174  for (auto &customFunc : customVectorizationHooks) {
1175  VectorizationResult result = customFunc(op, bvm);
1176  if (result.status == VectorizationStatus::Failure)
1177  continue;
1178  return result;
1179  }
1180  }
1181 
1182  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1183  // Clone so that the constant is not confined to the linalgOp block .
1184  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1185  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1186 
1187  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1190 
1191  // 4 . Check if the operation is a reduction.
1192  SmallVector<std::pair<Value, Value>> reductionOperands;
1193  for (Value operand : op->getOperands()) {
1194  auto blockArg = dyn_cast<BlockArgument>(operand);
1195  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1196  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1197  continue;
1198  SmallVector<Operation *> reductionOps;
1199  Value reduceValue = matchReduction(
1200  linalgOp.getRegionOutputArgs(),
1201  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1202  if (!reduceValue)
1203  continue;
1204  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1205  }
1206  if (!reductionOperands.empty()) {
1207  assert(reductionOperands.size() == 1);
1208  Operation *reduceOp =
1209  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1210  reductionOperands[0].second, bvm);
1211  if (reduceOp)
1213  }
1214 
1215  // 5. Generic vectorization path for ElementwiseMappable ops.
1216  // a. Get the first max ranked shape.
1217  VectorType firstMaxRankedType;
1218  for (Value operand : op->getOperands()) {
1219  auto vecOperand = bvm.lookup(operand);
1220  assert(vecOperand && "Vector operand couldn't be found");
1221 
1222  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1223  if (vecType && (!firstMaxRankedType ||
1224  firstMaxRankedType.getRank() < vecType.getRank()))
1225  firstMaxRankedType = vecType;
1226  }
1227  // b. Broadcast each op if needed.
1228  SmallVector<Value> vecOperands;
1229  for (Value scalarOperand : op->getOperands()) {
1230  Value vecOperand = bvm.lookup(scalarOperand);
1231  assert(vecOperand && "Vector operand couldn't be found");
1232 
1233  if (firstMaxRankedType) {
1234  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1235  getElementTypeOrSelf(vecOperand.getType()),
1236  firstMaxRankedType.getScalableDims());
1237  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1238  } else {
1239  vecOperands.push_back(vecOperand);
1240  }
1241  }
1242  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1243  SmallVector<Type> resultTypes;
1244  for (Type resultType : op->getResultTypes()) {
1245  resultTypes.push_back(
1246  firstMaxRankedType
1247  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1248  firstMaxRankedType.getScalableDims())
1249  : resultType);
1250  }
1251  // d. Build and return the new op.
1252  return VectorizationResult{
1254  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1255  resultTypes, op->getAttrs())};
1256 }
1257 
1258 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1259 /// vector form. Generic vectorization proceeds as follows:
1260 /// 1. Verify the `linalgOp` has one non-empty region.
1261 /// 2. Values defined above the region are mapped to themselves and will be
1262 /// broadcasted on a per-need basis by their consumers.
1263 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1264 /// load).
1265 /// TODO: Reuse opportunities for RAR dependencies.
1266 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1267 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1268 /// iteration indices.
1269 /// 5. Iteratively call vectorizeOneOp on the region operations.
1270 ///
1271 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1272 /// performed to the maximal common vector size implied by the `linalgOp`
1273 /// iteration space. This eager broadcasting is introduced in the
1274 /// permutation_map of the vector.transfer_read operations. The eager
1275 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1276 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1277 /// the absence of good canonicalizations, the amount of work increases.
1278 /// This is not deemed a problem as we expect canonicalizations and foldings to
1279 /// aggressively clean up the useless work.
1280 static LogicalResult
1282  LinalgOp linalgOp,
1283  SmallVectorImpl<Value> &newResults) {
1284  LDBG("Vectorizing operation as linalg generic\n");
1285  Block *block = linalgOp.getBlock();
1286 
1287  // 2. Values defined above the region can only be broadcast for now. Make them
1288  // map to themselves.
1289  IRMapping bvm;
1290  SetVector<Value> valuesSet;
1291  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1292  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1293 
1294  if (linalgOp.getNumDpsInits() == 0)
1295  return failure();
1296 
1297  // 3. Turn all BBArgs into vector.transfer_read / load.
1298  Location loc = linalgOp.getLoc();
1299  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1300  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1301  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1302  if (linalgOp.isScalar(opOperand)) {
1303  bvm.map(bbarg, opOperand->get());
1304  continue;
1305  }
1306 
1307  // 3.a. Convert the indexing map for this input/output to a transfer read
1308  // permutation map and masking map.
1309  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1310 
1311  // Remove zeros from indexing map to use it as masking map.
1312  SmallVector<int64_t> zeroPos;
1313  auto results = indexingMap.getResults();
1314  for (const auto &result : llvm::enumerate(results)) {
1315  if (isa<AffineConstantExpr>(result.value())) {
1316  zeroPos.push_back(result.index());
1317  }
1318  }
1319  AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1320 
1321  AffineMap readMap;
1322  VectorType readType;
1323  Type elemType = getElementTypeOrSelf(opOperand->get());
1324  if (linalgOp.isDpsInput(opOperand)) {
1325  // 3.a.i. For input reads we use the canonical vector shape.
1326  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1327  readType = state.getCanonicalVecType(elemType);
1328  } else {
1329  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1330  // reductions), the vector shape is computed by mapping the canonical
1331  // vector shape to the output domain and back to the canonical domain.
1332  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1333  readType =
1334  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1335  }
1336 
1337  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1338 
1339  Operation *read = rewriter.create<vector::TransferReadOp>(
1340  loc, readType, opOperand->get(), indices, readMap);
1341  read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1342  Value readValue = read->getResult(0);
1343 
1344  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1345  // will be in-bounds.
1346  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1347  SmallVector<bool> inBounds(readType.getRank(), true);
1348  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1349  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1350  }
1351 
1352  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1353  // TODO: remove this.
1354  if (readType.getRank() == 0)
1355  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1356 
1357  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1358  << "\n");
1359  bvm.map(bbarg, readValue);
1360  bvm.map(opOperand->get(), readValue);
1361  }
1362 
1364  // 4a. Register CustomVectorizationHook for yieldOp.
1365  CustomVectorizationHook vectorizeYield =
1366  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1367  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1368  };
1369  hooks.push_back(vectorizeYield);
1370 
1371  // 4b. Register CustomVectorizationHook for indexOp.
1372  CustomVectorizationHook vectorizeIndex =
1373  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1374  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1375  };
1376  hooks.push_back(vectorizeIndex);
1377 
1378  // 4c. Register CustomVectorizationHook for extractOp.
1379  CustomVectorizationHook vectorizeExtract =
1380  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1381  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1382  };
1383  hooks.push_back(vectorizeExtract);
1384 
1385  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1386  for (Operation &op : block->getOperations()) {
1387  VectorizationResult result =
1388  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1389  if (result.status == VectorizationStatus::Failure) {
1390  LDBG("failed to vectorize: " << op << "\n");
1391  return failure();
1392  }
1393  if (result.status == VectorizationStatus::NewOp) {
1394  Operation *maybeMaskedOp =
1395  state.maskOperation(rewriter, result.newOp, linalgOp);
1396  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1397  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1398  }
1399  }
1400 
1401  return success();
1402 }
1403 
1404 /// Given a tensor::PackOp, return the `dest` shape before any packing
1405 /// permutations.
1406 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1407  ArrayRef<int64_t> destShape) {
1408  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1409 }
1410 
1411 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
1412 /// vector type for the read is not the same as the type of `source`, then a
1413 /// mask is created on the read.
1415  Value source, ArrayRef<int64_t> readShape,
1416  Value padValue) {
1417  assert(llvm::none_of(readShape,
1418  [](int64_t s) { return s == ShapedType::kDynamic; }));
1419  auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
1420  assert(sourceShape.size() == readShape.size());
1421  auto maskType = VectorType::get(readShape, builder.getI1Type());
1422  auto vectorType = VectorType::get(readShape, padValue.getType());
1423  int64_t readRank = readShape.size();
1424  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1425  auto transferReadOp = builder.create<vector::TransferReadOp>(
1426  loc,
1427  /*vectorType=*/vectorType,
1428  /*source=*/source,
1429  /*indices=*/SmallVector<Value>(readRank, zero),
1430  /*padding=*/padValue,
1431  /*inBounds=*/SmallVector<bool>(readRank, true));
1432  if (llvm::equal(readShape, sourceShape)) {
1433  return transferReadOp;
1434  }
1435  SmallVector<OpFoldResult> mixedSourceDims =
1436  tensor::getMixedSizes(builder, loc, source);
1437  Value mask =
1438  builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1439  return mlir::vector::maskOperation(builder, transferReadOp, mask)
1440  ->getResult(0);
1441 }
1442 
1443 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1444 /// create an empty destination tensor and create a TransferWriteOp from the
1445 /// input to the empty tensor. If the destination shape is not the same as the
1446 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1447 /// mask for the write.
1449  Value input,
1450  SmallVector<OpFoldResult> destSizes,
1451  ArrayRef<int64_t> inputVectorSizes) {
1452  auto inputType = cast<VectorType>(input.getType());
1453  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1454  inputType.getElementType());
1455  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1456  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1457  Operation *write = builder.create<vector::TransferWriteOp>(
1458  loc,
1459  /*vector=*/input,
1460  /*source=*/dest,
1461  /*indices=*/SmallVector<Value>(rank, zero),
1462  /*inBounds=*/SmallVector<bool>(rank, true));
1463  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1464  assert(llvm::none_of(
1465  destShape.drop_front(inputVectorSizes.size()),
1466  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1467  "Only dims aligned with inputVectorSizes may be dynamic");
1468  bool needMaskForWrite = !llvm::equal(
1469  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1470  if (needMaskForWrite) {
1471  SmallVector<int64_t> writeMaskShape;
1472  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1473  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1474  destShape.end());
1475  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1476  Value maskForWrite =
1477  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1478  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1479  }
1480  return write;
1481 }
1482 
1483 /// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1484 /// padding value into:
1485 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1486 /// As in the following example:
1487 ///
1488 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1489 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1490 ///
1491 /// This pack would be vectorized to:
1492 ///
1493 /// %load = vector.mask %mask {
1494 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1495 /// {in_bounds = [true, true, true]} :
1496 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1497 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1498 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1499 /// to vector<32x4x2x1x16xf32>
1500 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1501 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1502 /// %write = vector.transfer_write %transpose,
1503 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1504 /// {in_bounds = [true, true, true, true, true]}
1505 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1506 static LogicalResult
1507 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1508  ArrayRef<int64_t> inputVectorSizes,
1509  SmallVectorImpl<Value> &newResults) {
1510  OpBuilder::InsertionGuard g(rewriter);
1511  rewriter.setInsertionPoint(packOp);
1512 
1513  Location loc = packOp.getLoc();
1514  auto padValue = packOp.getPaddingValue();
1515  if (!padValue) {
1516  padValue = rewriter.create<arith::ConstantOp>(
1517  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1518  }
1519  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1520  LogicalResult status =
1521  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1522  .reifyResultShapes(rewriter, reifiedReturnShapes);
1523  (void)status; // prevent unused variable warning on non-assert builds.
1524  assert(succeeded(status) && "failed to reify result shapes");
1525 
1526  // Create masked TransferReadOp.
1527  SmallVector<int64_t> inputShape(inputVectorSizes);
1528  auto innerTiles = packOp.getStaticInnerTiles();
1529  auto innerDimsPos = packOp.getInnerDimsPos();
1530  auto outerDimsPerm = packOp.getOuterDimsPerm();
1531  if (!outerDimsPerm.empty())
1532  applyPermutationToVector(inputShape,
1533  invertPermutationVector(outerDimsPerm));
1534  for (auto [idx, size] : enumerate(innerTiles))
1535  inputShape[innerDimsPos[idx]] *= size;
1536  auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
1537  inputShape, padValue);
1538 
1539  // Create ShapeCastOp.
1540  SmallVector<int64_t> destShape(inputVectorSizes);
1541  destShape.append(innerTiles.begin(), innerTiles.end());
1542  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1543  packOp.getDestType().getElementType());
1544  auto shapeCastOp =
1545  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1546 
1547  // Create TransposeOp.
1548  auto destPermutation =
1550  auto transposeOp = rewriter.create<vector::TransposeOp>(
1551  loc, shapeCastOp.getResult(), destPermutation);
1552 
1553  // Create TransferWriteOp.
1554  Operation *write =
1555  createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1556  reifiedReturnShapes[0], inputVectorSizes);
1557  newResults.push_back(write->getResult(0));
1558  return success();
1559 }
1560 
1561 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1562 /// Vector::TransferReadOp - Reads a vector from the source tensor
1563 /// vector::TransposeOp - Transpose the Source tensor
1564 /// ShapeCastOp - Reshape the data based on the target.
1565 /// vector::TransferWriteOp. - Write the result vector back to the destination
1566 /// tensor
1567 static LogicalResult
1568 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1569  ArrayRef<int64_t> inputVectorSizes,
1570  SmallVectorImpl<Value> &newResults) {
1571 
1572  OpBuilder::InsertionGuard g(rewriter);
1573  rewriter.setInsertionPoint(unpackOp);
1574 
1575  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1576 
1577  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1578  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1579 
1580  SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
1581  inputVectorSizes.end());
1582  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1583  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1584 
1585  // ReadMask is the size of tensor used to read and apply mask. It is
1586  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1587  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1588  // size M-N
1589  // Thus:
1590  // - initially: ReadMaskShape = vectorInputSizes
1591  // - Divide all the readMaskShape locations pointed by innerDimPos
1592  // by the innerTileSize attribute value.
1593  // - if outer_dims_perms is present: do that permutation on readMaskShape.
1594  // - Append the remaining shape from SS
1595  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1596  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1597  // 128] and outer_dims_perm is [1, 0] then read shape is:
1598  // ReadMaskShape(initial): [512, 128]
1599  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1600  // = [16, 8]
1601  // After applying outer_dims_perm: [8, 16]
1602  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1603 
1604  for (auto [index, size] : enumerate(innerTiles)) {
1605  readMaskShape[innerDimPos[index]] =
1606  llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
1607  }
1608  if (!outerDimsPerm.empty()) {
1609  applyPermutationToVector(readMaskShape, outerDimsPerm);
1610  }
1611  readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
1612  sourceShape.end());
1613 
1614  ReifiedRankedShapedTypeDims reifiedRetShapes;
1615  LogicalResult status =
1616  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1617  .reifyResultShapes(rewriter, reifiedRetShapes);
1618  if (status.failed()) {
1619  LDBG("Unable to reify result shapes of " << unpackOp);
1620  return failure();
1621  }
1622  Location loc = unpackOp->getLoc();
1623 
1624  auto padValue = rewriter.create<arith::ConstantOp>(
1625  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1626 
1627  // Read result, mask if necessary. If transferReadOp shape is not equal
1628  // to shape of source, then a mask is necessary.
1629  Value readResult = createReadOrMaskedRead(
1630  rewriter, loc, unpackOp.getSource(),
1631  ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
1632 
1633  PackingMetadata packMetadata;
1634  SmallVector<int64_t> lastDimToInsertPosPerm =
1635  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1636  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1637  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1638  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1639  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1640  RankedTensorType stripMineTensorType =
1641  RankedTensorType::get(stripMineShape, stripMineElemType);
1642  // Transpose the appropriate rows to match output.
1643  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1644  loc, readResult, lastDimToInsertPosPerm);
1645 
1646  // Collapse the vector to the size required by result.
1647  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1648  stripMineTensorType, packMetadata.reassociations);
1649  mlir::VectorType vecCollapsedType =
1650  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1651  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1652  loc, vecCollapsedType, transposeOp->getResult(0));
1653 
1654  // WriteMaskShape had to match the shapecast shape for dynamic sizes,
1655  // otherwise the validator complains that the mask size is invalid.
1656  SmallVector<int64_t> writeMaskShape(
1657  unpackOp.getDestType().hasStaticShape()
1658  ? inputVectorSizes
1659  : shapeCastOp.getResultVectorType().getShape());
1660  Operation *write =
1661  createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
1662  reifiedRetShapes[0], writeMaskShape);
1663  newResults.push_back(write->getResult(0));
1664  return success();
1665 }
1666 
1667 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1668 /// and (3) all-zero lowPad to
1669 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1670 static LogicalResult
1671 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1672  ArrayRef<int64_t> inputVectorSizes,
1673  SmallVectorImpl<Value> &newResults) {
1674  auto padValue = padOp.getConstantPaddingValue();
1675  Location loc = padOp.getLoc();
1676 
1677  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1678  OpBuilder::InsertionGuard g(rewriter);
1679  rewriter.setInsertionPoint(padOp);
1680 
1681  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1682  LogicalResult status =
1683  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1684  .reifyResultShapes(rewriter, reifiedReturnShapes);
1685  (void)status; // prevent unused variable warning on non-assert builds
1686  assert(succeeded(status) && "failed to reify result shapes");
1687  auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
1688  inputVectorSizes, padValue);
1690  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
1691  newResults.push_back(write->getResult(0));
1692  return success();
1693 }
1694 
1695 // TODO: probably need some extra checks for reduction followed by consumer
1696 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1698  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1699  LDBG("reduction precondition failed: no reduction iterator\n");
1700  return failure();
1701  }
1702  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1703  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1704  if (indexingMap.isPermutation())
1705  continue;
1706 
1707  Operation *reduceOp = matchLinalgReduction(&opOperand);
1708  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1709  LDBG("reduction precondition failed: reduction detection failed\n");
1710  return failure();
1711  }
1712  }
1713  return success();
1714 }
1715 
1717  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1718  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1719  if (!isElementwise(op) &&
1720  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1721  op.getOperation()))
1722  return failure();
1723 
1724  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1725  return success();
1726 }
1727 
1728 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
1729 /// given `shape`, i.e., it meets:
1730 /// 1. The numbers of elements in both array are equal.
1731 /// 2. `inputVectorSizes` does nos have dynamic dimensions.
1732 /// 3. All the values in `inputVectorSizes` are greater than or equal to
1733 /// static sizes in `shape`.
1734 static LogicalResult
1736  ArrayRef<int64_t> inputVectorSizes) {
1737  LDBG("Iteration space static sizes:");
1738  LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
1739  LLVM_DEBUG(llvm::dbgs() << "\n");
1740 
1741  if (inputVectorSizes.size() != shape.size()) {
1742  LDBG("Input vector sizes don't match the number of loops");
1743  return failure();
1744  }
1745  if (ShapedType::isDynamicShape(inputVectorSizes)) {
1746  LDBG("Input vector sizes can't have dynamic dimensions");
1747  return failure();
1748  }
1749  if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
1750  [](std::tuple<int64_t, int64_t> sizePair) {
1751  int64_t staticSize = std::get<0>(sizePair);
1752  int64_t inputSize = std::get<1>(sizePair);
1753  return ShapedType::isDynamic(staticSize) ||
1754  staticSize <= inputSize;
1755  })) {
1756  LDBG("Input vector sizes must be greater than or equal to iteration space "
1757  "static sizes");
1758  return failure();
1759  }
1760  return success();
1761 }
1762 
1763 /// Need to check if the inner-tiles are static/constant.
1764 static LogicalResult
1765 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1766  ArrayRef<int64_t> inputVectorSizes) {
1767 
1768  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1769  return !getConstantIntValue(res).has_value();
1770  })) {
1771  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1772  return failure();
1773  }
1774  llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1775  if (!inputVectorSizes.empty() &&
1776  failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
1777  return failure();
1778 
1779  return success();
1780 }
1781 
1782 static LogicalResult
1784  ArrayRef<int64_t> inputVectorSizes,
1785  bool vectorizeNDExtract) {
1786  // tensor with dimension of 0 cannot be vectorized.
1787  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1788  return failure();
1789  // Check API contract for input vector sizes.
1790  if (!inputVectorSizes.empty() &&
1791  failed(isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1792  inputVectorSizes)))
1793  return failure();
1794 
1795  if (linalgOp.hasDynamicShape() &&
1797  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1798  return failure();
1799  }
1800 
1802 
1803  // Register CustomVectorizationPrecondition for extractOp.
1804  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1805 
1806  // All types in the body should be a supported element type for VectorType.
1807  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1808  // Check if any custom hook can vectorize the inner op.
1809  if (llvm::any_of(
1810  customPreconditions,
1811  [&](const CustomVectorizationPrecondition &customPrecondition) {
1812  return succeeded(
1813  customPrecondition(&innerOp, vectorizeNDExtract));
1814  })) {
1815  continue;
1816  }
1817  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1818  return !VectorType::isValidElementType(type);
1819  })) {
1820  return failure();
1821  }
1822  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1823  return !VectorType::isValidElementType(type);
1824  })) {
1825  return failure();
1826  }
1827  }
1828  if (isElementwise(linalgOp))
1829  return success();
1830 
1831  // TODO: isaConvolutionOpInterface that can also infer from generic
1832  // features. But we will still need stride/dilation attributes that will be
1833  // annoying to reverse-engineer...
1834  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1835  return success();
1836  // TODO: the common vector shape is equal to the static loop sizes only when
1837  // all indexing maps are projected permutations. For convs and stencils the
1838  // logic will need to evolve.
1839  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1840  LDBG("precondition failed: not projected permutations\n");
1841  return failure();
1842  }
1843  if (failed(reductionPreconditions(linalgOp))) {
1844  LDBG("precondition failed: reduction preconditions\n");
1845  return failure();
1846  }
1847  return success();
1848 }
1849 
1850 /// TODO: Use a matcher to check for a constant padding value.
1851 static LogicalResult
1852 vectorizePackOpPrecondition(tensor::PackOp packOp,
1853  ArrayRef<int64_t> inputVectorSizes) {
1854  auto padValue = packOp.getPaddingValue();
1855  if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1856  LDBG("pad value is not constant: " << packOp << "\n");
1857  return failure();
1858  }
1859 
1860  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1862  resultTensorShape.take_front(packOp.getSourceRank()),
1863  inputVectorSizes)))
1864  return failure();
1865 
1866  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1867  return !getConstantIntValue(v).has_value();
1868  })) {
1869  LDBG("inner_tiles must be constant: " << packOp << "\n");
1870  return failure();
1871  }
1872 
1873  return success();
1874 }
1875 
1876 static LogicalResult
1877 vectorizePadOpPrecondition(tensor::PadOp padOp,
1878  ArrayRef<int64_t> inputVectorSizes) {
1879  auto padValue = padOp.getConstantPaddingValue();
1880  if (!padValue) {
1881  LDBG("pad value is not constant: " << padOp << "\n");
1882  return failure();
1883  }
1884 
1885  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1886  if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
1887  return failure();
1888 
1889  if (llvm::any_of(padOp.getLow(), [](Value v) {
1890  std::optional<int64_t> res = getConstantIntValue(v);
1891  return !res.has_value() || res.value() != 0;
1892  })) {
1893  LDBG("low pad must all be zero: " << padOp << "\n");
1894  return failure();
1895  }
1896 
1897  return success();
1898 }
1899 
1900 /// Preconditions for scalable vectors.
1901 static LogicalResult
1903  ArrayRef<int64_t> inputVectorSizes,
1904  ArrayRef<bool> inputScalableVecDims) {
1905  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1906  "Number of input vector sizes and scalable dims doesn't match");
1907 
1908  if (inputVectorSizes.empty())
1909  return success();
1910 
1911  bool isScalable = inputScalableVecDims.back();
1912  if (!isScalable)
1913  return success();
1914 
1915  // Only element-wise ops supported in the presence of scalable dims.
1916  auto linalgOp = dyn_cast<LinalgOp>(op);
1917  return success(linalgOp && isElementwise(linalgOp));
1918 }
1919 
1921  Operation *op, ArrayRef<int64_t> inputVectorSizes,
1922  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
1923  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
1924  inputScalableVecDims)))
1925  return failure();
1926 
1928  .Case<linalg::LinalgOp>([&](auto linalgOp) {
1929  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1930  vectorizeNDExtract);
1931  })
1932  .Case<tensor::PadOp>([&](auto padOp) {
1933  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
1934  })
1935  .Case<tensor::PackOp>([&](auto packOp) {
1936  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
1937  })
1938  .Case<tensor::UnPackOp>([&](auto unpackOp) {
1939  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
1940  })
1941  .Default([](auto) { return failure(); });
1942 }
1943 
1944 /// Converts affine.apply Ops to arithmetic operations.
1945 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1946  OpBuilder::InsertionGuard g(rewriter);
1947  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
1948 
1949  for (auto op : make_early_inc_range(toReplace)) {
1950  rewriter.setInsertionPoint(op);
1951  auto expanded = affine::expandAffineExpr(
1952  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
1953  op.getOperands().take_front(op.getAffineMap().getNumDims()),
1954  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
1955  rewriter.replaceOp(op, expanded);
1956  }
1957 }
1958 
1959 /// Emit a suitable vector form for an operation. If provided,
1960 /// `inputVectorSizes` are used to vectorize this operation.
1961 /// `inputVectorSizes` must match the rank of the iteration space of the
1962 /// operation and the input vector sizes must be greater than or equal to
1963 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
1964 /// also allows the vectorization of operations with dynamic shapes.
1966  ArrayRef<int64_t> inputVectorSizes,
1967  ArrayRef<bool> inputScalableVecDims,
1968  bool vectorizeNDExtract,
1969  bool flatten1DDepthwiseConv) {
1970  LDBG("Attempting to vectorize:\n" << *op << "\n");
1971  LDBG("Input vector sizes: ");
1972  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1973  LLVM_DEBUG(llvm::dbgs() << "\n");
1974  LDBG("Input scalable vector dims: ");
1975  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
1976  LLVM_DEBUG(llvm::dbgs() << "\n");
1977 
1978  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
1979  vectorizeNDExtract))) {
1980  LDBG("Vectorization pre-conditions failed\n");
1981  return failure();
1982  }
1983 
1984  // Initialize vectorization state.
1985  VectorizationState state(rewriter);
1986  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
1987  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
1988  inputScalableVecDims))) {
1989  LDBG("Vectorization state couldn't be initialized\n");
1990  return failure();
1991  }
1992  }
1993 
1994  SmallVector<Value> results;
1995  auto vectorizeResult =
1997  .Case<linalg::LinalgOp>([&](auto linalgOp) {
1998  // TODO: isaConvolutionOpInterface that can also infer from
1999  // generic features. Will require stride/dilation attributes
2000  // inference.
2001  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2003  rewriter, linalgOp, flatten1DDepthwiseConv);
2004  if (succeeded(convOr)) {
2005  llvm::append_range(results, (*convOr)->getResults());
2006  return success();
2007  }
2008 
2009  LDBG("Unsupported convolution can't be vectorized.\n");
2010  return failure();
2011  }
2012 
2013  LDBG("Vectorize generic by broadcasting to the canonical vector "
2014  "shape\n");
2015 
2016  // Pre-process before proceeding.
2017  convertAffineApply(rewriter, linalgOp);
2018 
2019  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2020  // to 'OpBuilder' when it is passed over to some methods like
2021  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2022  // erase an op within these methods, the actual rewriter won't be
2023  // notified and we will end up with read-after-free issues!
2024  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2025  })
2026  .Case<tensor::PadOp>([&](auto padOp) {
2027  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2028  results);
2029  })
2030  .Case<tensor::PackOp>([&](auto packOp) {
2031  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2032  results);
2033  })
2034  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2035  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2036  inputVectorSizes, results);
2037  })
2038  .Default([](auto) { return failure(); });
2039 
2040  if (failed(vectorizeResult)) {
2041  LDBG("Vectorization failed\n");
2042  return failure();
2043  }
2044 
2045  if (!results.empty())
2046  rewriter.replaceOp(op, results);
2047  else
2048  rewriter.eraseOp(op);
2049 
2050  return success();
2051 }
2052 
2054  memref::CopyOp copyOp) {
2055  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2056  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2057  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2058  return failure();
2059 
2060  auto srcElementType = getElementTypeOrSelf(srcType);
2061  auto dstElementType = getElementTypeOrSelf(dstType);
2062  if (!VectorType::isValidElementType(srcElementType) ||
2063  !VectorType::isValidElementType(dstElementType))
2064  return failure();
2065 
2066  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2067  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2068 
2069  Location loc = copyOp->getLoc();
2070  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2071  SmallVector<Value> indices(srcType.getRank(), zero);
2072 
2073  Value readValue = rewriter.create<vector::TransferReadOp>(
2074  loc, readType, copyOp.getSource(), indices,
2075  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2076  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2077  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2078  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2079  }
2080  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2081  loc, readValue, copyOp.getTarget(), indices,
2082  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2083  rewriter.replaceOp(copyOp, writeValue->getResults());
2084  return success();
2085 }
2086 
2087 //----------------------------------------------------------------------------//
2088 // Misc. vectorization patterns.
2089 //----------------------------------------------------------------------------//
2090 
2091 /// Helper function that retrieves the value of an IntegerAttr.
2092 static int64_t getIntFromAttr(Attribute attr) {
2093  return cast<IntegerAttr>(attr).getInt();
2094 }
2095 
2096 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2097 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2098 /// not supported.
2100  ArrayRef<OpFoldResult> ofrs) {
2101  SmallVector<Value> result;
2102  for (auto o : ofrs) {
2103  if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2104  result.push_back(val);
2105  } else {
2106  result.push_back(rewriter.create<arith::ConstantIndexOp>(
2107  loc, getIntFromAttr(o.template get<Attribute>())));
2108  }
2109  }
2110  return result;
2111 }
2112 
2113 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2114 /// InsertSliceOp. For now, only constant padding values are supported.
2115 /// If there is enough static type information, TransferReadOps and
2116 /// TransferWriteOps may be generated instead of InsertSliceOps.
2119  PatternBenefit benefit = 1)
2120  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2121  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2122  /// each dimension size is statically know in the source type or the result
2123  /// type (or both).
2125  tensor::PadOp padOp, Value dest) {
2126  auto sourceType = padOp.getSourceType();
2127  auto resultType = padOp.getResultType();
2128  if (!VectorType::isValidElementType(sourceType.getElementType()))
2129  return failure();
2130 
2131  // Copy cannot be vectorized if pad value is non-constant and source shape
2132  // is dynamic. In case of a dynamic source shape, padding must be appended
2133  // by TransferReadOp, but TransferReadOp supports only constant padding.
2134  auto padValue = padOp.getConstantPaddingValue();
2135  if (!padValue) {
2136  if (!sourceType.hasStaticShape())
2137  return failure();
2138  // Create dummy padding value.
2139  auto elemType = sourceType.getElementType();
2140  padValue = rewriter.create<arith::ConstantOp>(
2141  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2142  }
2143 
2144  SmallVector<int64_t> vecShape;
2145  SmallVector<bool> readInBounds;
2146  SmallVector<bool> writeInBounds;
2147  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2148  if (!sourceType.isDynamicDim(i)) {
2149  vecShape.push_back(sourceType.getDimSize(i));
2150  // Source shape is statically known: Neither read nor write are
2151  // out-of- bounds.
2152  readInBounds.push_back(true);
2153  writeInBounds.push_back(true);
2154  } else if (!resultType.isDynamicDim(i)) {
2155  // Source shape is not statically known, but result shape is.
2156  // Vectorize with size of result shape. This may be larger than the
2157  // source size.
2158  vecShape.push_back(resultType.getDimSize(i));
2159  // Read may be out-of-bounds because the result size could be larger
2160  // than the source size.
2161  readInBounds.push_back(false);
2162  // Write is out-of-bounds if low padding > 0.
2163  writeInBounds.push_back(
2164  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2165  static_cast<int64_t>(0));
2166  } else {
2167  // Neither source nor result dim of padOp is static. Cannot vectorize
2168  // the copy.
2169  return failure();
2170  }
2171  }
2172  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2173 
2174  // Generate TransferReadOp.
2175  SmallVector<Value> readIndices(
2176  vecType.getRank(),
2177  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2178  auto read = rewriter.create<vector::TransferReadOp>(
2179  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2180  ArrayRef<bool>{readInBounds});
2181 
2182  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2183  // entire tensor, write directly to the FillOp's operand.
2184  if (llvm::equal(vecShape, resultType.getShape()) &&
2185  llvm::all_of(writeInBounds, [](bool b) { return b; }))
2186  if (auto fill = dest.getDefiningOp<FillOp>())
2187  dest = fill.output();
2188 
2189  // Generate TransferWriteOp.
2190  auto writeIndices =
2191  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2192  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2193  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2194 
2195  return success();
2196  }
2197 };
2198 
2199 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2200 /// given operation type OpTy.
2201 template <typename OpTy>
2202 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2204 
2205  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2206  PatternRewriter &rewriter) const final {
2207  bool changed = false;
2208  // Insert users in vector, because some users may be replaced/removed.
2209  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2210  if (auto op = dyn_cast<OpTy>(user))
2211  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2212  return success(changed);
2213  }
2214 
2215 protected:
2217  tensor::PadOp padOp, OpTy op) const = 0;
2218 };
2219 
2220 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2221 /// ```
2222 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2223 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2224 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2225 /// ```
2226 /// is rewritten to:
2227 /// ```
2228 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2229 /// {in_bounds = [true, true]}
2230 /// : tensor<?x?xf32>, vector<17x5xf32>
2231 /// ```
2232 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2233 /// sure that the original padding value %cst was never used.
2234 ///
2235 /// This rewrite is possible if:
2236 /// - `xferOp` has no out-of-bounds dims or mask.
2237 /// - Low padding is static 0.
2238 /// - Single, scalar padding value.
2240  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2242  vector::TransferReadOp>::VectorizePadOpUserPattern;
2243 
2244  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2245  vector::TransferReadOp xferOp) const override {
2246  // Low padding must be static 0.
2247  if (!padOp.hasZeroLowPad())
2248  return failure();
2249  // Pad value must be a constant.
2250  auto padValue = padOp.getConstantPaddingValue();
2251  if (!padValue)
2252  return failure();
2253  // Padding value of existing `xferOp` is unused.
2254  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2255  return failure();
2256 
2257  rewriter.modifyOpInPlace(xferOp, [&]() {
2258  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2259  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2260  rewriter.getBoolArrayAttr(inBounds));
2261  xferOp.getSourceMutable().assign(padOp.getSource());
2262  xferOp.getPaddingMutable().assign(padValue);
2263  });
2264 
2265  return success();
2266  }
2267 };
2268 
2269 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2270 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2271 /// value, where the same amount of padding is immediately removed again after
2272 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2273 /// tensor value and apply out-of-bounds masking. E.g.:
2274 /// ```
2275 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2276 /// : tensor<...> to tensor<?x?xf32>
2277 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2278 /// %2 = vector.transfer_write %vec, %1[...]
2279 /// : vector<17x5xf32>, tensor<17x5xf32>
2280 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2281 /// : tensor<17x5xf32> to tensor<?x?xf32>
2282 /// ```
2283 /// is rewritten to:
2284 /// ```
2285 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2286 /// : tensor<...> to tensor<?x?xf32>
2287 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2288 /// tensor<?x?xf32>
2289 /// ```
2290 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2291 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2292 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2293 /// from %r's old dimensions.
2294 ///
2295 /// This rewrite is possible if:
2296 /// - Low padding is static 0.
2297 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2298 /// ExtractSliceOp trims the same amount of padding that was added
2299 /// beforehand.
2300 /// - Single, scalar padding value.
2302  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2304  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2305 
2306  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2307  vector::TransferWriteOp xferOp) const override {
2308  // TODO: support 0-d corner case.
2309  if (xferOp.getTransferRank() == 0)
2310  return failure();
2311 
2312  // Low padding must be static 0.
2313  if (!padOp.hasZeroLowPad())
2314  return failure();
2315  // Pad value must be a constant.
2316  auto padValue = padOp.getConstantPaddingValue();
2317  if (!padValue)
2318  return failure();
2319  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2320  if (!xferOp->hasOneUse())
2321  return failure();
2322  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2323  if (!trimPadding)
2324  return failure();
2325  // Only static zero offsets supported when trimming padding.
2326  if (!trimPadding.hasZeroOffset())
2327  return failure();
2328  // trimPadding must remove the amount of padding that was added earlier.
2329  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2330  return failure();
2331 
2332  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2333  rewriter.setInsertionPoint(xferOp);
2334 
2335  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2336  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2337  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2338  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2339  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2340  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2341 
2342  return success();
2343  }
2344 
2345  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2346  /// i.e., same dimensions.
2347  ///
2348  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2349  /// dimensions, this function tries to infer the (static) tensor size by
2350  /// looking at the defining op and utilizing op-specific knowledge.
2351  ///
2352  /// This is a conservative analysis. In case equal tensor sizes cannot be
2353  /// proven statically, this analysis returns `false` even though the tensor
2354  /// sizes may turn out to be equal at runtime.
2355  bool hasSameTensorSize(Value beforePadding,
2356  tensor::ExtractSliceOp afterTrimming) const {
2357  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2358  // result and CastOp operand.
2359  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2360  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2361  return true;
2362 
2363  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2364  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2365  // Only RankedTensorType supported.
2366  if (!t1 || !t2)
2367  return false;
2368  // Rank of both values must be the same.
2369  if (t1.getRank() != t2.getRank())
2370  return false;
2371 
2372  // All static dimensions must be the same. Mixed cases (e.g., dimension
2373  // static in `t1` but dynamic in `t2`) are not supported.
2374  for (unsigned i = 0; i < t1.getRank(); ++i) {
2375  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2376  return false;
2377  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2378  return false;
2379  }
2380 
2381  // Nothing more to check if all dimensions are static.
2382  if (t1.getNumDynamicDims() == 0)
2383  return true;
2384 
2385  // All dynamic sizes must be the same. The only supported case at the
2386  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2387  // thereof).
2388 
2389  // Apart from CastOp, only ExtractSliceOp is supported.
2390  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2391  if (!beforeSlice)
2392  return false;
2393 
2394  assert(static_cast<size_t>(t1.getRank()) ==
2395  beforeSlice.getMixedSizes().size());
2396  assert(static_cast<size_t>(t2.getRank()) ==
2397  afterTrimming.getMixedSizes().size());
2398 
2399  for (unsigned i = 0; i < t1.getRank(); ++i) {
2400  // Skip static dimensions.
2401  if (!t1.isDynamicDim(i))
2402  continue;
2403  auto size1 = beforeSlice.getMixedSizes()[i];
2404  auto size2 = afterTrimming.getMixedSizes()[i];
2405 
2406  // Case 1: Same value or same constant int.
2407  if (isEqualConstantIntOrValue(size1, size2))
2408  continue;
2409 
2410  // Other cases: Take a deeper look at defining ops of values.
2411  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2412  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2413  if (!v1 || !v2)
2414  return false;
2415 
2416  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2417  // CSE is run.)
2418  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2419  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2420  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2421  minOp1.getOperands() == minOp2.getOperands())
2422  continue;
2423 
2424  // Add additional cases as needed.
2425  }
2426 
2427  // All tests passed.
2428  return true;
2429  }
2430 };
2431 
2432 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2433 /// ```
2434 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2435 /// %r = tensor.insert_slice %0
2436 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2437 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2438 /// ```
2439 /// is rewritten to:
2440 /// ```
2441 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2442 /// : tensor<?x?xf32>, vector<17x5xf32>
2443 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2444 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2445 /// ```
2446 ///
2447 /// This rewrite is possible if:
2448 /// - Low padding is static 0.
2449 /// - `padOp` result shape is static.
2450 /// - The entire padded tensor is inserted.
2451 /// (Implies that sizes of `insertOp` are all static.)
2452 /// - Only unit strides in `insertOp`.
2453 /// - Single, scalar padding value.
2454 /// - `padOp` result not used as destination.
2456  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2458  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2459 
2460  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2461  tensor::InsertSliceOp insertOp) const override {
2462  // Low padding must be static 0.
2463  if (!padOp.hasZeroLowPad())
2464  return failure();
2465  // Only unit stride supported.
2466  if (!insertOp.hasUnitStride())
2467  return failure();
2468  // Pad value must be a constant.
2469  auto padValue = padOp.getConstantPaddingValue();
2470  if (!padValue)
2471  return failure();
2472  // Dynamic shapes not supported.
2473  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2474  return failure();
2475  // Pad result not used as destination.
2476  if (insertOp.getDest() == padOp.getResult())
2477  return failure();
2478 
2479  auto vecType = VectorType::get(padOp.getType().getShape(),
2480  padOp.getType().getElementType());
2481  unsigned vecRank = vecType.getRank();
2482  unsigned tensorRank = insertOp.getType().getRank();
2483 
2484  // Check if sizes match: Insert the entire tensor into most minor dims.
2485  // (No permutations allowed.)
2486  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2487  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2488  if (!llvm::all_of(
2489  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2490  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2491  }))
2492  return failure();
2493 
2494  // Insert the TransferReadOp and TransferWriteOp at the position of the
2495  // InsertSliceOp.
2496  rewriter.setInsertionPoint(insertOp);
2497 
2498  // Generate TransferReadOp: Read entire source tensor and add high
2499  // padding.
2500  SmallVector<Value> readIndices(
2501  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2502  auto read = rewriter.create<vector::TransferReadOp>(
2503  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2504 
2505  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2506  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2507  // source must fit into the destination at the specified offsets.
2508  auto writeIndices =
2509  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2510  SmallVector<bool> inBounds(vecRank, true);
2511  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2512  insertOp, read, insertOp.getDest(), writeIndices,
2513  ArrayRef<bool>{inBounds});
2514 
2515  return success();
2516  }
2517 };
2518 
2520  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2521  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2522  baseBenefit);
2523  // Try these specialized patterns first before resorting to the generic one.
2527  patterns.getContext(), baseBenefit.getBenefit() + 1);
2528 }
2529 
2530 //----------------------------------------------------------------------------//
2531 // Forwarding patterns
2532 //----------------------------------------------------------------------------//
2533 
2534 /// Check whether there is any interleaved use of any `values` between
2535 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2536 /// is in a different block.
2537 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2538  ValueRange values) {
2539  if (firstOp->getBlock() != secondOp->getBlock() ||
2540  !firstOp->isBeforeInBlock(secondOp)) {
2541  LDBG("interleavedUses precondition failed, firstOp: "
2542  << *firstOp << ", second op: " << *secondOp << "\n");
2543  return true;
2544  }
2545  for (auto v : values) {
2546  for (auto &u : v.getUses()) {
2547  Operation *owner = u.getOwner();
2548  if (owner == firstOp || owner == secondOp)
2549  continue;
2550  // TODO: this is too conservative, use dominance info in the future.
2551  if (owner->getBlock() == firstOp->getBlock() &&
2552  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2553  continue;
2554  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2555  << ", second op: " << *secondOp << "\n");
2556  return true;
2557  }
2558  }
2559  return false;
2560 }
2561 
2562 /// Return the unique subview use of `v` if it is indeed unique, null
2563 /// otherwise.
2564 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2565  memref::SubViewOp subViewOp;
2566  for (auto &u : v.getUses()) {
2567  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2568  if (subViewOp)
2569  return memref::SubViewOp();
2570  subViewOp = newSubViewOp;
2571  }
2572  }
2573  return subViewOp;
2574 }
2575 
2576 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2577 /// when available.
2579  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2580 
2581  // TODO: support mask.
2582  if (xferOp.getMask())
2583  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2584 
2585  // Transfer into `view`.
2586  Value viewOrAlloc = xferOp.getSource();
2587  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2588  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2589  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2590 
2591  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2592  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2593  if (!subViewOp)
2594  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2595  Value subView = subViewOp.getResult();
2596 
2597  // Find the copy into `subView` without interleaved uses.
2598  memref::CopyOp copyOp;
2599  for (auto &u : subView.getUses()) {
2600  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2601  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2602  if (newCopyOp.getTarget() != subView)
2603  continue;
2604  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2605  continue;
2606  copyOp = newCopyOp;
2607  break;
2608  }
2609  }
2610  if (!copyOp)
2611  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2612 
2613  // Find the fill into `viewOrAlloc` without interleaved uses before the
2614  // copy.
2615  FillOp maybeFillOp;
2616  for (auto &u : viewOrAlloc.getUses()) {
2617  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2618  assert(isa<MemRefType>(newFillOp.output().getType()));
2619  if (newFillOp.output() != viewOrAlloc)
2620  continue;
2621  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2622  continue;
2623  maybeFillOp = newFillOp;
2624  break;
2625  }
2626  }
2627  // Ensure padding matches.
2628  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2629  return rewriter.notifyMatchFailure(xferOp,
2630  "padding value does not match fill");
2631 
2632  // `in` is the subview that memref.copy reads. Replace it.
2633  Value in = copyOp.getSource();
2634 
2635  // memref.copy + linalg.fill can be used to create a padded local buffer.
2636  // The `masked` attribute is only valid on this padded buffer.
2637  // When forwarding to vector.transfer_read, the attribute must be reset
2638  // conservatively.
2639  Value res = rewriter.create<vector::TransferReadOp>(
2640  xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2641  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2642  // in_bounds is explicitly reset
2643  /*inBoundsAttr=*/ArrayAttr());
2644 
2645  if (maybeFillOp)
2646  rewriter.eraseOp(maybeFillOp);
2647  rewriter.eraseOp(copyOp);
2648  rewriter.replaceOp(xferOp, res);
2649 
2650  return success();
2651 }
2652 
2653 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2654 /// when available.
2656  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2657  // TODO: support mask.
2658  if (xferOp.getMask())
2659  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2660 
2661  // Transfer into `viewOrAlloc`.
2662  Value viewOrAlloc = xferOp.getSource();
2663  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2664  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2665  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2666 
2667  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2668  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2669  if (!subViewOp)
2670  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2671  Value subView = subViewOp.getResult();
2672 
2673  // Find the copy from `subView` without interleaved uses.
2674  memref::CopyOp copyOp;
2675  for (auto &u : subViewOp.getResult().getUses()) {
2676  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2677  if (newCopyOp.getSource() != subView)
2678  continue;
2679  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2680  continue;
2681  copyOp = newCopyOp;
2682  break;
2683  }
2684  }
2685  if (!copyOp)
2686  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2687 
2688  // `out` is the subview copied into that we replace.
2689  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2690  Value out = copyOp.getTarget();
2691 
2692  // Forward vector.transfer into copy.
2693  // memref.copy + linalg.fill can be used to create a padded local buffer.
2694  // The `masked` attribute is only valid on this padded buffer.
2695  // When forwarding to vector.transfer_write, the attribute must be reset
2696  // conservatively.
2697  rewriter.create<vector::TransferWriteOp>(
2698  xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2699  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2700  // in_bounds is explicitly reset
2701  /*inBoundsAttr=*/ArrayAttr());
2702 
2703  rewriter.eraseOp(copyOp);
2704  rewriter.eraseOp(xferOp);
2705 
2706  return success();
2707 }
2708 
2709 //===----------------------------------------------------------------------===//
2710 // Convolution vectorization patterns
2711 //===----------------------------------------------------------------------===//
2712 
2713 template <int N>
2714 static void bindShapeDims(ShapedType shapedType) {}
2715 
2716 template <int N, typename IntTy, typename... IntTy2>
2717 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2718  val = shapedType.getShape()[N];
2719  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2720 }
2721 
2722 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2723 template <typename... IntTy>
2724 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2725  bindShapeDims<0>(shapedType, vals...);
2726 }
2727 
2728 namespace {
2729 bool isCastOfBlockArgument(Operation *op) {
2730  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2731  isa<BlockArgument>(op->getOperand(0));
2732 }
2733 
2734 bool isSupportedPoolKind(vector::CombiningKind kind) {
2735  switch (kind) {
2736  case vector::CombiningKind::ADD:
2737  case vector::CombiningKind::MAXNUMF:
2738  case vector::CombiningKind::MAXIMUMF:
2739  case vector::CombiningKind::MAXSI:
2740  case vector::CombiningKind::MAXUI:
2741  case vector::CombiningKind::MINNUMF:
2742  case vector::CombiningKind::MINIMUMF:
2743  case vector::CombiningKind::MINSI:
2745  return true;
2746  default:
2747  return false;
2748  }
2749 }
2750 
2751 /// Generate a vector implementation for either:
2752 /// ```
2753 /// Op def: ( w, kw )
2754 /// Iters: ({Par(), Red()})
2755 /// Layout: {{w + kw}, {kw}, {w}}
2756 /// ```
2757 /// kw is unrolled.
2758 ///
2759 /// or
2760 ///
2761 /// ```
2762 /// Op def: ( n, w, c, kw, f )
2763 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2764 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2765 /// ```
2766 /// kw is unrolled, w is unrolled iff dilationW > 1.
2767 ///
2768 /// or
2769 ///
2770 /// ```
2771 /// Op def: ( n, c, w, f, kw )
2772 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2773 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2774 /// ```
2775 /// kw is unrolled, w is unrolled iff dilationW > 1.
2776 ///
2777 /// or
2778 ///
2779 /// ```
2780 /// Op def: ( n, w, c, kw )
2781 /// Iters: ({Par(), Par(), Par(), Red()})
2782 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2783 /// ```
2784 /// kw is unrolled, w is unrolled iff dilationW > 1.
2785 struct Conv1DGenerator
2786  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2787  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2788  int dilationW)
2789  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2790  strideW(strideW), dilationW(dilationW) {
2791  // Determine whether `linalgOp` can be generated with this generator
2792  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2793  return;
2794  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2795  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2796  resShaped = linalgOp.getDpsInitOperand(0)->get();
2797  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2798  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2799  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2800  if (!lhsShapedType || !rhsShapedType || !resShapedType)
2801  return;
2802  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2803  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2804  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2805  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2806  return;
2807 
2808  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2809  if (!reduceOp)
2810  return;
2811  redOp = reduceOp->getName().getIdentifier();
2812 
2813  if (!setOperKind(reduceOp))
2814  return;
2815  auto maybeKind = getCombinerOpKind(reduceOp);
2816  if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2817  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2818  return;
2819  }
2820 
2821  auto rhsRank = rhsShapedType.getRank();
2822  switch (oper) {
2823  case Conv:
2824  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2825  return;
2826  break;
2827  case Pool:
2828  if (rhsRank != 1)
2829  return;
2830  break;
2831  }
2832  // The op is now known to be valid.
2833  valid = true;
2834  }
2835 
2836  /// Generate a vector implementation for:
2837  /// ```
2838  /// Op def: ( w, kw )
2839  /// Iters: ({Par(), Red()})
2840  /// Layout: {{w + kw}, {kw}, {w}}
2841  /// ```
2842  /// kw is always unrolled.
2843  ///
2844  /// or
2845  ///
2846  /// ```
2847  /// Op def: ( n, w, c, kw, f )
2848  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2849  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2850  /// ```
2851  /// kw is always unrolled.
2852  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2853  /// > 1.
2854  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
2855  if (!valid)
2856  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
2857 
2858  int64_t nSize, wSize, cSize, kwSize, fSize;
2859  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
2860  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
2861  switch (conv1DOpOrder) {
2862  case Conv1DOpOrder::W:
2863  // Initialize unused dimensions
2864  nSize = fSize = cSize = 0;
2865  // out{W}
2866  bindShapeDims(resShapedType, wSize);
2867  // kernel{kw}
2868  bindShapeDims(rhsShapedType, kwSize);
2869  lhsShape = {// iw = ow + kw - 1
2870  // (i.e. 16 convolved with 3 -> 14)
2871  (wSize + kwSize - 1)};
2872  rhsShape = {kwSize};
2873  resShape = {wSize};
2874  break;
2875  case Conv1DOpOrder::Nwc:
2876  // out{n, w, f}
2877  bindShapeDims(resShapedType, nSize, wSize, fSize);
2878  switch (oper) {
2879  case Conv:
2880  // kernel{kw, c, f}
2881  bindShapeDims(rhsShapedType, kwSize, cSize);
2882  break;
2883  case Pool:
2884  // kernel{kw}
2885  bindShapeDims(rhsShapedType, kwSize);
2886  cSize = fSize;
2887  break;
2888  }
2889  lhsShape = {nSize,
2890  // iw = ow * sw + kw * dw - 1
2891  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2892  // Perform the proper inclusive -> exclusive -> inclusive.
2893  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2894  1,
2895  cSize};
2896  switch (oper) {
2897  case Conv:
2898  rhsShape = {kwSize, cSize, fSize};
2899  break;
2900  case Pool:
2901  rhsShape = {kwSize};
2902  break;
2903  }
2904  resShape = {nSize, wSize, fSize};
2905  break;
2906  case Conv1DOpOrder::Ncw:
2907  // out{n, f, w}
2908  bindShapeDims(resShapedType, nSize, fSize, wSize);
2909  switch (oper) {
2910  case Conv:
2911  // kernel{f, c, kw}
2912  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
2913  break;
2914  case Pool:
2915  // kernel{kw}
2916  bindShapeDims(rhsShapedType, kwSize);
2917  cSize = fSize;
2918  break;
2919  }
2920  lhsShape = {nSize, cSize,
2921  // iw = ow * sw + kw * dw - 1
2922  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2923  // Perform the proper inclusive -> exclusive -> inclusive.
2924  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2925  1};
2926  switch (oper) {
2927  case Conv:
2928  rhsShape = {fSize, cSize, kwSize};
2929  break;
2930  case Pool:
2931  rhsShape = {kwSize};
2932  break;
2933  }
2934  resShape = {nSize, fSize, wSize};
2935  break;
2936  }
2937 
2938  vector::TransferWriteOp write;
2939  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2940 
2941  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2942  // When strideW == 1, we can batch the contiguous loads and avoid
2943  // unrolling
2944  int64_t wSizeStep = strideW == 1 ? wSize : 1;
2945 
2946  Type lhsEltType = lhsShapedType.getElementType();
2947  Type rhsEltType = rhsShapedType.getElementType();
2948  Type resEltType = resShapedType.getElementType();
2949  auto lhsType = VectorType::get(lhsShape, lhsEltType);
2950  auto rhsType = VectorType::get(rhsShape, rhsEltType);
2951  auto resType = VectorType::get(resShape, resEltType);
2952  // Zero padding with the corresponding dimensions for lhs, rhs and res.
2953  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
2954  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
2955  SmallVector<Value> resPadding(resShape.size(), zero);
2956 
2957  // Read the whole lhs, rhs and res in one shot (with zero padding).
2958  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2959  lhsPadding);
2960  // This is needed only for Conv.
2961  Value rhs = nullptr;
2962  if (oper == Conv)
2963  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2964  rhsPadding);
2965  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
2966  resPadding);
2967 
2968  // The base vectorization case for channeled convolution is input:
2969  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2970  // vectorization case, we do pre transpose on input, weight, and output.
2971  switch (conv1DOpOrder) {
2972  case Conv1DOpOrder::W:
2973  case Conv1DOpOrder::Nwc:
2974  // Base case, so no transposes necessary.
2975  break;
2976  case Conv1DOpOrder::Ncw: {
2977  // To match base vectorization case, we pre-transpose current case.
2978  // ncw -> nwc
2979  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2980  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
2981  // fcw -> wcf
2982  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2983 
2984  // This is needed only for Conv.
2985  if (oper == Conv)
2986  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
2987  // nfw -> nwf
2988  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2989  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
2990  break;
2991  }
2992  }
2993 
2994  //===------------------------------------------------------------------===//
2995  // Begin vector-only rewrite part
2996  //===------------------------------------------------------------------===//
2997  // Unroll along kw and read slices of lhs and rhs.
2998  SmallVector<Value> lhsVals, rhsVals, resVals;
2999  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3000  kwSize, strideW, dilationW, wSizeStep,
3001  isSingleChanneled);
3002  // Do not do for pooling.
3003  if (oper == Conv)
3004  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3005  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3006  wSizeStep, isSingleChanneled);
3007 
3008  auto linearIndex = [&](int64_t kw, int64_t w) {
3009  return kw * (wSize / wSizeStep) + w;
3010  };
3011 
3012  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3013  // or perform outerproduct for non-channeled convolution or perform simple
3014  // arith operation for pooling
3015  for (int64_t kw = 0; kw < kwSize; ++kw) {
3016  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3017  switch (oper) {
3018  case Conv:
3019  if (isSingleChanneled) {
3020  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3021  lhsVals[linearIndex(kw, w)],
3022  rhsVals[kw], resVals[w]);
3023  } else {
3024  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3025  lhsVals[linearIndex(kw, w)],
3026  rhsVals[kw], resVals[w]);
3027  }
3028  break;
3029  case Pool:
3030  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3031  resVals[w]);
3032  break;
3033  }
3034  }
3035  }
3036 
3037  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3038  isSingleChanneled);
3039  //===------------------------------------------------------------------===//
3040  // End vector-only rewrite part
3041  //===------------------------------------------------------------------===//
3042 
3043  // The base vectorization case for channeled convolution is output:
3044  // {n,w,f} To reuse the result from base pattern vectorization case, we
3045  // post transpose the base case result.
3046  switch (conv1DOpOrder) {
3047  case Conv1DOpOrder::W:
3048  case Conv1DOpOrder::Nwc:
3049  // Base case, so no transposes necessary.
3050  break;
3051  case Conv1DOpOrder::Ncw: {
3052  // nwf -> nfw
3053  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3054  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3055  break;
3056  }
3057  }
3058 
3059  return rewriter
3060  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3061  .getOperation();
3062  }
3063 
3064  // Take a value and widen to have the same element type as `ty`.
3065  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3066  const Type srcElementType = getElementTypeOrSelf(val.getType());
3067  const Type dstElementType = getElementTypeOrSelf(ty);
3068  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3069  if (srcElementType == dstElementType)
3070  return val;
3071 
3072  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3073  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3074  const Type dstType =
3075  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3076 
3077  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3078  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3079  }
3080 
3081  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3082  srcWidth < dstWidth)
3083  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3084 
3085  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3086  srcWidth < dstWidth)
3087  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3088 
3089  assert(false && "unhandled promotion case");
3090  return nullptr;
3091  }
3092 
3093  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3094  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3095  Value lhs, Value rhs, Value res) {
3096  vector::IteratorType par = vector::IteratorType::parallel;
3097  vector::IteratorType red = vector::IteratorType::reduction;
3098  AffineExpr n, w, f, c;
3099  bindDims(ctx, n, w, f, c);
3100  lhs = promote(rewriter, loc, lhs, res.getType());
3101  rhs = promote(rewriter, loc, rhs, res.getType());
3102  return rewriter.create<vector::ContractionOp>(
3103  loc, lhs, rhs, res,
3104  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3105  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3106  }
3107 
3108  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3109  // convolution.
3110  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3111  Value lhs, Value rhs, Value res) {
3112  return rewriter.create<vector::OuterProductOp>(
3113  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3114  }
3115 
3116  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3117  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3118  Value res) {
3119  if (isPoolExt)
3120  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3121  return rewriter
3122  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3123  ->getResult(0);
3124  }
3125 
3126  /// Generate a vector implementation for:
3127  /// ```
3128  /// Op def: ( n, w, c, kw)
3129  /// Iters: ({Par(), Par(), Par(), Red()})
3130  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3131  /// ```
3132  /// kw is always unrolled.
3133  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3134  /// > 1.
3135  FailureOr<Operation *> depthwiseConv(bool flatten) {
3136  if (!valid)
3137  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3138 
3139  int64_t nSize, wSize, cSize, kwSize;
3140  // kernel{kw, c}
3141  bindShapeDims(rhsShapedType, kwSize, cSize);
3142  // out{n, w, c}
3143  bindShapeDims(resShapedType, nSize, wSize);
3144 
3145  vector::TransferWriteOp write;
3146  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3147 
3148  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3149  // When strideW == 1, we can batch the contiguous loads and avoid
3150  // unrolling
3151  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3152 
3153  Type lhsEltType = lhsShapedType.getElementType();
3154  Type rhsEltType = rhsShapedType.getElementType();
3155  Type resEltType = resShapedType.getElementType();
3156  VectorType lhsType = VectorType::get(
3157  {nSize,
3158  // iw = ow * sw + kw * dw - 1
3159  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3160  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3161  cSize},
3162  lhsEltType);
3163  VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
3164  VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
3165 
3166  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3167  // 0].
3168  Value lhs = rewriter.create<vector::TransferReadOp>(
3169  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3170  // Read rhs slice of size {kw, c} @ [0, 0].
3171  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3172  ValueRange{zero, zero});
3173  // Read res slice of size {n, w, c} @ [0, 0, 0].
3174  Value res = rewriter.create<vector::TransferReadOp>(
3175  loc, resType, resShaped, ValueRange{zero, zero, zero});
3176 
3177  //===------------------------------------------------------------------===//
3178  // Begin vector-only rewrite part
3179  //===------------------------------------------------------------------===//
3180  // Unroll along kw and read slices of lhs and rhs.
3181  SmallVector<Value> lhsVals, rhsVals, resVals;
3182  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3183  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3184 
3185  // Extract lhs slice of size {n, wSizeStep, c}
3186  // @ [0, sw * w + dw * kw, 0].
3187  for (int64_t kw = 0; kw < kwSize; ++kw) {
3188  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3189  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3190  loc, lhs,
3191  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3192  inOutSliceSizes, inOutStrides));
3193  }
3194  }
3195  // Extract rhs slice of size {c} @ [kw].
3196  for (int64_t kw = 0; kw < kwSize; ++kw) {
3197  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3198  loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
3199  }
3200  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3201  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3202  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3203  loc, res,
3204  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3205  inOutStrides));
3206  }
3207 
3208  auto linearIndex = [&](int64_t kw, int64_t w) {
3209  return kw * (wSize / wSizeStep) + w;
3210  };
3211 
3212  auto inOutFlattenSliceSizes =
3213  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3214  auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3215  auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
3216  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3217  for (int64_t kw = 0; kw < kwSize; ++kw) {
3218  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3219  Value lhsVal = lhsVals[linearIndex(kw, w)];
3220  Value resVal = resVals[w];
3221  if (flatten) {
3222  // Flatten the input and output vectors (collapse the channel
3223  // dimension)
3224  lhsVal = rewriter.create<vector::ShapeCastOp>(
3225  loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
3226  resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
3227  resVals[w]);
3228  }
3229  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3230  rhsVals[kw], resVal, flatten);
3231  if (flatten) {
3232  // Un-flatten the output vector (restore the channel dimension)
3233  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3234  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3235  }
3236  }
3237  }
3238 
3239  // Its possible we failed to create the Fma.
3240  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3241  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3242  for (auto &collection :
3243  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3244  for (Value v : collection)
3245  rewriter.eraseOp(v.getDefiningOp());
3246  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3247  }
3248 
3249  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3250  // This does not depend on kw.
3251  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3252  res = rewriter.create<vector::InsertStridedSliceOp>(
3253  loc, resVals[w], res,
3254  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3255  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3256  }
3257  //===------------------------------------------------------------------===//
3258  // End vector-only rewrite part
3259  //===------------------------------------------------------------------===//
3260 
3261  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3262  return rewriter
3263  .create<vector::TransferWriteOp>(loc, res, resShaped,
3264  ValueRange{zero, zero, zero})
3265  .getOperation();
3266  }
3267 
3268  /// Lower:
3269  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3270  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3271  /// to MulAcc.
3272  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3273  Value lhs, Value rhs, Value res,
3274  bool flatten) {
3275  auto rhsTy = cast<ShapedType>(rhs.getType());
3276  auto resTy = cast<ShapedType>(res.getType());
3277 
3278  // TODO(suderman): Change this to use a vector.ima intrinsic.
3279  lhs = promote(rewriter, loc, lhs, resTy);
3280 
3281  if (flatten) {
3282  // There are two options for handling the filter:
3283  // * shape_cast(broadcast(filter))
3284  // * broadcast(shuffle(filter))
3285  // Opt for the option without shape_cast to simplify the codegen.
3286  auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
3287  auto resSize = res.getType().cast<VectorType>().getShape()[1];
3288 
3289  SmallVector<int64_t, 16> indicies;
3290  for (int i = 0; i < resSize / rhsSize; ++i) {
3291  for (int j = 0; j < rhsSize; ++j)
3292  indicies.push_back(j);
3293  }
3294 
3295  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
3296  }
3297  // Broadcast the filter to match the output vector
3298  rhs = rewriter.create<vector::BroadcastOp>(
3299  loc, resTy.clone(rhsTy.getElementType()), rhs);
3300 
3301  rhs = promote(rewriter, loc, rhs, resTy);
3302 
3303  if (!lhs || !rhs)
3304  return nullptr;
3305 
3306  if (isa<FloatType>(resTy.getElementType()))
3307  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3308 
3309  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3310  return rewriter.create<arith::AddIOp>(loc, mul, res);
3311  }
3312 
3313  /// Entry point for non-channeled convolution:
3314  /// {{w + kw}, {kw}, {w}}
3315  FailureOr<Operation *> generateNonChanneledConv() {
3316  AffineExpr w, kw;
3317  bindDims(ctx, w, kw);
3318  if (!iters({Par(), Red()}))
3319  return rewriter.notifyMatchFailure(op,
3320  "failed to match conv::W 1-par 1-red");
3321 
3322  // No transposition needed.
3323  if (layout({/*lhsIndex*/ {w + kw},
3324  /*rhsIndex*/ {kw},
3325  /*resIndex*/ {w}}))
3326  return conv(Conv1DOpOrder::W);
3327 
3328  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3329  }
3330 
3331  /// Entry point that transposes into the common form:
3332  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3333  FailureOr<Operation *> generateNwcConv() {
3334  AffineExpr n, w, f, kw, c;
3335  bindDims(ctx, n, w, f, kw, c);
3336  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3337  return rewriter.notifyMatchFailure(
3338  op, "failed to match conv::Nwc 3-par 2-red");
3339 
3340  // No transposition needed.
3341  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3342  /*rhsIndex*/ {kw, c, f},
3343  /*resIndex*/ {n, w, f}}))
3344  return conv(Conv1DOpOrder::Nwc);
3345 
3346  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3347  }
3348 
3349  /// Entry point that transposes into the common form:
3350  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3351  FailureOr<Operation *> generateNcwConv() {
3352  AffineExpr n, w, f, kw, c;
3353  bindDims(ctx, n, f, w, c, kw);
3354  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3355  return rewriter.notifyMatchFailure(
3356  op, "failed to match conv::Ncw 3-par 2-red");
3357 
3358  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3359  /*rhsIndex*/ {f, c, kw},
3360  /*resIndex*/ {n, f, w}}))
3361  return conv(Conv1DOpOrder::Ncw);
3362 
3363  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3364  }
3365 
3366  /// Entry point that transposes into the common form:
3367  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3368  FailureOr<Operation *> generateNwcPooling() {
3369  AffineExpr n, w, c, kw;
3370  bindDims(ctx, n, w, c, kw);
3371  if (!iters({Par(), Par(), Par(), Red()}))
3372  return rewriter.notifyMatchFailure(op,
3373  "failed to match pooling 3-par 1-red");
3374 
3375  // No transposition needed.
3376  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3377  /*rhsIndex*/ {kw},
3378  /*resIndex*/ {n, w, c}}))
3379  return conv(Conv1DOpOrder::Nwc);
3380 
3381  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3382  }
3383 
3384  /// Entry point that transposes into the common form:
3385  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3386  FailureOr<Operation *> generateNcwPooling() {
3387  AffineExpr n, w, c, kw;
3388  bindDims(ctx, n, c, w, kw);
3389  if (!iters({Par(), Par(), Par(), Red()}))
3390  return rewriter.notifyMatchFailure(op,
3391  "failed to match pooling 3-par 1-red");
3392 
3393  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3394  /*rhsIndex*/ {kw},
3395  /*resIndex*/ {n, c, w}}))
3396  return conv(Conv1DOpOrder::Ncw);
3397 
3398  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3399  }
3400 
3401  /// Entry point that transposes into the common form:
3402  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3403  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
3404  AffineExpr n, w, c, kw;
3405  bindDims(ctx, n, w, c, kw);
3406  if (!iters({Par(), Par(), Par(), Red()}))
3407  return rewriter.notifyMatchFailure(
3408  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3409 
3410  // No transposition needed.
3411  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3412  /*rhsIndex*/ {kw, c},
3413  /*resIndex*/ {n, w, c}}))
3414  return depthwiseConv(flatten);
3415 
3416  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3417  }
3418 
3419 private:
3420  enum OperKind { Conv, Pool };
3421  bool valid = false;
3422  OperKind oper = Conv;
3423  StringAttr redOp;
3424  StringAttr poolExtOp;
3425  bool isPoolExt = false;
3426  int strideW, dilationW;
3427  Value lhsShaped, rhsShaped, resShaped;
3428  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3429 
3430  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3431  // Returns true iff it is a valid conv/pooling op.
3432  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3433  // + yield) and rhs is not used) then it is the body of a pooling
3434  // If conv, check for single `mul` predecessor. The `mul` operands must be
3435  // block arguments or extension of block arguments.
3436  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3437  // must be block arguments or extension of block arguments.
3438  bool setOperKind(Operation *reduceOp) {
3439  int numBlockArguments = llvm::count_if(
3440  reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
3441  switch (numBlockArguments) {
3442  case 1: {
3443  // Will be convolution if feeder is a MulOp.
3444  // Otherwise, if it can be pooling.
3445  auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
3446  return !isa<BlockArgument>(v);
3447  });
3448  Operation *feedOp = (*feedValIt).getDefiningOp();
3449  if (isCastOfBlockArgument(feedOp)) {
3450  oper = Pool;
3451  isPoolExt = true;
3452  poolExtOp = feedOp->getName().getIdentifier();
3453  } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3454  llvm::all_of(feedOp->getOperands(), [](Value v) {
3455  if (isa<BlockArgument>(v))
3456  return true;
3457  if (Operation *op = v.getDefiningOp())
3458  return isCastOfBlockArgument(op);
3459  return false;
3460  }))) {
3461  return false;
3462  }
3463  return true;
3464  }
3465  case 2:
3466  // Must be pooling
3467  oper = Pool;
3468  isPoolExt = false;
3469  return true;
3470  default:
3471  return false;
3472  }
3473  }
3474 };
3475 } // namespace
3476 
3477 /// Helper function to vectorize a LinalgOp with convolution semantics.
3478 // TODO: extend the generic vectorization to support windows and drop this.
3480 vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
3481  bool flatten1DDepthwiseConv) {
3482  // The ConvolutionOpInterface gives us guarantees of existence for
3483  // strides/dilations. However, we do not need to rely on those, we can
3484  // simply use them if present, otherwise use the default and let the generic
3485  // conv. matcher in the ConvGenerator succeed or fail.
3486  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3487  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3488  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3489  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3490  Conv1DGenerator e(rewriter, op, stride, dilation);
3491  auto res = e.generateNonChanneledConv();
3492  if (succeeded(res))
3493  return res;
3494  res = e.generateNwcConv();
3495  if (succeeded(res))
3496  return res;
3497  res = e.generateNcwConv();
3498  if (succeeded(res))
3499  return res;
3500  res = e.generateNwcPooling();
3501  if (succeeded(res))
3502  return res;
3503  res = e.generateNcwPooling();
3504  if (succeeded(res))
3505  return res;
3506  return e.generateDilatedConv(flatten1DDepthwiseConv);
3507 }
3508 
3510  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3511 
3513  PatternRewriter &rewriter) const override {
3514  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3515  if (failed(resultOrFail))
3516  return failure();
3517  Operation *newOp = *resultOrFail;
3518  if (newOp->getNumResults() == 0) {
3519  rewriter.eraseOp(op.getOperation());
3520  return success();
3521  }
3522  assert(newOp->getNumResults() == 1 && "expected single result");
3523  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3524  return success();
3525  }
3526 };
3527 
3529  RewritePatternSet &patterns, PatternBenefit benefit) {
3530  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3531 }
static ArrayRef< int64_t > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
TODO: Use a matcher to check for a constant padding value.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract)
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op)
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles and (2) constant padding value into: masked_trans...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue)
Create a TransferReadOp from source with static shape readShape.
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:237
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:132
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:292
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:318
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:581
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:393
unsigned getNumInputs() const
Definition: AffineMap.cpp:389
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:139
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:542
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:611
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:289
OpListType & getOperations()
Definition: Block.h:134
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
IntegerType getI32Type()
Definition: Builders.cpp:83
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:150
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:671
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:631
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1375
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:215
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false)
Return success if the operation can be vectorized.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:635
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2299
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:650
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:779
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:755
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:63
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:685
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1335
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.