MLIR  17.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Support/LLVM.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <optional>
35 #include <type_traits>
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 #define DEBUG_TYPE "linalg-vectorization"
41 
42 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
44 
45 /// Try to vectorize `convOp` as a convolution.
47  LinalgOp convOp);
48 
49 /// Return the unique instance of OpType in `block` if it is indeed unique.
50 /// Return null if none or more than 1 instances exist.
51 template <typename OpType>
52 static OpType getSingleOpOfType(Block &block) {
53  OpType res;
54  block.walk([&](OpType op) {
55  if (res) {
56  res = nullptr;
57  return WalkResult::interrupt();
58  }
59  res = op;
60  return WalkResult::advance();
61  });
62  return res;
63 }
64 
65 /// Helper function to extract the input slices after filter is unrolled along
66 /// kw.
67 static SmallVector<Value>
69  int64_t nSize, int64_t wSize, int64_t cSize,
70  int64_t kwSize, int strideW, int dilationW,
71  int64_t wSizeStep, bool isSingleChanneled) {
72  SmallVector<Value> result;
73  if (isSingleChanneled) {
74  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
75  // convolution.
76  SmallVector<int64_t> sizes{wSizeStep};
77  SmallVector<int64_t> strides{1};
78  for (int64_t kw = 0; kw < kwSize; ++kw) {
79  for (int64_t w = 0; w < wSize; w += wSizeStep) {
80  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
81  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
82  }
83  }
84  } else {
85  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
86  // for channeled convolution.
87  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
88  SmallVector<int64_t> strides{1, 1, 1};
89  for (int64_t kw = 0; kw < kwSize; ++kw) {
90  for (int64_t w = 0; w < wSize; w += wSizeStep) {
91  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
92  loc, input,
93  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
94  sizes, strides));
95  }
96  }
97  }
98  return result;
99 }
100 
101 /// Helper function to extract the filter slices after filter is unrolled along
102 /// kw.
104  Location loc, Value filter,
105  int64_t kwSize) {
106  SmallVector<Value> result;
107  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
108  // non-chanelled convolution] @ [kw].
109  for (int64_t kw = 0; kw < kwSize; ++kw) {
110  result.push_back(rewriter.create<vector::ExtractOp>(
111  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
112  }
113  return result;
114 }
115 
116 /// Helper function to extract the result slices after filter is unrolled along
117 /// kw.
118 static SmallVector<Value>
120  int64_t nSize, int64_t wSize, int64_t fSize,
121  int64_t wSizeStep, bool isSingleChanneled) {
122  SmallVector<Value> result;
123  if (isSingleChanneled) {
124  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
125  SmallVector<int64_t> sizes{wSizeStep};
126  SmallVector<int64_t> strides{1};
127  for (int64_t w = 0; w < wSize; w += wSizeStep) {
128  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
129  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
130  }
131  } else {
132  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
133  // convolution.
134  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
135  SmallVector<int64_t> strides{1, 1, 1};
136  for (int64_t w = 0; w < wSize; w += wSizeStep) {
137  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
138  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
139  }
140  }
141  return result;
142 }
143 
144 /// Helper function to insert the computed result slices.
146  Value res, int64_t wSize, int64_t wSizeStep,
147  SmallVectorImpl<Value> &resVals,
148  bool isSingleChanneled) {
149 
150  if (isSingleChanneled) {
151  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
152  // This does not depend on kw.
153  SmallVector<int64_t> strides{1};
154  for (int64_t w = 0; w < wSize; w += wSizeStep) {
155  res = rewriter.create<vector::InsertStridedSliceOp>(
156  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
157  }
158  } else {
159  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
160  // convolution. This does not depend on kw.
161  SmallVector<int64_t> strides{1, 1, 1};
162  for (int64_t w = 0; w < wSize; w += wSizeStep) {
163  res = rewriter.create<vector::InsertStridedSliceOp>(
164  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
165  strides);
166  }
167  }
168  return res;
169 }
170 
171 /// Contains the vectorization state and related methods used across the
172 /// vectorization process of a given operation.
174  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
175 
176  /// Initializes the vectorization state, including the computation of the
177  /// canonical vector shape for vectorization.
178  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
179  ArrayRef<int64_t> inputVectorSizes);
180 
181  /// Returns the canonical vector shape used to vectorize the iteration space.
182  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
183 
184  /// Masks an operation with the canonical vector mask if the operation needs
185  /// masking. Returns the masked operation or the original operation if masking
186  /// is not needed. If provided, the canonical mask for this operation is
187  /// permuted using `maybeMaskingMap`.
188  Operation *
189  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
190  std::optional<AffineMap> maybeMaskingMap = std::nullopt);
191 
192 private:
193  /// Initializes the iteration space static sizes using the Linalg op
194  /// information. This may become more complicated in the future.
195  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
196  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
197  }
198 
199  /// Generates 'tensor.dim' operations for all the dynamic dimensions of the
200  /// iteration space to be vectorized and store them in
201  /// `iterSpaceDynamicSizes`.
202  LogicalResult precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
203  LinalgOp linalgOp);
204 
205  /// Create or retrieve an existing mask value to mask `opToMask` in the
206  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
207  /// permuted using that permutation map. If a new mask is created, it will be
208  /// cached for future users.
209  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
210  LinalgOp linalgOp,
211  std::optional<AffineMap> maybeMaskingMap);
212 
213  // Holds the compile-time static sizes of the iteration space to vectorize.
214  // Dynamic dimensions are represented using ShapedType::kDynamic.
215  SmallVector<int64_t> iterSpaceStaticSizes;
216 
217  /// Holds the runtime sizes of the iteration spaces to vectorize. Static
218  /// dimensions are represented with a empty value.
219  SmallVector<Value> iterSpaceDynamicSizes;
220 
221  /// Holds the canonical vector shape used to vectorize the iteration space.
222  SmallVector<int64_t> canonicalVecShape;
223 
224  /// Holds the active masks for permutations of the canonical vector iteration
225  /// space.
226  DenseMap<AffineMap, Value> activeMaskCache;
227 
228  /// Global vectorization guard for the incoming rewriter. It's initialized
229  /// when the vectorization state is initialized.
231 };
232 
233 /// Generates 'tensor.dim' operations for all the dynamic dimensions of the
234 /// iteration space to be vectorized and store them in
235 /// `iterSpaceDynamicSizes`.
237 VectorizationState::precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
238  LinalgOp linalgOp) {
239  // TODO: Support 0-d vectors.
240  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
241  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
242  // Add a empty value for static dimensions.
243  iterSpaceDynamicSizes.push_back(Value());
244  continue;
245  }
246 
247  // Find an operand defined on this dimension of the iteration space to
248  // extract the runtime dimension size.
249  Value operand;
250  unsigned operandDimPos;
251  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
252  operandDimPos)))
253  return failure();
254 
255  Value dynamicDim = linalgOp.hasTensorSemantics()
256  ? (Value)rewriter.create<tensor::DimOp>(
257  linalgOp.getLoc(), operand, operandDimPos)
258  : (Value)rewriter.create<memref::DimOp>(
259  linalgOp.getLoc(), operand, operandDimPos);
260  iterSpaceDynamicSizes.push_back(dynamicDim);
261  }
262 
263  return success();
264 }
265 
266 /// Initializes the vectorization state, including the computation of the
267 /// canonical vector shape for vectorization.
268 // TODO: Move this to the constructor when we can remove the failure cases.
270 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
271  ArrayRef<int64_t> inputVectorSizes) {
272  // Initialize the insertion point.
273  rewriter.setInsertionPoint(linalgOp);
274 
275  if (!inputVectorSizes.empty()) {
276  // Get the canonical vector shape from the input vector sizes provided. This
277  // path should be taken to vectorize code with dynamic shapes and when using
278  // vector sizes greater than the iteration space sizes.
279  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
280  } else {
281  // Compute the canonical vector shape from the operation shape. If there are
282  // dynamic shapes, the operation won't be vectorized.
283  canonicalVecShape = linalgOp.getStaticLoopRanges();
284  }
285 
286  LDBG("Canonical vector shape: ");
287  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
288  LLVM_DEBUG(llvm::dbgs() << "\n");
289 
290  if (ShapedType::isDynamicShape(canonicalVecShape))
291  return failure();
292 
293  // Initialize iteration space static sizes.
294  initIterSpaceStaticSizes(linalgOp);
295 
296  // Extract and register the runtime value of any potential dynamic shape
297  // needed to compute a mask during vectorization.
298  if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp)))
299  return failure();
300 
301  return success();
302 }
303 
304 /// Create or retrieve an existing mask value to mask `opToMask` in the
305 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
306 /// using that permutation map. If a new mask is created, it will be cached for
307 /// future users.
308 Value VectorizationState::getOrCreateMaskFor(
309  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
310  std::optional<AffineMap> maybeMaskingMap) {
311  // No mask is needed if the operation is not maskable.
312  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
313  if (!maskableOp)
314  return Value();
315 
316  assert(!maskableOp.isMasked() &&
317  "Masking an operation that is already masked");
318 
319  // If no masking map was provided, use an identity map with the loop dims.
320  assert((!maybeMaskingMap || *maybeMaskingMap) &&
321  "Unexpected null mask permutation map");
322  AffineMap maskingMap =
323  maybeMaskingMap ? *maybeMaskingMap
325  linalgOp.getNumLoops(), rewriter.getContext());
326 
327  LDBG("Masking map: " << maskingMap << "\n");
328 
329  // Return the active mask for the masking map of this operation if it was
330  // already created.
331  auto activeMaskIt = activeMaskCache.find(maskingMap);
332  if (activeMaskIt != activeMaskCache.end()) {
333  Value mask = activeMaskIt->second;
334  LDBG("Reusing mask: " << mask << "\n");
335  return mask;
336  }
337 
338  // Compute permuted projection of the iteration space to be masked and the
339  // corresponding mask shape. If the resulting iteration space dimensions are
340  // static and identical to the mask shape, masking is not needed for this
341  // operation.
342  // TODO: Improve this check. Only projected permutation indexing maps are
343  // supported.
344  SmallVector<int64_t> permutedStaticSizes =
345  applyPermutationMap(maskingMap, ArrayRef<int64_t>(iterSpaceStaticSizes));
346  SmallVector<int64_t> maskShape =
347  applyPermutationMap(maskingMap, ArrayRef<int64_t>(canonicalVecShape));
348  LDBG("Mask shape: ");
349  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
350  LLVM_DEBUG(llvm::dbgs() << "\n");
351 
352  if (permutedStaticSizes == maskShape) {
353  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
354  activeMaskCache[maskingMap] = Value();
355  return Value();
356  }
357 
358  // Compute the mask upper bound values by combining the permuted iteration
359  // space static sizes and the dynamic values.
360  SmallVector<Value> permutedDynamicSizes =
361  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceDynamicSizes));
362  SmallVector<Value> upperBounds;
363  for (auto [staticBound, dynBound] :
364  llvm::zip(permutedStaticSizes, permutedDynamicSizes))
365  upperBounds.push_back(ShapedType::isDynamic(staticBound)
366  ? dynBound
367  : rewriter.create<arith::ConstantIndexOp>(
368  linalgOp.getLoc(), staticBound));
369 
370  assert(!maskShape.empty() && !upperBounds.empty() &&
371  "Masked 0-d vectors are not supported yet");
372 
373  // Create the mask based on the dimension size values.
374  auto maskType = VectorType::get(maskShape, rewriter.getI1Type());
375  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
376  maskType, upperBounds);
377  LDBG("Creating new mask: " << mask << "\n");
378  activeMaskCache[maskingMap] = mask;
379  return mask;
380 }
381 
382 /// Masks an operation with the canonical vector mask if the operation needs
383 /// masking. Returns the masked operation or the original operation if masking
384 /// is not needed. If provided, the canonical mask for this operation is
385 /// permuted using `maybeMaskingMap`.
386 Operation *
388  LinalgOp linalgOp,
389  std::optional<AffineMap> maybeMaskingMap) {
390  LDBG("Trying to mask: " << *opToMask << "\n");
391 
392  // Create or retrieve mask for this operation.
393  Value mask =
394  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
395 
396  if (!mask) {
397  LDBG("No mask required\n");
398  return opToMask;
399  }
400 
401  // Wrap the operation with a new `vector.mask` and update D-U chain.
402  assert(opToMask && "Expected a valid operation to mask");
403  auto maskOp = cast<vector::MaskOp>(
404  mlir::vector::maskOperation(rewriter, opToMask, mask));
405  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
406 
407  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
408  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
409  maskOpTerminator);
410 
411  LDBG("Masked operation: " << *maskOp << "\n");
412  return maskOp;
413 }
414 
415 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
416 /// projectedPermutation, compress the unused dimensions to serve as a
417 /// permutation_map for a vector transfer operation.
418 /// For example, given a linalg op such as:
419 ///
420 /// ```
421 /// %0 = linalg.generic {
422 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
423 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
424 /// }
425 /// ins(%0 : tensor<2x3x4xf32>)
426 /// outs(%1 : tensor<5x6xf32>)
427 /// ```
428 ///
429 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
430 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
431 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
433  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
434  "expected projected permutation");
435  auto res = compressUnusedDims(map);
436  assert(res.getNumDims() == res.getNumResults() &&
437  "expected reindexed map with same number of dims and results");
438  return res;
439 }
440 
441 /// Helper enum to represent conv1d input traversal order.
442 enum class Conv1DOpOrder {
443  W, // Corresponds to non-channeled 1D convolution operation.
444  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
445  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
446 };
447 
448 /// Helper data structure to represent the result of vectorization.
449 /// In certain specific cases, like terminators, we do not want to propagate/
451  /// Op failed to vectorize.
452  Failure = 0,
453  /// Op vectorized and custom function took care of replacement logic
455  /// Op vectorized into a new Op whose results will replace original Op's
456  /// results.
457  NewOp
458  // TODO: support values if Op vectorized to Many-Ops whose results we need to
459  // aggregate for replacement.
460 };
462  /// Return status from vectorizing the current op.
464  /// New vectorized operation to replace the current op.
465  /// Replacement behavior is specified by `status`.
467 };
468 
469 std::optional<vector::CombiningKind>
471  using ::mlir::vector::CombiningKind;
472 
473  if (!combinerOp)
474  return std::nullopt;
476  .Case<arith::AddIOp, arith::AddFOp>(
477  [&](auto op) { return CombiningKind::ADD; })
478  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
479  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
480  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
481  .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
482  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
483  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
484  .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
485  .Case<arith::MulIOp, arith::MulFOp>(
486  [&](auto op) { return CombiningKind::MUL; })
487  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
488  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
489  .Default([&](auto op) { return std::nullopt; });
490 }
491 
492 /// Check whether `outputOperand` is a reduction with a single combiner
493 /// operation. Return the combiner operation of the reduction. Return
494 /// nullptr otherwise. Multiple reduction operations would impose an
495 /// ordering between reduction dimensions and is currently unsupported in
496 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
497 /// max(min(X))
498 // TODO: use in LinalgOp verification, there is a circular dependency atm.
499 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
500  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
501  unsigned outputPos =
502  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
503  // Only single combiner operations are supported for now.
504  SmallVector<Operation *, 4> combinerOps;
505  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
506  combinerOps.size() != 1)
507  return nullptr;
508 
509  // Return the combiner operation.
510  return combinerOps[0];
511 }
512 
513 /// Broadcast `value` to a vector of `shape` if possible. Return value
514 /// otherwise.
516  ArrayRef<int64_t> shape) {
517  // If no shape to broadcast to, just return `value`.
518  if (shape.empty())
519  return value;
520  VectorType targetVectorType =
521  VectorType::get(shape, getElementTypeOrSelf(value));
522  if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
524  return value;
525  Location loc = b.getInsertionPoint()->getLoc();
526  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
527 }
528 
529 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
530 /// assumes that `reductionOp` has two operands and one of them is the reduction
531 /// initial value.buildMultiDimReduce
532 // Note: this is a true builder that notifies the OpBuilder listener.
533 // TODO: Consider moving as a static helper on the ReduceOp.
535  Value valueToReduce, Value acc,
536  ArrayRef<bool> dimsToMask) {
537  auto maybeKind = getCombinerOpKind(reduceOp);
538  assert(maybeKind && "Failed precondition: could not get reduction kind");
539  return b.create<vector::MultiDimReductionOp>(
540  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
541 }
542 
543 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
544  return llvm::to_vector(
545  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
546 }
547 
548 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
549 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
550 /// currently being vectorized. If `dest` has null rank, build an memref.store.
551 /// Return the produced value or null if no value is produced.
552 // Note: this is a true builder that notifies the OpBuilder listener.
553 // TODO: Consider moving as a static helper on the ReduceOp.
554 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
555  OpOperand *outputOperand,
556  VectorizationState &state) {
557  Location loc = value.getLoc();
558  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
559  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
560  auto vectorType =
561  VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()),
562  getElementTypeOrSelf(outputOperand->get().getType()));
563 
564  Operation *write;
565  if (vectorType.getRank() > 0) {
566  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
567  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
568  rewriter.create<arith::ConstantIndexOp>(loc, 0));
569  value = broadcastIfNeeded(rewriter, value, vectorType.getShape());
570  write = rewriter.create<vector::TransferWriteOp>(
571  loc, value, outputOperand->get(), indices, writeMap);
572  } else {
573  // 0-d case is still special: do not invert the reindexing writeMap.
574  if (!value.getType().isa<VectorType>())
575  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
576  assert(value.getType() == vectorType && "incorrect type");
577  write = rewriter.create<vector::TransferWriteOp>(
578  loc, value, outputOperand->get(), ValueRange{});
579  }
580 
581  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
582 
583  // If masked, set in-bounds to true. Masking guarantees that the access will
584  // be in-bounds.
585  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
586  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
587  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
588  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
589  }
590 
591  LDBG("vectorized op: " << *write << "\n");
592  if (!write->getResults().empty())
593  return write->getResult(0);
594  return Value();
595 }
596 
597 // Custom vectorization precondition function type. This is intented to be used
598 // with CustomVectorizationHook. Returns success if the corresponding custom
599 // hook can vectorize the op.
601  std::function<LogicalResult(Operation *, bool)>;
602 
603 // Custom vectorization function type. Produce a vector form of Operation*
604 // assuming all its vectorized operands are already in the IRMapping.
605 // Return nullptr if the Operation cannot be vectorized.
607  std::function<VectorizationResult(Operation *, const IRMapping &)>;
608 
609 /// Helper function to vectorize the terminator of a `linalgOp`. New result
610 /// vector values are appended to `newResults`. Return
611 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
612 /// should not try to map produced operations and instead return the results
613 /// using the `newResults` vector making them available to the vectorization
614 /// algorithm for RAUW. This function is meant to be used as a
615 /// CustomVectorizationHook.
616 static VectorizationResult
618  const IRMapping &bvm, VectorizationState &state,
619  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
620  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
621  if (!yieldOp)
623  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
624  // TODO: Scan for an opportunity for reuse.
625  // TODO: use a map.
626  Value vectorValue = bvm.lookup(output.value());
627  Value newResult =
628  buildVectorWrite(rewriter, vectorValue,
629  linalgOp.getDpsInitOperand(output.index()), state);
630  if (newResult)
631  newResults.push_back(newResult);
632  }
633 
635 }
636 
637 /// Helper function to vectorize the index operations of a `linalgOp`. Return
638 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
639 /// should map the produced operations. This function is meant to be used as a
640 /// CustomVectorizationHook.
642  VectorizationState &state,
643  Operation *op,
644  LinalgOp linalgOp) {
645  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
646  if (!indexOp)
648  auto loc = indexOp.getLoc();
649  // Compute the static loop sizes of the index op.
650  auto targetShape = llvm::to_vector(state.getCanonicalVecShape());
651  // Compute a one-dimensional index vector for the index op dimension.
652  SmallVector<int64_t> constantSeq =
653  llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
654  auto constantOp = rewriter.create<arith::ConstantOp>(
655  loc, rewriter.getIndexVectorAttr(constantSeq));
656  // Return the one-dimensional index vector if it lives in the trailing
657  // dimension of the iteration space since the vectorization algorithm in this
658  // case can handle the broadcast.
659  if (indexOp.getDim() == targetShape.size() - 1)
661  // Otherwise permute the targetShape to move the index dimension last,
662  // broadcast the one-dimensional index vector to the permuted shape, and
663  // finally transpose the broadcasted index vector to undo the permutation.
664  std::swap(targetShape[indexOp.getDim()], targetShape.back());
665  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
666  loc, VectorType::get(targetShape, rewriter.getIndexType()), constantOp);
667  SmallVector<int64_t> transposition =
668  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
669  std::swap(transposition.back(), transposition[indexOp.getDim()]);
670  auto transposeOp =
671  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
672  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
673 }
674 
675 /// Helper function to check if the tensor.extract can be vectorized by the
676 /// custom hook vectorizeTensorExtract.
677 static LogicalResult
678 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
679  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
680  if (!extractOp)
681  return failure();
682 
683  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
684  return failure();
685 
686  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
687  return failure();
688 
689  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
690  return !VectorType::isValidElementType(type);
691  })) {
692  return failure();
693  }
694 
695  return success();
696 }
697 
698 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
699 /// generated from `tensor.extract`. The offset is calculated as follows
700 /// (example using scalar values):
701 ///
702 /// offset = extractOp.indices[0]
703 /// for (i = 1; i < numIndices; i++)
704 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
705 ///
706 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
707 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
709  tensor::ExtractOp extractOp,
710  const IRMapping &bvm,
711  const ArrayRef<int64_t> targetShape) {
712  // The vector of indices for GatherOp should be shaped as the output vector
713  auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType());
714  auto loc = extractOp.getLoc();
715 
716  Value offset = broadcastIfNeeded(
717  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape());
718 
719  const size_t numIndices = extractOp.getIndices().size();
720  for (size_t i = 1; i < numIndices; i++) {
721  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
722 
723  auto dimSize = broadcastIfNeeded(
724  rewriter,
725  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
726  indexVecType.getShape());
727 
728  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
729 
730  auto extractOpIndex =
731  broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]),
732  indexVecType.getShape());
733 
734  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
735  }
736 
737  return offset;
738 }
739 
741  // TODO: ScalarBroadcast,
743  Gather
744 };
745 
746 /// Checks whether /p val can be used for calculating a loop invariant index.
747 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
748 
749  auto targetShape = linalgOp.getStaticLoopRanges();
750  assert(((llvm::count_if(targetShape,
751  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
752  "n-D vectors are not yet supported");
753  assert(targetShape.back() != 1 &&
754  "1-D vectors with the trailing dim eqaual 1 are not yet supported");
755 
756  // Blocks outside _this_ linalg.generic are effectively loop invariant.
757  // However, analysing block arguments for _this_ linalg.generic Op is a bit
758  // tricky. Just bail out in the latter case.
759  // TODO: We could try analysing the corresponding affine map here.
760  auto *block = linalgOp.getBlock();
761  if (isa<BlockArgument>(val))
762  return llvm::all_of(block->getArguments(),
763  [&val](Value v) { return (v != val); });
764 
765  Operation *defOp = val.getDefiningOp();
766  assert(defOp && "This is neither a block argument nor an operation result");
767 
768  // IndexOp is loop invariant as long as its result remains constant across
769  // iterations. Given the assumptions on the loop ranges above, only the
770  // trailing loop dim ever changes.
771  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
772  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
773  return (indexOp.getDim() != trailingLoopDim);
774 
775  auto *ancestor = block->findAncestorOpInBlock(*defOp);
776 
777  // Values define outside `linalgOp` are loop invariant.
778  if (!ancestor)
779  return true;
780 
781  // Values defined inside `linalgOp`, which are constant, are loop invariant.
782  if (isa<arith::ConstantOp>(ancestor))
783  return true;
784 
785  bool result = true;
786  for (auto op : ancestor->getOperands())
787  result &= isLoopInvariantIdx(linalgOp, op);
788 
789  return result;
790 }
791 
792 /// Check whether \p val could be used for calculating the trailing index for a
793 /// contiguous load operation.
794 ///
795 /// There are currently 3 types of values that are allowed here:
796 /// 1. loop-invariant values,
797 /// 2. values that increment by 1 with every loop iteration,
798 /// 3. results of basic arithmetic operations (linear and continuous)
799 /// involving 1., 2. and 3.
800 /// This method returns True if indeed only such values are used in calculating
801 /// \p val.
802 ///
803 /// Additionally, the trailing index for a contiguous load operation should
804 /// increment by 1 with every loop iteration, i.e. be based on:
805 /// * `linalg.index <dim>` ,
806 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
807 /// updated to `true` when such an op is found.
808 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
809  bool &foundIndexOp) {
810 
811  auto targetShape = linalgOp.getStaticLoopRanges();
812  assert(((llvm::count_if(targetShape,
813  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
814  "n-D vectors are not yet supported");
815  assert(targetShape.back() != 1 &&
816  "1-D vectors with the trailing dim 1 are not yet supported");
817 
818  // Blocks outside _this_ linalg.generic are effectively loop invariant.
819  // However, analysing block arguments for _this_ linalg.generic Op is a bit
820  // tricky. Just bail out in the latter case.
821  // TODO: We could try analysing the corresponding affine map here.
822  auto *block = linalgOp.getBlock();
823  if (isa<BlockArgument>(val))
824  return llvm::all_of(block->getArguments(),
825  [&val](Value v) { return (v != val); });
826 
827  Operation *defOp = val.getDefiningOp();
828  assert(defOp && "This is neither a block argument nor an operation result");
829 
830  // Given the assumption on the loop ranges above, only the trailing loop
831  // index is not constant.
832  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
833  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
834  foundIndexOp = (indexOp.getDim() == trailingLoopDim);
835  return true;
836  }
837 
838  auto *ancestor = block->findAncestorOpInBlock(*defOp);
839 
840  if (!ancestor)
841  return false;
842 
843  // Conservatively reject Ops that could lead to indices with stride other
844  // than 1.
845  if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
846  ancestor))
847  return false;
848 
849  bool result = false;
850  for (auto op : ancestor->getOperands())
851  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
852 
853  return result;
854 }
855 
856 /// Check whether \p extractOp would be a gather or a contiguous load Op after
857 /// vectorising \p linalgOp. Note that it is always safe to use gather load
858 /// operations for contiguous loads (albeit slow), but not vice-versa. When in
859 /// doubt, bail out and assume that \p extractOp is a gather load.
861 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
862  LinalgOp &linalgOp) {
863 
864  auto targetShape = linalgOp.getStaticLoopRanges();
865 
866  // 1. Assume that it's a gather load when reading _into_:
867  // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
868  // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
869  // TODO: Relax these conditions.
870  if ((llvm::count_if(targetShape,
871  [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
872  targetShape.back() == 1)
874 
875  auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
876 
877  // 2. Assume that it's a gather load when reading _from_ a tensor for which
878  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
879  // TODO: Relax this condition.
880  if (inputShape.getShape().back() == 1)
882 
883  bool isContiguous = true;
884 
885  // 3a. Analyze the leading indices of `extractOp`.
886  // Look at the way each index is calculated and decide whether it is suitable
887  // for a contiguous load, i.e. whether it's loop invariant.
888  auto indices = extractOp.getIndices();
889  auto leadIndices = ValueRange(indices.drop_back(1));
890 
891  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
892  if (inputShape.getShape()[i] == 1)
893  continue;
894 
895  isContiguous &= isLoopInvariantIdx(linalgOp, indexVal);
896  }
897 
898  // 3b. Analyze the trailing index for `extractOp`.
899  auto extractOpTrailingIdx = indices.back();
900  // For contiguous loads, the trailing `extractOp` index should increment with
901  // every loop iteration. This effectively means that it must be based on the
902  // trailing loop index. This is what the following bool captures.
903  bool foundIndexOp = false;
904  isContiguous &=
905  isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
906  isContiguous &= foundIndexOp;
907 
908  if (isContiguous) {
909  LDBG("Found contigous load: " << extractOp);
911  }
912 
914 }
915 
916 /// Helper function to vectorize the tensor.extract operations. Returns
917 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
918 /// should map the produced operations. This function is meant to be used as a
919 /// CustomVectorizationHook.
920 static VectorizationResult
922  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
923  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
924  if (!extractOp)
926  auto loc = extractOp.getLoc();
927 
928  // Compute the static loop sizes of the extract op.
929  auto targetShape = state.getCanonicalVecShape();
930 
931  auto resultType =
932  VectorType::get(targetShape, extractOp.getResult().getType());
933  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
935  VectorType::get(targetShape, rewriter.getI1Type()),
936  /*value=*/true));
937  auto passThruConstantOp =
938  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
939 
940  // Base indices are currently set to 0. We will need to re-visit if more
941  // generic scenarios are to be supported.
942  SmallVector<Value> baseIndices(
943  extractOp.getIndices().size(),
944  rewriter.create<arith::ConstantIndexOp>(loc, 0));
945 
946  VectorMemoryAccessKind memAccessKind =
947  getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
948 
949  // 1. Handle gather access
950  if (memAccessKind == VectorMemoryAccessKind::Gather) {
951  Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
952 
953  // Generate the gather load
954  Operation *gatherOp = rewriter.create<vector::GatherOp>(
955  loc, resultType, extractOp.getTensor(), baseIndices, offset,
956  maskConstantOp, passThruConstantOp);
957  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
958 
959  LDBG("Vectorised as gather load: " << extractOp);
961  }
962 
963  // 2. Handle contiguous access.
964  SmallVector<Value> transferReadIdxs;
965  auto resTrailingDim = resultType.getShape().back();
966  auto zero = rewriter.create<arith::ConstantOp>(
967  loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
968 
969  // Collect indices for `vector.transfer_read`. At this point, the indices will
970  // either be scalars or would have been broadcast to vectors matching the
971  // result type. For indices that are vectors, there are two options:
972  // * for non-trailing indices, all elements are identical (contiguous
973  // loads are identified by looking for non-trailing indices that are
974  // invariant with respect to the corresponding linalg.generic), or
975  // * for trailing indices, the index vector will contain values with stride
976  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
977  // needed.
978  // This means that
979  // * for scalar indices - just re-use it,
980  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
981  // (0th) element and use that.
982  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
983  auto idx = bvm.lookup(extractOp.getIndices()[i]);
984  if (idx.getType().isIndex()) {
985  transferReadIdxs.push_back(idx);
986  continue;
987  }
988 
989  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
990  loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
991  bvm.lookup(extractOp.getIndices()[i]));
992  transferReadIdxs.push_back(
993  rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
994  }
995 
996  // `tensor.extract_element` is always in-bounds, hence the following holds.
997  SmallVector<bool> inBounds(resultType.getRank(), true);
998 
999  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1000  loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
1001 
1002  LDBG("Vectorised as contiguous load: " << extractOp);
1003  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1004 }
1005 
1006 /// Emit reduction operations if the shapes of the value to reduce is different
1007 /// that the result shape.
1008 // Note: this is a true builder that notifies the OpBuilder listener.
1009 // TODO: Consider moving as a static helper on the ReduceOp.
1010 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1011  Value reduceValue, Value initialValue,
1012  const IRMapping &bvm) {
1013  Value reduceVec = bvm.lookup(reduceValue);
1014  Value outputVec = bvm.lookup(initialValue);
1015  auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
1016  auto outputType = outputVec.getType().dyn_cast<VectorType>();
1017  // Reduce only if needed as the value may already have been reduce for
1018  // contraction vectorization.
1019  if (!reduceType ||
1020  (outputType && reduceType.getShape() == outputType.getShape()))
1021  return nullptr;
1022  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1023  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1024 }
1025 
1026 /// Generic vectorization for a single operation `op`, given already vectorized
1027 /// operands carried by `bvm`. Vectorization occurs as follows:
1028 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1029 /// result on success.
1030 /// 2. Clone any constant in the current scope without vectorization: each
1031 /// consumer of the constant will later determine the shape to which the
1032 /// constant needs to be broadcast to.
1033 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1034 /// of the `customVectorizationHooks` to cover such cases.
1035 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1036 /// operand of maximal rank. Other operands have smaller rank and are
1037 /// broadcast accordingly. It is assumed this broadcast is always legal,
1038 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1039 ///
1040 /// This function assumes all operands of `op` have been vectorized and are in
1041 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1042 /// a topologically-sorted list of ops.
1043 /// This function does not update `bvm` but returns a VectorizationStatus that
1044 /// instructs the caller what `bvm` update needs to occur.
1045 static VectorizationResult
1046 vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
1047  const IRMapping &bvm,
1048  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1049  LDBG("vectorize op " << *op << "\n");
1050 
1051  // 1. Try to apply any CustomVectorizationHook.
1052  if (!customVectorizationHooks.empty()) {
1053  for (auto &customFunc : customVectorizationHooks) {
1054  VectorizationResult result = customFunc(op, bvm);
1055  if (result.status == VectorizationStatus::Failure)
1056  continue;
1057  return result;
1058  }
1059  }
1060 
1061  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1062  // Clone so that the constant is not confined to the linalgOp block .
1063  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1064  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1065 
1066  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1069 
1070  // 4 . Check if the operation is a reduction.
1071  SmallVector<std::pair<Value, Value>> reductionOperands;
1072  for (Value operand : op->getOperands()) {
1073  auto blockArg = operand.dyn_cast<BlockArgument>();
1074  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1075  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1076  continue;
1077  SmallVector<Operation *> reductionOps;
1078  Value reduceValue = matchReduction(
1079  linalgOp.getRegionOutputArgs(),
1080  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1081  if (!reduceValue)
1082  continue;
1083  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1084  }
1085  if (!reductionOperands.empty()) {
1086  assert(reductionOperands.size() == 1);
1087  Operation *reduceOp =
1088  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1089  reductionOperands[0].second, bvm);
1090  if (reduceOp)
1092  }
1093 
1094  // 5. Generic vectorization path for ElementwiseMappable ops.
1095  // a. first get the first max ranked shape.
1096  SmallVector<int64_t, 4> firstMaxRankedShape;
1097  for (Value operand : op->getOperands()) {
1098  auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
1099  if (vt && firstMaxRankedShape.size() < vt.getShape().size())
1100  firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
1101  }
1102  // rewriter. broadcast each op if needed.
1103  auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
1104  return firstMaxRankedShape.empty()
1105  ? bvm.lookup(v)
1106  : broadcastIfNeeded(rewriter, bvm.lookup(v),
1107  firstMaxRankedShape);
1108  });
1109  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1110  auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
1111  return firstMaxRankedShape.empty()
1112  ? t
1113  : VectorType::get(firstMaxRankedShape, t);
1114  });
1115 
1116  // Build and return the new op.
1117  return VectorizationResult{
1119  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1120  llvm::to_vector<4>(vectorizedOperands),
1121  llvm::to_vector<4>(returnTypes), op->getAttrs())};
1122 }
1123 
1124 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1125 /// vector form. Generic vectorization proceeds as follows:
1126 /// 1. Verify the `linalgOp` has one non-empty region.
1127 /// 2. Values defined above the region are mapped to themselves and will be
1128 /// broadcasted on a per-need basis by their consumers.
1129 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1130 /// load).
1131 /// TODO: Reuse opportunities for RAR dependencies.
1132 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1133 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1134 /// iteration indices.
1135 /// 5. Iteratively call vectorizeOneOp on the region operations.
1136 ///
1137 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1138 /// performed to the maximal common vector size implied by the `linalgOp`
1139 /// iteration space. This eager broadcasting is introduced in the
1140 /// permutation_map of the vector.transfer_read operations. The eager
1141 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1142 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1143 /// the absence of good canonicalizations, the amount of work increases.
1144 /// This is not deemed a problem as we expect canonicalizations and foldings to
1145 /// aggressively clean up the useless work.
1146 static LogicalResult
1148  LinalgOp linalgOp,
1149  SmallVectorImpl<Value> &newResults) {
1150  LDBG("Vectorizing operation as linalg generic\n");
1151  Block *block = linalgOp.getBlock();
1152 
1153  // 2. Values defined above the region can only be broadcast for now. Make them
1154  // map to themselves.
1155  IRMapping bvm;
1156  SetVector<Value> valuesSet;
1157  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1158  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1159 
1160  if (linalgOp.getNumDpsInits() == 0)
1161  return failure();
1162 
1163  // 3. Turn all BBArgs into vector.transfer_read / load.
1164  Location loc = linalgOp.getLoc();
1165  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1166  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1167  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1168  if (linalgOp.isScalar(opOperand)) {
1169  bvm.map(bbarg, opOperand->get());
1170  continue;
1171  }
1172 
1173  // 3.a. Convert the indexing map for this input/output to a transfer read
1174  // permutation map and masking map.
1175  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1176 
1177  // Remove zeros from indexing map to use it as masking map.
1178  SmallVector<int64_t> zeroPos;
1179  auto results = indexingMap.getResults();
1180  for (const auto &result : llvm::enumerate(results)) {
1181  if (result.value().isa<AffineConstantExpr>()) {
1182  zeroPos.push_back(result.index());
1183  }
1184  }
1185  AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1186 
1187  AffineMap readMap;
1188  SmallVector<int64_t> readVecShape;
1189  if (linalgOp.isDpsInput(opOperand)) {
1190  // 3.a.i. For input reads we use the canonical vector shape.
1191  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1192  readVecShape = llvm::to_vector(state.getCanonicalVecShape());
1193  } else {
1194  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1195  // reductions), the vector shape is computed by mapping the canonical
1196  // vector shape to the output domain and back to the canonical domain.
1197  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1198  readVecShape =
1199  readMap.compose(indexingMap.compose(state.getCanonicalVecShape()));
1200  }
1201 
1202  auto readType =
1203  VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
1204  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1205 
1206  Operation *read = rewriter.create<vector::TransferReadOp>(
1207  loc, readType, opOperand->get(), indices, readMap);
1208  read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1209  Value readValue = read->getResult(0);
1210 
1211  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1212  // will be in-bounds.
1213  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1214  SmallVector<bool> inBounds(readType.getRank(), true);
1215  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1216  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1217  }
1218 
1219  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1220  // TODO: remove this.
1221  if (readValue.getType().cast<VectorType>().getRank() == 0)
1222  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1223 
1224  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1225  << "\n");
1226  bvm.map(bbarg, readValue);
1227  bvm.map(opOperand->get(), readValue);
1228  }
1229 
1231  // 4a. Register CustomVectorizationHook for yieldOp.
1232  CustomVectorizationHook vectorizeYield =
1233  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1234  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1235  };
1236  hooks.push_back(vectorizeYield);
1237 
1238  // 4b. Register CustomVectorizationHook for indexOp.
1239  CustomVectorizationHook vectorizeIndex =
1240  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1241  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1242  };
1243  hooks.push_back(vectorizeIndex);
1244 
1245  // 4c. Register CustomVectorizationHook for extractOp.
1246  CustomVectorizationHook vectorizeExtract =
1247  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1248  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1249  };
1250  hooks.push_back(vectorizeExtract);
1251 
1252  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1253  for (Operation &op : block->getOperations()) {
1254  VectorizationResult result =
1255  vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks);
1256  if (result.status == VectorizationStatus::Failure) {
1257  LDBG("failed to vectorize: " << op << "\n");
1258  return failure();
1259  }
1260  if (result.status == VectorizationStatus::NewOp) {
1261  Operation *maybeMaskedOp =
1262  state.maskOperation(rewriter, result.newOp, linalgOp);
1263  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1264  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1265  }
1266  }
1267 
1268  return success();
1269 }
1270 
1271 // TODO: probably need some extra checks for reduction followed by consumer
1272 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1274  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1275  LDBG("reduction precondition failed: no reduction iterator\n");
1276  return failure();
1277  }
1278  for (OpOperand *opOperand : op.getDpsInitOperands()) {
1279  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1280  if (indexingMap.isPermutation())
1281  continue;
1282 
1283  Operation *reduceOp = matchLinalgReduction(opOperand);
1284  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1285  LDBG("reduction precondition failed: reduction detection failed\n");
1286  return failure();
1287  }
1288  }
1289  return success();
1290 }
1291 
1293  // TODO: Masking only supports dynamic generic ops for now.
1294  if (!isa<linalg::GenericOp, linalg::FillOp>(op))
1295  return failure();
1296 
1297  // TODO: Index vectorization assumes static shape.
1298  if (op.hasIndexSemantics())
1299  return failure();
1300 
1301  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1302  return success();
1303 }
1304 
1307  ArrayRef<int64_t> inputVectorSizes,
1308  bool vectorizeNDExtract) {
1309  // tensor with dimension of 0 cannot be vectorized.
1310  if (llvm::any_of(linalgOp.getStaticShape(),
1311  [](int64_t dim) { return dim == 0; }))
1312  return failure();
1313  // Check API contract for input vector sizes.
1314  if (!inputVectorSizes.empty()) {
1315  assert(inputVectorSizes.size() == linalgOp.getNumLoops() &&
1316  "Input vector sizes don't match the number of loops");
1317  assert(!ShapedType::isDynamicShape(inputVectorSizes) &&
1318  "Input vector sizes can't have dynamic dimensions");
1319  assert(
1320  llvm::all_of(
1321  llvm::zip(linalgOp.getStaticLoopRanges(), inputVectorSizes),
1322  [](std::tuple<int64_t, int64_t> sizePair) {
1323  int64_t staticSize = std::get<0>(sizePair);
1324  int64_t inputSize = std::get<1>(sizePair);
1325  return ShapedType::isDynamic(staticSize) ||
1326  staticSize <= inputSize;
1327  }) &&
1328  "Input vector sizes must be greater than or equal to iteration space "
1329  "static sizes");
1330  }
1331 
1332  if (linalgOp.hasDynamicShape() &&
1334  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1335  return failure();
1336  }
1337 
1339 
1340  // Register CustomVectorizationPrecondition for extractOp.
1341  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1342 
1343  // All types in the body should be a supported element type for VectorType.
1344  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1345  // Check if any custom hook can vectorize the inner op.
1346  if (llvm::any_of(
1347  customPreconditions,
1348  [&](const CustomVectorizationPrecondition &customPrecondition) {
1349  return succeeded(
1350  customPrecondition(&innerOp, vectorizeNDExtract));
1351  })) {
1352  continue;
1353  }
1354  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1355  return !VectorType::isValidElementType(type);
1356  })) {
1357  return failure();
1358  }
1359  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1360  return !VectorType::isValidElementType(type);
1361  })) {
1362  return failure();
1363  }
1364  }
1365  if (isElementwise(linalgOp))
1366  return success();
1367  // TODO: isaConvolutionOpInterface that can also infer from generic features.
1368  // But we will still need stride/dilation attributes that will be annoying to
1369  // reverse-engineer...
1370  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1371  return success();
1372  // TODO: the common vector shape is equal to the static loop sizes only when
1373  // all indexing maps are projected permutations. For convs and stencils the
1374  // logic will need to evolve.
1375  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1376  LDBG("precondition failed: not projected permutations\n");
1377  return failure();
1378  }
1379  if (failed(reductionPreconditions(linalgOp))) {
1380  LDBG("precondition failed: reduction preconditions\n");
1381  return failure();
1382  }
1383  return success();
1384 }
1385 
1386 /// Converts affine.apply Ops to arithmetic operations.
1387 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1388  OpBuilder::InsertionGuard g(rewriter);
1389  auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
1390 
1391  for (auto op : make_early_inc_range(toReplace)) {
1392  rewriter.setInsertionPoint(op);
1393  auto expanded = expandAffineExpr(
1394  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
1395  op.getOperands().take_front(op.getAffineMap().getNumDims()),
1396  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
1397  rewriter.replaceOp(op, expanded);
1398  }
1399 }
1400 
1401 /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
1402 /// are used to vectorize this operation. `inputVectorSizes` must match the rank
1403 /// of the iteration space of the operation and the input vector sizes must be
1404 /// greater than or equal to their counterpart iteration space sizes, if static.
1405 /// `inputVectorShapes` also allows the vectorization of operations with dynamic
1406 /// shapes.
1408  ArrayRef<int64_t> inputVectorSizes,
1409  bool vectorizeNDExtract) {
1410  LDBG("Attempting to vectorize:\n" << linalgOp << "\n");
1411  LDBG("Input vector sizes: ");
1412  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1413  LLVM_DEBUG(llvm::dbgs() << "\n");
1414 
1415  if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1416  vectorizeNDExtract))) {
1417  LDBG("Vectorization pre-conditions failed\n");
1418  return failure();
1419  }
1420 
1421  // Initialize vectorization state.
1422  VectorizationState state(rewriter);
1423  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) {
1424  LDBG("Vectorization state couldn't be initialized\n");
1425  return failure();
1426  }
1427 
1428  SmallVector<Value> results;
1429  // TODO: isaConvolutionOpInterface that can also infer from generic
1430  // features. Will require stride/dilation attributes inference.
1431  FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
1432  if (succeeded(convOr)) {
1433  llvm::append_range(results, (*convOr)->getResults());
1434  } else {
1435  if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1436  vectorizeNDExtract)))
1437  return failure();
1438  LDBG("Vectorize generic by broadcasting to the canonical vector shape\n");
1439 
1440  // Pre-process before proceeding.
1441  convertAffineApply(rewriter, linalgOp);
1442 
1443  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to
1444  // 'OpBuilder' when it is passed over to some methods like
1445  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op
1446  // within these methods, the actual rewriter won't be notified and we will
1447  // end up with read-after-free issues!
1448  if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results)))
1449  return failure();
1450  }
1451 
1452  if (!results.empty())
1453  rewriter.replaceOp(linalgOp, results);
1454  else
1455  rewriter.eraseOp(linalgOp);
1456 
1457  return success();
1458 }
1459 
1461  memref::CopyOp copyOp) {
1462 
1463  auto srcType = copyOp.getSource().getType().cast<MemRefType>();
1464  auto dstType = copyOp.getTarget().getType().cast<MemRefType>();
1465  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
1466  return failure();
1467 
1468  auto srcElementType = getElementTypeOrSelf(srcType);
1469  auto dstElementType = getElementTypeOrSelf(dstType);
1470  if (!VectorType::isValidElementType(srcElementType) ||
1471  !VectorType::isValidElementType(dstElementType))
1472  return failure();
1473 
1474  auto readType = VectorType::get(srcType.getShape(), srcElementType);
1475  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
1476 
1477  Location loc = copyOp->getLoc();
1478  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1479  SmallVector<Value> indices(srcType.getRank(), zero);
1480 
1481  Value readValue = rewriter.create<vector::TransferReadOp>(
1482  loc, readType, copyOp.getSource(), indices,
1483  rewriter.getMultiDimIdentityMap(srcType.getRank()));
1484  if (readValue.getType().cast<VectorType>().getRank() == 0) {
1485  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1486  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
1487  }
1488  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
1489  loc, readValue, copyOp.getTarget(), indices,
1490  rewriter.getMultiDimIdentityMap(srcType.getRank()));
1491  rewriter.replaceOp(copyOp, writeValue->getResults());
1492  return success();
1493 }
1494 
1495 //----------------------------------------------------------------------------//
1496 // Misc. vectorization patterns.
1497 //----------------------------------------------------------------------------//
1498 
1499 /// Helper function that retrieves the value of an IntegerAttr.
1500 static int64_t getIntFromAttr(Attribute attr) {
1501  return attr.cast<IntegerAttr>().getInt();
1502 }
1503 
1504 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
1505 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
1506 /// not supported.
1508  ArrayRef<OpFoldResult> ofrs) {
1509  SmallVector<Value> result;
1510  for (auto o : ofrs) {
1511  if (auto val = o.template dyn_cast<Value>()) {
1512  result.push_back(val);
1513  } else {
1514  result.push_back(rewriter.create<arith::ConstantIndexOp>(
1515  loc, getIntFromAttr(o.template get<Attribute>())));
1516  }
1517  }
1518  return result;
1519 }
1520 
1521 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
1522 /// InsertSliceOp. For now, only constant padding values are supported.
1523 /// If there is enough static type information, TransferReadOps and
1524 /// TransferWriteOps may be generated instead of InsertSliceOps.
1527  PatternBenefit benefit = 1)
1528  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
1529  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
1530  /// each dimension size is statically know in the source type or the result
1531  /// type (or both).
1533  tensor::PadOp padOp, Value dest) {
1534  auto sourceType = padOp.getSourceType();
1535  auto resultType = padOp.getResultType();
1536  if (!VectorType::isValidElementType(sourceType.getElementType()))
1537  return failure();
1538 
1539  // Copy cannot be vectorized if pad value is non-constant and source shape
1540  // is dynamic. In case of a dynamic source shape, padding must be appended
1541  // by TransferReadOp, but TransferReadOp supports only constant padding.
1542  auto padValue = padOp.getConstantPaddingValue();
1543  if (!padValue) {
1544  if (!sourceType.hasStaticShape())
1545  return failure();
1546  // Create dummy padding value.
1547  auto elemType = sourceType.getElementType();
1548  padValue = rewriter.create<arith::ConstantOp>(
1549  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
1550  }
1551 
1552  SmallVector<int64_t> vecShape;
1553  SmallVector<bool> readInBounds;
1554  SmallVector<bool> writeInBounds;
1555  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
1556  if (!sourceType.isDynamicDim(i)) {
1557  vecShape.push_back(sourceType.getDimSize(i));
1558  // Source shape is statically known: Neither read nor write are
1559  // out-of- bounds.
1560  readInBounds.push_back(true);
1561  writeInBounds.push_back(true);
1562  } else if (!resultType.isDynamicDim(i)) {
1563  // Source shape is not statically known, but result shape is.
1564  // Vectorize with size of result shape. This may be larger than the
1565  // source size.
1566  vecShape.push_back(resultType.getDimSize(i));
1567  // Read may be out-of-bounds because the result size could be larger
1568  // than the source size.
1569  readInBounds.push_back(false);
1570  // Write is out-of-bounds if low padding > 0.
1571  writeInBounds.push_back(
1572  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
1573  static_cast<int64_t>(0));
1574  } else {
1575  // Neither source nor result dim of padOp is static. Cannot vectorize
1576  // the copy.
1577  return failure();
1578  }
1579  }
1580  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
1581 
1582  // Generate TransferReadOp.
1583  SmallVector<Value> readIndices(
1584  vecType.getRank(),
1585  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1586  auto read = rewriter.create<vector::TransferReadOp>(
1587  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
1588  ArrayRef<bool>{readInBounds});
1589 
1590  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
1591  // entire tensor, write directly to the FillOp's operand.
1592  if (llvm::equal(vecShape, resultType.getShape()) &&
1593  llvm::all_of(writeInBounds, [](bool b) { return b; }))
1594  if (auto fill = dest.getDefiningOp<FillOp>())
1595  dest = fill.output();
1596 
1597  // Generate TransferWriteOp.
1598  auto writeIndices =
1599  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
1600  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1601  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
1602 
1603  return success();
1604  }
1605 };
1606 
1607 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
1608 /// given operation type OpTy.
1609 template <typename OpTy>
1610 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
1612 
1613  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1614  PatternRewriter &rewriter) const final {
1615  bool changed = false;
1616  // Insert users in vector, because some users may be replaced/removed.
1617  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
1618  if (auto op = dyn_cast<OpTy>(user))
1619  changed |= rewriteUser(rewriter, padOp, op).succeeded();
1620  return success(changed);
1621  }
1622 
1623 protected:
1625  tensor::PadOp padOp, OpTy op) const = 0;
1626 };
1627 
1628 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
1629 /// ```
1630 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
1631 /// %r = vector.transfer_read %0[%c0, %c0], %cst
1632 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
1633 /// ```
1634 /// is rewritten to:
1635 /// ```
1636 /// %r = vector.transfer_read %src[%c0, %c0], %padding
1637 /// {in_bounds = [true, true]}
1638 /// : tensor<?x?xf32>, vector<17x5xf32>
1639 /// ```
1640 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
1641 /// sure that the original padding value %cst was never used.
1642 ///
1643 /// This rewrite is possible if:
1644 /// - `xferOp` has no out-of-bounds dims or mask.
1645 /// - Low padding is static 0.
1646 /// - Single, scalar padding value.
1648  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
1650  vector::TransferReadOp>::VectorizePadOpUserPattern;
1651 
1652  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
1653  vector::TransferReadOp xferOp) const override {
1654  // Low padding must be static 0.
1655  if (!padOp.hasZeroLowPad())
1656  return failure();
1657  // Pad value must be a constant.
1658  auto padValue = padOp.getConstantPaddingValue();
1659  if (!padValue)
1660  return failure();
1661  // Padding value of existing `xferOp` is unused.
1662  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
1663  return failure();
1664 
1665  rewriter.updateRootInPlace(xferOp, [&]() {
1666  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
1667  xferOp->setAttr(xferOp.getInBoundsAttrName(),
1668  rewriter.getBoolArrayAttr(inBounds));
1669  xferOp.getSourceMutable().assign(padOp.getSource());
1670  xferOp.getPaddingMutable().assign(padValue);
1671  });
1672 
1673  return success();
1674  }
1675 };
1676 
1677 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
1678 /// This pattern rewrites TransferWriteOps that write to a padded tensor
1679 /// value, where the same amount of padding is immediately removed again after
1680 /// the write. In such cases, the TransferWriteOp can write to the non-padded
1681 /// tensor value and apply out-of-bounds masking. E.g.:
1682 /// ```
1683 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
1684 /// : tensor<...> to tensor<?x?xf32>
1685 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
1686 /// %2 = vector.transfer_write %vec, %1[...]
1687 /// : vector<17x5xf32>, tensor<17x5xf32>
1688 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
1689 /// : tensor<17x5xf32> to tensor<?x?xf32>
1690 /// ```
1691 /// is rewritten to:
1692 /// ```
1693 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
1694 /// : tensor<...> to tensor<?x?xf32>
1695 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
1696 /// tensor<?x?xf32>
1697 /// ```
1698 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
1699 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
1700 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
1701 /// from %r's old dimensions.
1702 ///
1703 /// This rewrite is possible if:
1704 /// - Low padding is static 0.
1705 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
1706 /// ExtractSliceOp trims the same amount of padding that was added
1707 /// beforehand.
1708 /// - Single, scalar padding value.
1710  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
1712  vector::TransferWriteOp>::VectorizePadOpUserPattern;
1713 
1714  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
1715  vector::TransferWriteOp xferOp) const override {
1716  // TODO: support 0-d corner case.
1717  if (xferOp.getTransferRank() == 0)
1718  return failure();
1719 
1720  // Low padding must be static 0.
1721  if (!padOp.hasZeroLowPad())
1722  return failure();
1723  // Pad value must be a constant.
1724  auto padValue = padOp.getConstantPaddingValue();
1725  if (!padValue)
1726  return failure();
1727  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
1728  if (!xferOp->hasOneUse())
1729  return failure();
1730  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
1731  if (!trimPadding)
1732  return failure();
1733  // Only static zero offsets supported when trimming padding.
1734  if (!trimPadding.hasZeroOffset())
1735  return failure();
1736  // trimPadding must remove the amount of padding that was added earlier.
1737  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
1738  return failure();
1739 
1740  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
1741  rewriter.setInsertionPoint(xferOp);
1742 
1743  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
1744  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1745  xferOp, padOp.getSource().getType(), xferOp.getVector(),
1746  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
1747  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
1748  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
1749 
1750  return success();
1751  }
1752 
1753  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
1754  /// i.e., same dimensions.
1755  ///
1756  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
1757  /// dimensions, this function tries to infer the (static) tensor size by
1758  /// looking at the defining op and utilizing op-specific knowledge.
1759  ///
1760  /// This is a conservative analysis. In case equal tensor sizes cannot be
1761  /// proven statically, this analysis returns `false` even though the tensor
1762  /// sizes may turn out to be equal at runtime.
1763  bool hasSameTensorSize(Value beforePadding,
1764  tensor::ExtractSliceOp afterTrimming) const {
1765  // If the input to tensor::PadOp is a CastOp, try with with both CastOp
1766  // result and CastOp operand.
1767  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
1768  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
1769  return true;
1770 
1771  auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
1772  auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
1773  // Only RankedTensorType supported.
1774  if (!t1 || !t2)
1775  return false;
1776  // Rank of both values must be the same.
1777  if (t1.getRank() != t2.getRank())
1778  return false;
1779 
1780  // All static dimensions must be the same. Mixed cases (e.g., dimension
1781  // static in `t1` but dynamic in `t2`) are not supported.
1782  for (unsigned i = 0; i < t1.getRank(); ++i) {
1783  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
1784  return false;
1785  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
1786  return false;
1787  }
1788 
1789  // Nothing more to check if all dimensions are static.
1790  if (t1.getNumDynamicDims() == 0)
1791  return true;
1792 
1793  // All dynamic sizes must be the same. The only supported case at the
1794  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
1795  // thereof).
1796 
1797  // Apart from CastOp, only ExtractSliceOp is supported.
1798  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
1799  if (!beforeSlice)
1800  return false;
1801 
1802  assert(static_cast<size_t>(t1.getRank()) ==
1803  beforeSlice.getMixedSizes().size());
1804  assert(static_cast<size_t>(t2.getRank()) ==
1805  afterTrimming.getMixedSizes().size());
1806 
1807  for (unsigned i = 0; i < t1.getRank(); ++i) {
1808  // Skip static dimensions.
1809  if (!t1.isDynamicDim(i))
1810  continue;
1811  auto size1 = beforeSlice.getMixedSizes()[i];
1812  auto size2 = afterTrimming.getMixedSizes()[i];
1813 
1814  // Case 1: Same value or same constant int.
1815  if (isEqualConstantIntOrValue(size1, size2))
1816  continue;
1817 
1818  // Other cases: Take a deeper look at defining ops of values.
1819  auto v1 = size1.dyn_cast<Value>();
1820  auto v2 = size2.dyn_cast<Value>();
1821  if (!v1 || !v2)
1822  return false;
1823 
1824  // Case 2: Both values are identical AffineMinOps. (Should not happen if
1825  // CSE is run.)
1826  auto minOp1 = v1.getDefiningOp<AffineMinOp>();
1827  auto minOp2 = v2.getDefiningOp<AffineMinOp>();
1828  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
1829  minOp1.getOperands() == minOp2.getOperands())
1830  continue;
1831 
1832  // Add additional cases as needed.
1833  }
1834 
1835  // All tests passed.
1836  return true;
1837  }
1838 };
1839 
1840 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
1841 /// ```
1842 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
1843 /// %r = tensor.insert_slice %0
1844 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
1845 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
1846 /// ```
1847 /// is rewritten to:
1848 /// ```
1849 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
1850 /// : tensor<?x?xf32>, vector<17x5xf32>
1851 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
1852 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
1853 /// ```
1854 ///
1855 /// This rewrite is possible if:
1856 /// - Low padding is static 0.
1857 /// - `padOp` result shape is static.
1858 /// - The entire padded tensor is inserted.
1859 /// (Implies that sizes of `insertOp` are all static.)
1860 /// - Only unit strides in `insertOp`.
1861 /// - Single, scalar padding value.
1862 /// - `padOp` result not used as destination.
1864  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
1866  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
1867 
1868  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
1869  tensor::InsertSliceOp insertOp) const override {
1870  // Low padding must be static 0.
1871  if (!padOp.hasZeroLowPad())
1872  return failure();
1873  // Only unit stride supported.
1874  if (!insertOp.hasUnitStride())
1875  return failure();
1876  // Pad value must be a constant.
1877  auto padValue = padOp.getConstantPaddingValue();
1878  if (!padValue)
1879  return failure();
1880  // Dynamic shapes not supported.
1881  if (!padOp.getResult().getType().cast<ShapedType>().hasStaticShape())
1882  return failure();
1883  // Pad result not used as destination.
1884  if (insertOp.getDest() == padOp.getResult())
1885  return failure();
1886 
1887  auto vecType = VectorType::get(padOp.getType().getShape(),
1888  padOp.getType().getElementType());
1889  unsigned vecRank = vecType.getRank();
1890  unsigned tensorRank = insertOp.getType().getRank();
1891 
1892  // Check if sizes match: Insert the entire tensor into most minor dims.
1893  // (No permutations allowed.)
1894  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
1895  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
1896  if (!llvm::all_of(
1897  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
1898  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
1899  }))
1900  return failure();
1901 
1902  // Insert the TransferReadOp and TransferWriteOp at the position of the
1903  // InsertSliceOp.
1904  rewriter.setInsertionPoint(insertOp);
1905 
1906  // Generate TransferReadOp: Read entire source tensor and add high
1907  // padding.
1908  SmallVector<Value> readIndices(
1909  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1910  auto read = rewriter.create<vector::TransferReadOp>(
1911  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
1912 
1913  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
1914  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
1915  // source must fit into the destination at the specified offsets.
1916  auto writeIndices =
1917  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
1918  SmallVector<bool> inBounds(vecRank, true);
1919  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1920  insertOp, read, insertOp.getDest(), writeIndices,
1921  ArrayRef<bool>{inBounds});
1922 
1923  return success();
1924  }
1925 };
1926 
1928  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
1929  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
1930  baseBenefit);
1931  // Try these specialized patterns first before resorting to the generic one.
1935  patterns.getContext(), baseBenefit.getBenefit() + 1);
1936 }
1937 
1938 //----------------------------------------------------------------------------//
1939 // Forwarding patterns
1940 //----------------------------------------------------------------------------//
1941 
1942 /// Check whether there is any interleaved use of any `values` between
1943 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
1944 /// is in a different block.
1945 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
1946  ValueRange values) {
1947  if (firstOp->getBlock() != secondOp->getBlock() ||
1948  !firstOp->isBeforeInBlock(secondOp)) {
1949  LDBG("interleavedUses precondition failed, firstOp: "
1950  << *firstOp << ", second op: " << *secondOp << "\n");
1951  return true;
1952  }
1953  for (auto v : values) {
1954  for (auto &u : v.getUses()) {
1955  Operation *owner = u.getOwner();
1956  if (owner == firstOp || owner == secondOp)
1957  continue;
1958  // TODO: this is too conservative, use dominance info in the future.
1959  if (owner->getBlock() == firstOp->getBlock() &&
1960  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
1961  continue;
1962  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
1963  << ", second op: " << *secondOp << "\n");
1964  return true;
1965  }
1966  }
1967  return false;
1968 }
1969 
1970 /// Return the unique subview use of `v` if it is indeed unique, null
1971 /// otherwise.
1972 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
1973  memref::SubViewOp subViewOp;
1974  for (auto &u : v.getUses()) {
1975  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
1976  if (subViewOp)
1977  return memref::SubViewOp();
1978  subViewOp = newSubViewOp;
1979  }
1980  }
1981  return subViewOp;
1982 }
1983 
1984 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1985 /// when available.
1987  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
1988 
1989  // TODO: support mask.
1990  if (xferOp.getMask())
1991  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
1992 
1993  // Transfer into `view`.
1994  Value viewOrAlloc = xferOp.getSource();
1995  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1996  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1997  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
1998 
1999  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2000  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2001  if (!subViewOp)
2002  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2003  Value subView = subViewOp.getResult();
2004 
2005  // Find the copy into `subView` without interleaved uses.
2006  memref::CopyOp copyOp;
2007  for (auto &u : subView.getUses()) {
2008  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2009  assert(newCopyOp.getTarget().getType().isa<MemRefType>());
2010  if (newCopyOp.getTarget() != subView)
2011  continue;
2012  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2013  continue;
2014  copyOp = newCopyOp;
2015  break;
2016  }
2017  }
2018  if (!copyOp)
2019  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2020 
2021  // Find the fill into `viewOrAlloc` without interleaved uses before the
2022  // copy.
2023  FillOp maybeFillOp;
2024  for (auto &u : viewOrAlloc.getUses()) {
2025  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2026  assert(newFillOp.output().getType().isa<MemRefType>());
2027  if (newFillOp.output() != viewOrAlloc)
2028  continue;
2029  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2030  continue;
2031  maybeFillOp = newFillOp;
2032  break;
2033  }
2034  }
2035  // Ensure padding matches.
2036  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2037  return rewriter.notifyMatchFailure(xferOp,
2038  "padding value does not match fill");
2039 
2040  // `in` is the subview that memref.copy reads. Replace it.
2041  Value in = copyOp.getSource();
2042 
2043  // memref.copy + linalg.fill can be used to create a padded local buffer.
2044  // The `masked` attribute is only valid on this padded buffer.
2045  // When forwarding to vector.transfer_read, the attribute must be reset
2046  // conservatively.
2047  Value res = rewriter.create<vector::TransferReadOp>(
2048  xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2049  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2050  // in_bounds is explicitly reset
2051  /*inBoundsAttr=*/ArrayAttr());
2052 
2053  if (maybeFillOp)
2054  rewriter.eraseOp(maybeFillOp);
2055  rewriter.eraseOp(copyOp);
2056  rewriter.replaceOp(xferOp, res);
2057 
2058  return success();
2059 }
2060 
2061 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2062 /// when available.
2064  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2065  // TODO: support mask.
2066  if (xferOp.getMask())
2067  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2068 
2069  // Transfer into `viewOrAlloc`.
2070  Value viewOrAlloc = xferOp.getSource();
2071  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2072  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2073  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2074 
2075  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2076  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2077  if (!subViewOp)
2078  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2079  Value subView = subViewOp.getResult();
2080 
2081  // Find the copy from `subView` without interleaved uses.
2082  memref::CopyOp copyOp;
2083  for (auto &u : subViewOp.getResult().getUses()) {
2084  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2085  if (newCopyOp.getSource() != subView)
2086  continue;
2087  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2088  continue;
2089  copyOp = newCopyOp;
2090  break;
2091  }
2092  }
2093  if (!copyOp)
2094  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2095 
2096  // `out` is the subview copied into that we replace.
2097  assert(copyOp.getTarget().getType().isa<MemRefType>());
2098  Value out = copyOp.getTarget();
2099 
2100  // Forward vector.transfer into copy.
2101  // memref.copy + linalg.fill can be used to create a padded local buffer.
2102  // The `masked` attribute is only valid on this padded buffer.
2103  // When forwarding to vector.transfer_write, the attribute must be reset
2104  // conservatively.
2105  rewriter.create<vector::TransferWriteOp>(
2106  xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2107  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2108  // in_bounds is explicitly reset
2109  /*inBoundsAttr=*/ArrayAttr());
2110 
2111  rewriter.eraseOp(copyOp);
2112  rewriter.eraseOp(xferOp);
2113 
2114  return success();
2115 }
2116 
2117 //===----------------------------------------------------------------------===//
2118 // Convolution vectorization patterns
2119 //===----------------------------------------------------------------------===//
2120 
2121 template <int N>
2122 static void bindShapeDims(ShapedType shapedType) {}
2123 
2124 template <int N, typename IntTy, typename... IntTy2>
2125 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2126  val = shapedType.getShape()[N];
2127  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2128 }
2129 
2130 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2131 template <typename... IntTy>
2132 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2133  bindShapeDims<0>(shapedType, vals...);
2134 }
2135 
2136 namespace {
2137 bool isCastOfBlockArgument(Operation *op) {
2138  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2139  op->getOperand(0).isa<BlockArgument>();
2140 }
2141 
2142 bool isSupportedPoolKind(vector::CombiningKind kind) {
2143  switch (kind) {
2144  case vector::CombiningKind::ADD:
2145  case vector::CombiningKind::MAXF:
2146  case vector::CombiningKind::MAXSI:
2147  case vector::CombiningKind::MAXUI:
2148  case vector::CombiningKind::MINF:
2149  case vector::CombiningKind::MINSI:
2150  case vector::CombiningKind::MINUI:
2151  return true;
2152  default:
2153  return false;
2154  }
2155 }
2156 
2157 /// Generate a vector implementation for either:
2158 /// ```
2159 /// Op def: ( w, kw )
2160 /// Iters: ({Par(), Red()})
2161 /// Layout: {{w + kw}, {kw}, {w}}
2162 /// ```
2163 /// kw is unrolled.
2164 ///
2165 /// or
2166 ///
2167 /// ```
2168 /// Op def: ( n, w, c, kw, f )
2169 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2170 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2171 /// ```
2172 /// kw is unrolled, w is unrolled iff dilationW > 1.
2173 ///
2174 /// or
2175 ///
2176 /// ```
2177 /// Op def: ( n, c, w, f, kw )
2178 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2179 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2180 /// ```
2181 /// kw is unrolled, w is unrolled iff dilationW > 1.
2182 ///
2183 /// or
2184 ///
2185 /// ```
2186 /// Op def: ( n, w, c, kw )
2187 /// Iters: ({Par(), Par(), Par(), Red()})
2188 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2189 /// ```
2190 /// kw is unrolled, w is unrolled iff dilationW > 1.
2191 struct Conv1DGenerator
2192  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2193  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2194  int dilationW)
2195  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2196  strideW(strideW), dilationW(dilationW) {
2197  // Determine whether `linalgOp` can be generated with this generator
2198  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2199  return;
2200  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2201  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2202  resShaped = linalgOp.getDpsInitOperand(0)->get();
2203  lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
2204  rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
2205  resShapedType = resShaped.getType().dyn_cast<ShapedType>();
2206  if (!lhsShapedType || !rhsShapedType || !resShapedType)
2207  return;
2208  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2209  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2210  if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) ||
2211  (lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1)))
2212  return;
2213 
2214  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2215  if (!reduceOp)
2216  return;
2217  redOp = reduceOp->getName().getIdentifier();
2218 
2219  if (!setOperKind(reduceOp))
2220  return;
2221  auto maybeKind = getCombinerOpKind(reduceOp);
2222  if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2223  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2224  return;
2225  }
2226 
2227  auto rhsRank = rhsShapedType.getRank();
2228  switch (oper) {
2229  case Conv:
2230  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2231  return;
2232  break;
2233  case Pool:
2234  if (rhsRank != 1)
2235  return;
2236  break;
2237  }
2238  // The op is now known to be valid.
2239  valid = true;
2240  }
2241 
2242  /// Generate a vector implementation for:
2243  /// ```
2244  /// Op def: ( w, kw )
2245  /// Iters: ({Par(), Red()})
2246  /// Layout: {{w + kw}, {kw}, {w}}
2247  /// ```
2248  /// kw is always unrolled.
2249  ///
2250  /// or
2251  ///
2252  /// ```
2253  /// Op def: ( n, w, c, kw, f )
2254  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2255  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2256  /// ```
2257  /// kw is always unrolled.
2258  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2259  /// > 1.
2260  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
2261  if (!valid)
2262  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
2263 
2264  int64_t nSize, wSize, cSize, kwSize, fSize;
2265  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
2266  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
2267  switch (conv1DOpOrder) {
2268  case Conv1DOpOrder::W:
2269  // Initialize unused dimensions
2270  nSize = fSize = cSize = 0;
2271  // out{W}
2272  bindShapeDims(resShapedType, wSize);
2273  // kernel{kw}
2274  bindShapeDims(rhsShapedType, kwSize);
2275  lhsShape = {// iw = ow + kw - 1
2276  // (i.e. 16 convolved with 3 -> 14)
2277  (wSize + kwSize - 1)};
2278  rhsShape = {kwSize};
2279  resShape = {wSize};
2280  break;
2281  case Conv1DOpOrder::Nwc:
2282  // out{n, w, f}
2283  bindShapeDims(resShapedType, nSize, wSize, fSize);
2284  switch (oper) {
2285  case Conv:
2286  // kernel{kw, c, f}
2287  bindShapeDims(rhsShapedType, kwSize, cSize);
2288  break;
2289  case Pool:
2290  // kernel{kw}
2291  bindShapeDims(rhsShapedType, kwSize);
2292  cSize = fSize;
2293  break;
2294  }
2295  lhsShape = {nSize,
2296  // iw = ow * sw + kw * dw - 1
2297  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2298  // Perform the proper inclusive -> exclusive -> inclusive.
2299  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2300  1,
2301  cSize};
2302  switch (oper) {
2303  case Conv:
2304  rhsShape = {kwSize, cSize, fSize};
2305  break;
2306  case Pool:
2307  rhsShape = {kwSize};
2308  break;
2309  }
2310  resShape = {nSize, wSize, fSize};
2311  break;
2312  case Conv1DOpOrder::Ncw:
2313  // out{n, f, w}
2314  bindShapeDims(resShapedType, nSize, fSize, wSize);
2315  switch (oper) {
2316  case Conv:
2317  // kernel{f, c, kw}
2318  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
2319  break;
2320  case Pool:
2321  // kernel{kw}
2322  bindShapeDims(rhsShapedType, kwSize);
2323  cSize = fSize;
2324  break;
2325  }
2326  lhsShape = {nSize, cSize,
2327  // iw = ow * sw + kw * dw - 1
2328  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2329  // Perform the proper inclusive -> exclusive -> inclusive.
2330  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2331  1};
2332  switch (oper) {
2333  case Conv:
2334  rhsShape = {fSize, cSize, kwSize};
2335  break;
2336  case Pool:
2337  rhsShape = {kwSize};
2338  break;
2339  }
2340  resShape = {nSize, fSize, wSize};
2341  break;
2342  }
2343 
2344  vector::TransferWriteOp write;
2345  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2346 
2347  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2348  // When strideW == 1, we can batch the contiguous loads and avoid
2349  // unrolling
2350  int64_t wSizeStep = strideW == 1 ? wSize : 1;
2351 
2352  Type lhsEltType = lhsShapedType.getElementType();
2353  Type rhsEltType = rhsShapedType.getElementType();
2354  Type resEltType = resShapedType.getElementType();
2355  auto lhsType = VectorType::get(lhsShape, lhsEltType);
2356  auto rhsType = VectorType::get(rhsShape, rhsEltType);
2357  auto resType = VectorType::get(resShape, resEltType);
2358  // Zero padding with the corresponding dimensions for lhs, rhs and res.
2359  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
2360  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
2361  SmallVector<Value> resPadding(resShape.size(), zero);
2362 
2363  // Read the whole lhs, rhs and res in one shot (with zero padding).
2364  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2365  lhsPadding);
2366  // This is needed only for Conv.
2367  Value rhs = nullptr;
2368  if (oper == Conv)
2369  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2370  rhsPadding);
2371  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
2372  resPadding);
2373 
2374  // The base vectorization case for channeled convolution is input: {n,w,c},
2375  // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2376  // vectorization case, we do pre transpose on input, weight, and output.
2377  switch (conv1DOpOrder) {
2378  case Conv1DOpOrder::W:
2379  case Conv1DOpOrder::Nwc:
2380  // Base case, so no transposes necessary.
2381  break;
2382  case Conv1DOpOrder::Ncw: {
2383  // To match base vectorization case, we pre-transpose current case.
2384  // ncw -> nwc
2385  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2386  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
2387  // fcw -> wcf
2388  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2389 
2390  // This is needed only for Conv.
2391  if (oper == Conv)
2392  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
2393  // nfw -> nwf
2394  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2395  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
2396  break;
2397  }
2398  }
2399 
2400  //===------------------------------------------------------------------===//
2401  // Begin vector-only rewrite part
2402  //===------------------------------------------------------------------===//
2403  // Unroll along kw and read slices of lhs and rhs.
2404  SmallVector<Value> lhsVals, rhsVals, resVals;
2405  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
2406  kwSize, strideW, dilationW, wSizeStep,
2407  isSingleChanneled);
2408  // Do not do for pooling.
2409  if (oper == Conv)
2410  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
2411  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
2412  wSizeStep, isSingleChanneled);
2413 
2414  auto linearIndex = [&](int64_t kw, int64_t w) {
2415  return kw * (wSize / wSizeStep) + w;
2416  };
2417 
2418  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
2419  // perform outerproduct for non-channeled convolution or
2420  // perform simple arith operation for pooling
2421  for (int64_t kw = 0; kw < kwSize; ++kw) {
2422  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2423  switch (oper) {
2424  case Conv:
2425  if (isSingleChanneled) {
2426  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
2427  lhsVals[linearIndex(kw, w)],
2428  rhsVals[kw], resVals[w]);
2429  } else {
2430  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
2431  lhsVals[linearIndex(kw, w)],
2432  rhsVals[kw], resVals[w]);
2433  }
2434  break;
2435  case Pool:
2436  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
2437  resVals[w]);
2438  break;
2439  }
2440  }
2441  }
2442 
2443  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
2444  isSingleChanneled);
2445  //===------------------------------------------------------------------===//
2446  // End vector-only rewrite part
2447  //===------------------------------------------------------------------===//
2448 
2449  // The base vectorization case for channeled convolution is output: {n,w,f}
2450  // To reuse the result from base pattern vectorization case, we post
2451  // transpose the base case result.
2452  switch (conv1DOpOrder) {
2453  case Conv1DOpOrder::W:
2454  case Conv1DOpOrder::Nwc:
2455  // Base case, so no transposes necessary.
2456  break;
2457  case Conv1DOpOrder::Ncw: {
2458  // nwf -> nfw
2459  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
2460  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
2461  break;
2462  }
2463  }
2464 
2465  return rewriter
2466  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
2467  .getOperation();
2468  }
2469 
2470  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
2471  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
2472  Value lhs, Value rhs, Value res) {
2473  vector::IteratorType par = vector::IteratorType::parallel;
2474  vector::IteratorType red = vector::IteratorType::reduction;
2475  AffineExpr n, w, f, c;
2476  bindDims(ctx, n, w, f, c);
2477  return rewriter.create<vector::ContractionOp>(
2478  loc, lhs, rhs, res,
2479  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
2480  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
2481  }
2482 
2483  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
2484  // convolution.
2485  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
2486  Value lhs, Value rhs, Value res) {
2487  return rewriter.create<vector::OuterProductOp>(
2488  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
2489  }
2490 
2491  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
2492  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
2493  Value res) {
2494  if (isPoolExt)
2495  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
2496  return rewriter
2497  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
2498  ->getResult(0);
2499  }
2500 
2501  /// Generate a vector implementation for:
2502  /// ```
2503  /// Op def: ( n, w, c, kw)
2504  /// Iters: ({Par(), Par(), Par(), Red()})
2505  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2506  /// ```
2507  /// kw is always unrolled.
2508  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2509  /// > 1.
2510  FailureOr<Operation *> depthwiseConv() {
2511  if (!valid)
2512  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
2513 
2514  int64_t nSize, wSize, cSize, kwSize;
2515  // kernel{kw, c}
2516  bindShapeDims(rhsShapedType, kwSize, cSize);
2517  // out{n, w, c}
2518  bindShapeDims(resShapedType, nSize, wSize);
2519 
2520  vector::TransferWriteOp write;
2521  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2522 
2523  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2524  // When strideW == 1, we can batch the contiguous loads and avoid
2525  // unrolling
2526  int64_t wSizeStep = strideW == 1 ? wSize : 1;
2527 
2528  Type lhsEltType = lhsShapedType.getElementType();
2529  Type rhsEltType = rhsShapedType.getElementType();
2530  Type resEltType = resShapedType.getElementType();
2531  VectorType lhsType = VectorType::get(
2532  {nSize,
2533  // iw = ow * sw + kw * dw - 1
2534  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2535  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
2536  cSize},
2537  lhsEltType);
2538  VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
2539  VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
2540 
2541  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
2542  // 0].
2543  Value lhs = rewriter.create<vector::TransferReadOp>(
2544  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
2545  // Read rhs slice of size {kw, c} @ [0, 0].
2546  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2547  ValueRange{zero, zero});
2548  // Read res slice of size {n, w, c} @ [0, 0, 0].
2549  Value res = rewriter.create<vector::TransferReadOp>(
2550  loc, resType, resShaped, ValueRange{zero, zero, zero});
2551 
2552  //===------------------------------------------------------------------===//
2553  // Begin vector-only rewrite part
2554  //===------------------------------------------------------------------===//
2555  // Unroll along kw and read slices of lhs and rhs.
2556  SmallVector<Value> lhsVals, rhsVals, resVals;
2557  // Extract lhs slice of size {n, wSizeStep, c}
2558  // @ [0, sw * w + dw * kw, 0].
2559  for (int64_t kw = 0; kw < kwSize; ++kw) {
2560  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2561  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
2562  loc, lhs,
2563  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
2564  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
2565  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
2566  }
2567  }
2568  // Extract rhs slice of size {c} @ [kw].
2569  for (int64_t kw = 0; kw < kwSize; ++kw) {
2570  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
2571  loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
2572  }
2573  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
2574  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2575  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
2576  loc, res,
2577  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
2578  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
2579  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
2580  }
2581 
2582  auto linearIndex = [&](int64_t kw, int64_t w) {
2583  return kw * (wSize / wSizeStep) + w;
2584  };
2585 
2586  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
2587  for (int64_t kw = 0; kw < kwSize; ++kw) {
2588  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2589  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
2590  lhsVals[linearIndex(kw, w)],
2591  rhsVals[kw], resVals[w]);
2592  }
2593  }
2594 
2595  // Its possible we failed to create the Fma.
2596  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
2597  // Manually revert (in reverse order) to avoid leaving a bad IR state.
2598  for (auto &collection :
2599  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
2600  for (Value v : collection)
2601  rewriter.eraseOp(v.getDefiningOp());
2602  return rewriter.notifyMatchFailure(op, "failed to create FMA");
2603  }
2604 
2605  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
2606  // This does not depend on kw.
2607  for (int64_t w = 0; w < wSize; w += wSizeStep) {
2608  res = rewriter.create<vector::InsertStridedSliceOp>(
2609  loc, resVals[w], res,
2610  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
2611  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
2612  }
2613  //===------------------------------------------------------------------===//
2614  // End vector-only rewrite part
2615  //===------------------------------------------------------------------===//
2616 
2617  // Write back res slice of size {n, w, c} @ [0, 0, 0].
2618  return rewriter
2619  .create<vector::TransferWriteOp>(loc, res, resShaped,
2620  ValueRange{zero, zero, zero})
2621  .getOperation();
2622  }
2623 
2624  // Take a value of element type T and widen to the destination type.
2625  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
2626  if (val.getType() == ty)
2627  return val;
2628 
2629  const int64_t srcWidth =
2631  const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
2632 
2633  if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
2634  return rewriter.create<arith::ExtFOp>(loc, ty, val);
2635 
2636  if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
2637  return rewriter.create<arith::ExtSIOp>(loc, ty, val);
2638 
2639  return nullptr;
2640  }
2641 
2642  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
2643  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
2644  Value lhs, Value rhs, Value res) {
2645  auto rhsTy = rhs.getType().cast<ShapedType>();
2646  auto resTy = res.getType().cast<ShapedType>();
2647 
2648  // TODO(suderman): Change this to use a vector.ima intrinsic.
2649  lhs = promote(rewriter, loc, lhs, resTy);
2650 
2651  rhs = rewriter.create<vector::BroadcastOp>(
2652  loc, resTy.clone(rhsTy.getElementType()), rhs);
2653  rhs = promote(rewriter, loc, rhs, resTy);
2654 
2655  if (!lhs || !rhs)
2656  return nullptr;
2657 
2658  if (resTy.getElementType().isa<FloatType>())
2659  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
2660 
2661  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
2662  return rewriter.create<arith::AddIOp>(loc, mul, res);
2663  }
2664 
2665  /// Entry point for non-channeled convolution:
2666  /// {{w + kw}, {kw}, {w}}
2667  FailureOr<Operation *> generateNonChanneledConv() {
2668  AffineExpr w, kw;
2669  bindDims(ctx, w, kw);
2670  if (!iters({Par(), Red()}))
2671  return rewriter.notifyMatchFailure(op,
2672  "failed to match conv::W 1-par 1-red");
2673 
2674  // No transposition needed.
2675  if (layout({/*lhsIndex*/ {w + kw},
2676  /*rhsIndex*/ {kw},
2677  /*resIndex*/ {w}}))
2678  return conv(Conv1DOpOrder::W);
2679 
2680  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
2681  }
2682 
2683  /// Entry point that transposes into the common form:
2684  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2685  FailureOr<Operation *> generateNwcConv() {
2686  AffineExpr n, w, f, kw, c;
2687  bindDims(ctx, n, w, f, kw, c);
2688  if (!iters({Par(), Par(), Par(), Red(), Red()}))
2689  return rewriter.notifyMatchFailure(
2690  op, "failed to match conv::Nwc 3-par 2-red");
2691 
2692  // No transposition needed.
2693  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
2694  /*rhsIndex*/ {kw, c, f},
2695  /*resIndex*/ {n, w, f}}))
2696  return conv(Conv1DOpOrder::Nwc);
2697 
2698  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
2699  }
2700 
2701  /// Entry point that transposes into the common form:
2702  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2703  FailureOr<Operation *> generateNcwConv() {
2704  AffineExpr n, w, f, kw, c;
2705  bindDims(ctx, n, f, w, c, kw);
2706  if (!iters({Par(), Par(), Par(), Red(), Red()}))
2707  return rewriter.notifyMatchFailure(
2708  op, "failed to match conv::Ncw 3-par 2-red");
2709 
2710  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
2711  /*rhsIndex*/ {f, c, kw},
2712  /*resIndex*/ {n, f, w}}))
2713  return conv(Conv1DOpOrder::Ncw);
2714 
2715  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
2716  }
2717 
2718  /// Entry point that transposes into the common form:
2719  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
2720  FailureOr<Operation *> generateNwcPooling() {
2721  AffineExpr n, w, c, kw;
2722  bindDims(ctx, n, w, c, kw);
2723  if (!iters({Par(), Par(), Par(), Red()}))
2724  return rewriter.notifyMatchFailure(op,
2725  "failed to match pooling 3-par 1-red");
2726 
2727  // No transposition needed.
2728  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
2729  /*rhsIndex*/ {kw},
2730  /*resIndex*/ {n, w, c}}))
2731  return conv(Conv1DOpOrder::Nwc);
2732 
2733  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
2734  }
2735 
2736  /// Entry point that transposes into the common form:
2737  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
2738  FailureOr<Operation *> generateNcwPooling() {
2739  AffineExpr n, w, c, kw;
2740  bindDims(ctx, n, c, w, kw);
2741  if (!iters({Par(), Par(), Par(), Red()}))
2742  return rewriter.notifyMatchFailure(op,
2743  "failed to match pooling 3-par 1-red");
2744 
2745  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
2746  /*rhsIndex*/ {kw},
2747  /*resIndex*/ {n, c, w}}))
2748  return conv(Conv1DOpOrder::Ncw);
2749 
2750  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
2751  }
2752 
2753  /// Entry point that transposes into the common form:
2754  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2755  FailureOr<Operation *> generateDilatedConv() {
2756  AffineExpr n, w, c, kw;
2757  bindDims(ctx, n, w, c, kw);
2758  if (!iters({Par(), Par(), Par(), Red()}))
2759  return rewriter.notifyMatchFailure(
2760  op, "failed to match depthwise::Nwc conv 3-par 1-red");
2761 
2762  // No transposition needed.
2763  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
2764  /*rhsIndex*/ {kw, c},
2765  /*resIndex*/ {n, w, c}}))
2766  return depthwiseConv();
2767 
2768  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
2769  }
2770 
2771 private:
2772  enum OperKind { Conv, Pool };
2773  bool valid = false;
2774  OperKind oper = Conv;
2775  StringAttr redOp;
2776  StringAttr poolExtOp;
2777  bool isPoolExt = false;
2778  int strideW, dilationW;
2779  Value lhsShaped, rhsShaped, resShaped;
2780  ShapedType lhsShapedType, rhsShapedType, resShapedType;
2781 
2782  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
2783  // Returns true iff it is a valid conv/pooling op.
2784  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2785  // + yield) and rhs is not used) then it is the body of a pooling
2786  // If conv, check for single `mul` predecessor. The `mul` operands must be
2787  // block arguments or extension of block arguments.
2788  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2789  // must be block arguments or extension of block arguments.
2790  bool setOperKind(Operation *reduceOp) {
2791  int numBlockArguments =
2792  llvm::count_if(reduceOp->getOperands(),
2793  [](Value v) { return v.isa<BlockArgument>(); });
2794  switch (numBlockArguments) {
2795  case 1: {
2796  // Will be convolution if feeder is a MulOp.
2797  // Otherwise, if it can be pooling.
2798  auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
2799  return !v.isa<BlockArgument>();
2800  });
2801  Operation *feedOp = (*feedValIt).getDefiningOp();
2802  if (isCastOfBlockArgument(feedOp)) {
2803  oper = Pool;
2804  isPoolExt = true;
2805  poolExtOp = feedOp->getName().getIdentifier();
2806  } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
2807  llvm::all_of(feedOp->getOperands(), [](Value v) {
2808  if (v.isa<BlockArgument>())
2809  return true;
2810  if (Operation *op = v.getDefiningOp())
2811  return isCastOfBlockArgument(op);
2812  return false;
2813  }))) {
2814  return false;
2815  }
2816  return true;
2817  }
2818  case 2:
2819  // Must be pooling
2820  oper = Pool;
2821  isPoolExt = false;
2822  return true;
2823  default:
2824  return false;
2825  }
2826  }
2827 };
2828 } // namespace
2829 
2830 /// Helper function to vectorize a LinalgOp with convolution semantics.
2831 // TODO: extend the generic vectorization to support windows and drop this.
2833  LinalgOp op) {
2834  // The ConvolutionOpInterface gives us guarantees of existence for
2835  // strides/dilations. However, we do not need to rely on those, we can simply
2836  // use them if present, otherwise use the default and let the generic conv.
2837  // matcher in the ConvGenerator succeed or fail.
2838  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
2839  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
2840  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
2841  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
2842  Conv1DGenerator e(rewriter, op, stride, dilation);
2843  auto res = e.generateNonChanneledConv();
2844  if (succeeded(res))
2845  return res;
2846  res = e.generateNwcConv();
2847  if (succeeded(res))
2848  return res;
2849  res = e.generateNcwConv();
2850  if (succeeded(res))
2851  return res;
2852  res = e.generateNwcPooling();
2853  if (succeeded(res))
2854  return res;
2855  res = e.generateNcwPooling();
2856  if (succeeded(res))
2857  return res;
2858  return e.generateDilatedConv();
2859 }
2860 
2862  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
2863 
2865  PatternRewriter &rewriter) const override {
2866  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
2867  if (failed(resultOrFail))
2868  return failure();
2869  Operation *newOp = *resultOrFail;
2870  if (newOp->getNumResults() == 0) {
2871  rewriter.eraseOp(op.getOperation());
2872  return success();
2873  }
2874  assert(newOp->getNumResults() == 1 && "expected single result");
2875  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
2876  return success();
2877  }
2878 };
2879 
2881  RewritePatternSet &patterns, PatternBenefit benefit) {
2882  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
2883 }
static Value calculateGatherOffset(RewriterBase &rewriter, tensor::ExtractOp extractOp, const IRMapping &bvm, const ArrayRef< int64_t > targetShape)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static LogicalResult reductionPreconditions(LinalgOp op)
VectorMemoryAccessKind
@ Contiguous
@ Gather
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp)
Try to vectorize convOp as a convolution.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Check whether extractOp would be a gather or a contiguous load Op after vectorising linalgOp.
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static Value broadcastIfNeeded(OpBuilder &b, Value value, ArrayRef< int64_t > shape)
Broadcast value to a vector of shape if possible.
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:258
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:262
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:521
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:337
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:482
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:551
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:176
This class represents an argument of a Block.
Definition: Value.h:304
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:316
Block represents an ordered list of Operations.
Definition: Block.h:30
RetT walk(FnT &&callback)
Walk the operations in this block.
Definition: Block.h:273
OpListType & getOperations()
Definition: Block.h:126
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:362
IntegerType getI32Type()
Definition: Builders.cpp:80
MLIRContext * getContext() const
Definition: Builders.h:55
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:318
IntegerType getI1Type()
Definition: Builders.cpp:70
IndexType getIndexType()
Definition: Builders.cpp:68
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:263
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:147
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:426
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:520
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:501
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
This class represents an operand of an operation.
Definition: Value.h:255
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Value getOperand(unsigned idx)
Definition: Operation.h:329
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:274
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
unsigned getNumOperands()
Definition: Operation.h:325
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:418
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:197
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:550
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
result_type_range getResultTypes()
Definition: Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
result_range getResults()
Definition: Operation.h:394
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:597
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:583
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool isa() const
Definition: Value.h:98
Type getType() const
Return the type of this value.
Definition: Value.h:122
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:201
U dyn_cast() const
Definition: Value.h:103
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1147
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:227
LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes={}, bool vectorizeNDExtract=false)
Emit a suitable vector form for a Linalg op.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:266
LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes={}, bool vectorizeNDExtract=false)
Return success if the operation can be vectorized.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:247
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(PatternRewriter &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:555
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:49
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:1949
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:609
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:691
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:329
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:667
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:59
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:625
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:212
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeMaskingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1105
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.