MLIR  19.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58  ArrayRef<int64_t> inputVecSizes = {},
59  ArrayRef<bool> inputVecScalableFlags = {},
60  bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66  OpType res;
67  block.walk([&](OpType op) {
68  if (res) {
69  res = nullptr;
70  return WalkResult::interrupt();
71  }
72  res = op;
73  return WalkResult::advance();
74  });
75  return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
82  int64_t nSize, int64_t wSize, int64_t cSize,
83  int64_t kwSize, int strideW, int dilationW,
84  int64_t wSizeStep, bool isSingleChanneled) {
85  SmallVector<Value> result;
86  if (isSingleChanneled) {
87  // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88  // convolution.
89  SmallVector<int64_t> sizes{wSizeStep};
90  SmallVector<int64_t> strides{1};
91  for (int64_t kw = 0; kw < kwSize; ++kw) {
92  for (int64_t w = 0; w < wSize; w += wSizeStep) {
93  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95  }
96  }
97  } else {
98  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99  // for channeled convolution.
100  SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101  SmallVector<int64_t> strides{1, 1, 1};
102  for (int64_t kw = 0; kw < kwSize; ++kw) {
103  for (int64_t w = 0; w < wSize; w += wSizeStep) {
104  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105  loc, input,
106  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107  sizes, strides));
108  }
109  }
110  }
111  return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
117  Location loc, Value filter,
118  int64_t kwSize) {
119  SmallVector<Value> result;
120  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121  // non-chanelled convolution] @ [kw].
122  for (int64_t kw = 0; kw < kwSize; ++kw) {
123  result.push_back(rewriter.create<vector::ExtractOp>(
124  loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125  }
126  return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
133  int64_t nSize, int64_t wSize, int64_t fSize,
134  int64_t wSizeStep, bool isSingleChanneled) {
135  SmallVector<Value> result;
136  if (isSingleChanneled) {
137  // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138  SmallVector<int64_t> sizes{wSizeStep};
139  SmallVector<int64_t> strides{1};
140  for (int64_t w = 0; w < wSize; w += wSizeStep) {
141  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142  loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143  }
144  } else {
145  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146  // convolution.
147  SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148  SmallVector<int64_t> strides{1, 1, 1};
149  for (int64_t w = 0; w < wSize; w += wSizeStep) {
150  result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151  loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152  }
153  }
154  return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
159  Value res, int64_t wSize, int64_t wSizeStep,
160  SmallVectorImpl<Value> &resVals,
161  bool isSingleChanneled) {
162 
163  if (isSingleChanneled) {
164  // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165  // This does not depend on kw.
166  SmallVector<int64_t> strides{1};
167  for (int64_t w = 0; w < wSize; w += wSizeStep) {
168  res = rewriter.create<vector::InsertStridedSliceOp>(
169  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170  }
171  } else {
172  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173  // convolution. This does not depend on kw.
174  SmallVector<int64_t> strides{1, 1, 1};
175  for (int64_t w = 0; w < wSize; w += wSizeStep) {
176  res = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178  strides);
179  }
180  }
181  return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
187  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189  /// Initializes the vectorization state, including the computation of the
190  /// canonical vector shape for vectorization.
191  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192  ArrayRef<int64_t> inputVectorSizes,
193  ArrayRef<bool> inputScalableVecDims);
194 
195  /// Returns the canonical vector shape used to vectorize the iteration space.
196  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198  /// Returns a vector type of the provided `elementType` with the canonical
199  /// vector shape and the corresponding fixed/scalable dimensions bit. If
200  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
201  /// accordingly.
203  Type elementType,
204  std::optional<AffineMap> dimPermutation = std::nullopt) const {
206  SmallVector<bool> scalableDims;
207  if (dimPermutation.has_value()) {
208  vectorShape =
209  applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
210  scalableDims =
211  applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
212  } else {
213  vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
214  scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
215  }
216 
217  return VectorType::get(vectorShape, elementType, scalableDims);
218  }
219 
220  /// Masks an operation with the canonical vector mask if the operation needs
221  /// masking. Returns the masked operation or the original operation if masking
222  /// is not needed. If provided, the canonical mask for this operation is
223  /// permuted using `maybeMaskingMap`.
224  Operation *
225  maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
226  std::optional<AffineMap> maybeMaskingMap = std::nullopt);
227 
228 private:
229  /// Initializes the iteration space static sizes using the Linalg op
230  /// information. This may become more complicated in the future.
231  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
232  iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
233  }
234 
235  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
236  /// all the static and dynamic dimensions of the iteration space to be
237  /// vectorized and store them in `iterSpaceValueSizes`.
238  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
239  LinalgOp linalgOp);
240 
241  /// Create or retrieve an existing mask value to mask `opToMask` in the
242  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
243  /// permuted using that permutation map. If a new mask is created, it will be
244  /// cached for future users.
245  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
246  LinalgOp linalgOp,
247  std::optional<AffineMap> maybeMaskingMap);
248 
249  // Holds the compile-time static sizes of the iteration space to vectorize.
250  // Dynamic dimensions are represented using ShapedType::kDynamic.
251  SmallVector<int64_t> iterSpaceStaticSizes;
252 
253  /// Holds the value sizes of the iteration space to vectorize. Static
254  /// dimensions are represented by 'arith.constant' and dynamic
255  /// dimensions by 'tensor/memref.dim'.
256  SmallVector<Value> iterSpaceValueSizes;
257 
258  /// Holds the canonical vector shape used to vectorize the iteration space.
259  SmallVector<int64_t> canonicalVecShape;
260 
261  /// Holds the vector dimensions that are scalable in the canonical vector
262  /// shape.
263  SmallVector<bool> scalableVecDims;
264 
265  /// Holds the active masks for permutations of the canonical vector iteration
266  /// space.
267  DenseMap<AffineMap, Value> activeMaskCache;
268 
269  /// Global vectorization guard for the incoming rewriter. It's initialized
270  /// when the vectorization state is initialized.
272 };
273 
275 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276  LinalgOp linalgOp) {
277  // TODO: Support 0-d vectors.
278  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
279  if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
280  // Create constant index op for static dimensions.
281  iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
282  linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
283  continue;
284  }
285 
286  // Find an operand defined on this dimension of the iteration space to
287  // extract the runtime dimension size.
288  Value operand;
289  unsigned operandDimPos;
290  if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
291  operandDimPos)))
292  return failure();
293 
294  Value dynamicDim = linalgOp.hasPureTensorSemantics()
295  ? (Value)rewriter.create<tensor::DimOp>(
296  linalgOp.getLoc(), operand, operandDimPos)
297  : (Value)rewriter.create<memref::DimOp>(
298  linalgOp.getLoc(), operand, operandDimPos);
299  iterSpaceValueSizes.push_back(dynamicDim);
300  }
301 
302  return success();
303 }
304 
305 /// Initializes the vectorization state, including the computation of the
306 /// canonical vector shape for vectorization.
307 // TODO: Move this to the constructor when we can remove the failure cases.
309 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
310  ArrayRef<int64_t> inputVectorSizes,
311  ArrayRef<bool> inputScalableVecDims) {
312  // Initialize the insertion point.
313  rewriter.setInsertionPoint(linalgOp);
314 
315  if (!inputVectorSizes.empty()) {
316  // Get the canonical vector shape from the input vector sizes provided. This
317  // path should be taken to vectorize code with dynamic shapes and when using
318  // vector sizes greater than the iteration space sizes.
319  canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
320  scalableVecDims.append(inputScalableVecDims.begin(),
321  inputScalableVecDims.end());
322  } else {
323  // Compute the canonical vector shape from the operation shape. If there are
324  // dynamic shapes, the operation won't be vectorized. We assume all the
325  // vector dimensions are fixed.
326  canonicalVecShape = linalgOp.getStaticLoopRanges();
327  scalableVecDims.append(linalgOp.getNumLoops(), false);
328  }
329 
330  LDBG("Canonical vector shape: ");
331  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
332  LLVM_DEBUG(llvm::dbgs() << "\n");
333  LDBG("Scalable vector dims: ");
334  LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
335  LLVM_DEBUG(llvm::dbgs() << "\n");
336 
337  if (ShapedType::isDynamicShape(canonicalVecShape))
338  return failure();
339 
340  // Initialize iteration space static sizes.
341  initIterSpaceStaticSizes(linalgOp);
342 
343  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
344  // all the static and dynamic dimensions of the iteration space, needed to
345  // compute a mask during vectorization.
346  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
347  return failure();
348 
349  return success();
350 }
351 
352 /// Create or retrieve an existing mask value to mask `opToMask` in the
353 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
354 /// using that permutation map. If a new mask is created, it will be cached for
355 /// future users.
356 Value VectorizationState::getOrCreateMaskFor(
357  RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
358  std::optional<AffineMap> maybeMaskingMap) {
359  // No mask is needed if the operation is not maskable.
360  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
361  if (!maskableOp)
362  return Value();
363 
364  assert(!maskableOp.isMasked() &&
365  "Masking an operation that is already masked");
366 
367  // If no masking map was provided, use an identity map with the loop dims.
368  assert((!maybeMaskingMap || *maybeMaskingMap) &&
369  "Unexpected null mask permutation map");
370  AffineMap maskingMap =
371  maybeMaskingMap ? *maybeMaskingMap
373  linalgOp.getNumLoops(), rewriter.getContext());
374 
375  LDBG("Masking map: " << maskingMap << "\n");
376 
377  // Return the active mask for the masking map of this operation if it was
378  // already created.
379  auto activeMaskIt = activeMaskCache.find(maskingMap);
380  if (activeMaskIt != activeMaskCache.end()) {
381  Value mask = activeMaskIt->second;
382  LDBG("Reusing mask: " << mask << "\n");
383  return mask;
384  }
385 
386  // Compute permuted projection of the iteration space to be masked and the
387  // corresponding mask shape. If the resulting iteration space dimensions are
388  // static and identical to the mask shape, masking is not needed for this
389  // operation.
390  // TODO: Improve this check. Only projected permutation indexing maps are
391  // supported.
392  SmallVector<int64_t> permutedStaticSizes =
393  applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
394  auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
395  auto maskShape = maskType.getShape();
396 
397  LDBG("Mask shape: ");
398  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
399  LLVM_DEBUG(llvm::dbgs() << "\n");
400 
401  if (permutedStaticSizes == maskShape) {
402  LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
403  activeMaskCache[maskingMap] = Value();
404  return Value();
405  }
406 
407  // Permute the iteration space value sizes to compute the mask upper bounds.
408  SmallVector<Value> upperBounds =
409  applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
410  assert(!maskShape.empty() && !upperBounds.empty() &&
411  "Masked 0-d vectors are not supported yet");
412 
413  // Create the mask based on the dimension values.
414  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
415  maskType, upperBounds);
416  LDBG("Creating new mask: " << mask << "\n");
417  activeMaskCache[maskingMap] = mask;
418  return mask;
419 }
420 
421 /// Masks an operation with the canonical vector mask if the operation needs
422 /// masking. Returns the masked operation or the original operation if masking
423 /// is not needed. If provided, the canonical mask for this operation is
424 /// permuted using `maybeMaskingMap`.
425 Operation *
427  LinalgOp linalgOp,
428  std::optional<AffineMap> maybeMaskingMap) {
429  LDBG("Trying to mask: " << *opToMask << "\n");
430 
431  // Create or retrieve mask for this operation.
432  Value mask =
433  getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
434 
435  if (!mask) {
436  LDBG("No mask required\n");
437  return opToMask;
438  }
439 
440  // Wrap the operation with a new `vector.mask` and update D-U chain.
441  assert(opToMask && "Expected a valid operation to mask");
442  auto maskOp = cast<vector::MaskOp>(
443  mlir::vector::maskOperation(rewriter, opToMask, mask));
444  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
445 
446  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
447  rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
448  maskOpTerminator);
449 
450  LDBG("Masked operation: " << *maskOp << "\n");
451  return maskOp;
452 }
453 
454 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
455 /// projectedPermutation, compress the unused dimensions to serve as a
456 /// permutation_map for a vector transfer operation.
457 /// For example, given a linalg op such as:
458 ///
459 /// ```
460 /// %0 = linalg.generic {
461 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
462 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
463 /// }
464 /// ins(%0 : tensor<2x3x4xf32>)
465 /// outs(%1 : tensor<5x6xf32>)
466 /// ```
467 ///
468 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
469 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
470 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
472  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
473  "expected projected permutation");
474  auto res = compressUnusedDims(map);
475  assert(res.getNumDims() == res.getNumResults() &&
476  "expected reindexed map with same number of dims and results");
477  return res;
478 }
479 
480 /// Helper enum to represent conv1d input traversal order.
481 enum class Conv1DOpOrder {
482  W, // Corresponds to non-channeled 1D convolution operation.
483  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
484  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
485 };
486 
487 /// Helper data structure to represent the result of vectorization.
488 /// In certain specific cases, like terminators, we do not want to propagate/
490  /// Op failed to vectorize.
491  Failure = 0,
492  /// Op vectorized and custom function took care of replacement logic
494  /// Op vectorized into a new Op whose results will replace original Op's
495  /// results.
496  NewOp
497  // TODO: support values if Op vectorized to Many-Ops whose results we need to
498  // aggregate for replacement.
499 };
501  /// Return status from vectorizing the current op.
503  /// New vectorized operation to replace the current op.
504  /// Replacement behavior is specified by `status`.
506 };
507 
508 std::optional<vector::CombiningKind>
510  using ::mlir::vector::CombiningKind;
511 
512  if (!combinerOp)
513  return std::nullopt;
515  .Case<arith::AddIOp, arith::AddFOp>(
516  [&](auto op) { return CombiningKind::ADD; })
517  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
518  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
519  .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
520  .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
521  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
522  .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
523  .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
524  .Case<arith::MulIOp, arith::MulFOp>(
525  [&](auto op) { return CombiningKind::MUL; })
526  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
527  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
528  .Default([&](auto op) { return std::nullopt; });
529 }
530 
531 /// Check whether `outputOperand` is a reduction with a single combiner
532 /// operation. Return the combiner operation of the reduction. Return
533 /// nullptr otherwise. Multiple reduction operations would impose an
534 /// ordering between reduction dimensions and is currently unsupported in
535 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
536 /// max(min(X))
537 // TODO: use in LinalgOp verification, there is a circular dependency atm.
538 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
539  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
540  unsigned outputPos =
541  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
542  // Only single combiner operations are supported for now.
543  SmallVector<Operation *, 4> combinerOps;
544  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
545  combinerOps.size() != 1)
546  return nullptr;
547 
548  // Return the combiner operation.
549  return combinerOps[0];
550 }
551 
552 /// Broadcast `value` to a vector of `shape` if possible. Return value
553 /// otherwise.
554 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
555  auto dstVecType = dyn_cast<VectorType>(dstType);
556  // If no shape to broadcast to, just return `value`.
557  if (dstVecType.getRank() == 0)
558  return value;
559  if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
561  return value;
562  Location loc = b.getInsertionPoint()->getLoc();
563  return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
564 }
565 
566 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
567 /// assumes that `reductionOp` has two operands and one of them is the reduction
568 /// initial value.buildMultiDimReduce
569 // Note: this is a true builder that notifies the OpBuilder listener.
570 // TODO: Consider moving as a static helper on the ReduceOp.
572  Value valueToReduce, Value acc,
573  ArrayRef<bool> dimsToMask) {
574  auto maybeKind = getCombinerOpKind(reduceOp);
575  assert(maybeKind && "Failed precondition: could not get reduction kind");
576  return b.create<vector::MultiDimReductionOp>(
577  reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
578 }
579 
580 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
581  return llvm::to_vector(
582  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
583 }
584 
585 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
586 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
587 /// currently being vectorized. If `dest` has null rank, build an memref.store.
588 /// Return the produced value or null if no value is produced.
589 // Note: this is a true builder that notifies the OpBuilder listener.
590 // TODO: Consider moving as a static helper on the ReduceOp.
591 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
592  OpOperand *outputOperand,
593  VectorizationState &state) {
594  Location loc = value.getLoc();
595  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
596  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
597 
598  // Compute the vector type of the value to store. This type should be an
599  // identity or projection of the canonical vector type without any permutation
600  // applied, given that any permutation in a transfer write happens as part of
601  // the write itself.
603  opOperandMap.getContext(), opOperandMap.getNumInputs(),
604  [&](AffineDimExpr dimExpr) -> bool {
605  return llvm::is_contained(opOperandMap.getResults(), dimExpr);
606  });
607  auto vectorType = state.getCanonicalVecType(
608  getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
609 
610  Operation *write;
611  if (vectorType.getRank() > 0) {
612  AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
613  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
614  rewriter.create<arith::ConstantIndexOp>(loc, 0));
615  value = broadcastIfNeeded(rewriter, value, vectorType);
616  assert(value.getType() == vectorType && "Incorrect type");
617  write = rewriter.create<vector::TransferWriteOp>(
618  loc, value, outputOperand->get(), indices, writeMap);
619  } else {
620  // 0-d case is still special: do not invert the reindexing writeMap.
621  if (!isa<VectorType>(value.getType()))
622  value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
623  assert(value.getType() == vectorType && "Incorrect type");
624  write = rewriter.create<vector::TransferWriteOp>(
625  loc, value, outputOperand->get(), ValueRange{});
626  }
627 
628  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
629 
630  // If masked, set in-bounds to true. Masking guarantees that the access will
631  // be in-bounds.
632  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
633  auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
634  SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
635  maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
636  }
637 
638  LDBG("vectorized op: " << *write << "\n");
639  if (!write->getResults().empty())
640  return write->getResult(0);
641  return Value();
642 }
643 
644 // Custom vectorization precondition function type. This is intented to be used
645 // with CustomVectorizationHook. Returns success if the corresponding custom
646 // hook can vectorize the op.
648  std::function<LogicalResult(Operation *, bool)>;
649 
650 // Custom vectorization function type. Produce a vector form of Operation*
651 // assuming all its vectorized operands are already in the IRMapping.
652 // Return nullptr if the Operation cannot be vectorized.
654  std::function<VectorizationResult(Operation *, const IRMapping &)>;
655 
656 /// Helper function to vectorize the terminator of a `linalgOp`. New result
657 /// vector values are appended to `newResults`. Return
658 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
659 /// should not try to map produced operations and instead return the results
660 /// using the `newResults` vector making them available to the vectorization
661 /// algorithm for RAUW. This function is meant to be used as a
662 /// CustomVectorizationHook.
663 static VectorizationResult
665  const IRMapping &bvm, VectorizationState &state,
666  LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
667  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
668  if (!yieldOp)
670  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
671  // TODO: Scan for an opportunity for reuse.
672  // TODO: use a map.
673  Value vectorValue = bvm.lookup(output.value());
674  Value newResult =
675  buildVectorWrite(rewriter, vectorValue,
676  linalgOp.getDpsInitOperand(output.index()), state);
677  if (newResult)
678  newResults.push_back(newResult);
679  }
680 
682 }
683 
684 /// Helper function to vectorize the index operations of a `linalgOp`. Return
685 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
686 /// should map the produced operations. This function is meant to be used as a
687 /// CustomVectorizationHook.
689  VectorizationState &state,
690  Operation *op,
691  LinalgOp linalgOp) {
692  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
693  if (!indexOp)
695  auto loc = indexOp.getLoc();
696  // Compute the static loop sizes of the index op.
697  auto targetShape = state.getCanonicalVecShape();
698  // Compute a one-dimensional index vector for the index op dimension.
699  auto constantSeq =
700  llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
701  auto indexSteps = rewriter.create<arith::ConstantOp>(
702  loc, rewriter.getIndexVectorAttr(constantSeq));
703  // Return the one-dimensional index vector if it lives in the trailing
704  // dimension of the iteration space since the vectorization algorithm in this
705  // case can handle the broadcast.
706  if (indexOp.getDim() == targetShape.size() - 1)
708  // Otherwise permute the targetShape to move the index dimension last,
709  // broadcast the one-dimensional index vector to the permuted shape, and
710  // finally transpose the broadcasted index vector to undo the permutation.
711  auto permPattern =
712  llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
713  std::swap(permPattern[indexOp.getDim()], permPattern.back());
714  auto permMap =
715  AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
716 
717  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
718  loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
719  indexSteps);
720  SmallVector<int64_t> transposition =
721  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
722  std::swap(transposition.back(), transposition[indexOp.getDim()]);
723  auto transposeOp =
724  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
725  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
726 }
727 
728 /// Helper function to check if the tensor.extract can be vectorized by the
729 /// custom hook vectorizeTensorExtract.
730 static LogicalResult
731 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
732  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
733  if (!extractOp)
734  return failure();
735 
736  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
737  return failure();
738 
739  // Check the index type, but only for non 0-d tensors (for which we do need
740  // access indices).
741  if (not extractOp.getIndices().empty()) {
742  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
743  return failure();
744  }
745 
746  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
747  return !VectorType::isValidElementType(type);
748  })) {
749  return failure();
750  }
751 
752  return success();
753 }
754 
755 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
756 /// generated from `tensor.extract`. The offset is calculated as follows
757 /// (example using scalar values):
758 ///
759 /// offset = extractOp.indices[0]
760 /// for (i = 1; i < numIndices; i++)
761 /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
762 ///
763 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
764 /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
766  VectorizationState &state,
767  tensor::ExtractOp extractOp,
768  const IRMapping &bvm) {
769  // The vector of indices for GatherOp should be shaped as the output vector.
770  auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
771  auto loc = extractOp.getLoc();
772 
773  Value offset = broadcastIfNeeded(
774  rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
775 
776  const size_t numIndices = extractOp.getIndices().size();
777  for (size_t i = 1; i < numIndices; i++) {
778  Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
779 
780  auto dimSize = broadcastIfNeeded(
781  rewriter,
782  rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
783  indexVecType);
784 
785  offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
786 
787  auto extractOpIndex = broadcastIfNeeded(
788  rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
789 
790  offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
791  }
792 
793  return offset;
794 }
795 
797 
798 /// Checks whether /p val can be used for calculating a loop invariant index.
799 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
800 
801  auto targetShape = linalgOp.getStaticLoopRanges();
802  assert(((llvm::count_if(targetShape,
803  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
804  "n-D vectors are not yet supported");
805  assert(targetShape.back() != 1 &&
806  "1-D vectors with the trailing dim eqaual 1 are not yet supported");
807 
808  // Blocks outside _this_ linalg.generic are effectively loop invariant.
809  // However, analysing block arguments for _this_ linalg.generic Op is a bit
810  // tricky. Just bail out in the latter case.
811  // TODO: We could try analysing the corresponding affine map here.
812  auto *block = linalgOp.getBlock();
813  if (isa<BlockArgument>(val))
814  return llvm::all_of(block->getArguments(),
815  [&val](Value v) { return (v != val); });
816 
817  Operation *defOp = val.getDefiningOp();
818  assert(defOp && "This is neither a block argument nor an operation result");
819 
820  // IndexOp is loop invariant as long as its result remains constant across
821  // iterations. Given the assumptions on the loop ranges above, only the
822  // trailing loop dim ever changes.
823  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
824  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
825  return (indexOp.getDim() != trailingLoopDim);
826 
827  auto *ancestor = block->findAncestorOpInBlock(*defOp);
828 
829  // Values define outside `linalgOp` are loop invariant.
830  if (!ancestor)
831  return true;
832 
833  // Values defined inside `linalgOp`, which are constant, are loop invariant.
834  if (isa<arith::ConstantOp>(ancestor))
835  return true;
836 
837  bool result = true;
838  for (auto op : ancestor->getOperands())
839  result &= isLoopInvariantIdx(linalgOp, op);
840 
841  return result;
842 }
843 
844 /// Check whether \p val could be used for calculating the trailing index for a
845 /// contiguous load operation.
846 ///
847 /// There are currently 3 types of values that are allowed here:
848 /// 1. loop-invariant values,
849 /// 2. values that increment by 1 with every loop iteration,
850 /// 3. results of basic arithmetic operations (linear and continuous)
851 /// involving 1., 2. and 3.
852 /// This method returns True if indeed only such values are used in calculating
853 /// \p val.
854 ///
855 /// Additionally, the trailing index for a contiguous load operation should
856 /// increment by 1 with every loop iteration, i.e. be based on:
857 /// * `linalg.index <dim>` ,
858 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
859 /// updated to `true` when such an op is found.
860 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
861  bool &foundIndexOp) {
862 
863  auto targetShape = linalgOp.getStaticLoopRanges();
864  assert(((llvm::count_if(targetShape,
865  [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
866  "n-D vectors are not yet supported");
867  assert(targetShape.back() != 1 &&
868  "1-D vectors with the trailing dim 1 are not yet supported");
869 
870  // Blocks outside _this_ linalg.generic are effectively loop invariant.
871  // However, analysing block arguments for _this_ linalg.generic Op is a bit
872  // tricky. Just bail out in the latter case.
873  // TODO: We could try analysing the corresponding affine map here.
874  auto *block = linalgOp.getBlock();
875  if (isa<BlockArgument>(val))
876  return llvm::all_of(block->getArguments(),
877  [&val](Value v) { return (v != val); });
878 
879  Operation *defOp = val.getDefiningOp();
880  assert(defOp && "This is neither a block argument nor an operation result");
881 
882  // Given the assumption on the loop ranges above, only the trailing loop
883  // index is not constant.
884  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
885  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
886  foundIndexOp = (indexOp.getDim() == trailingLoopDim);
887  return true;
888  }
889 
890  auto *ancestor = block->findAncestorOpInBlock(*defOp);
891 
892  if (!ancestor)
893  return false;
894 
895  // Conservatively reject Ops that could lead to indices with stride other
896  // than 1.
897  if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
898  return false;
899 
900  bool result = false;
901  for (auto op : ancestor->getOperands())
902  result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
903 
904  return result;
905 }
906 
907 /// Infer the memory access pattern for the input ExtractOp
908 ///
909 /// Based on the operation shapes and indices (usually based on the iteration
910 /// space of the parent `linalgOp` operation), decides whether the input
911 /// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
912 /// gather load.
913 ///
914 /// Note that it is always safe to use gather load operations for contiguous
915 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
916 /// that `extractOp` is a gather load.
918 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
919  LinalgOp &linalgOp) {
920 
921  auto targetShape = linalgOp.getStaticLoopRanges();
922  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
923 
924  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
925  if (inputShape.getShape().empty())
927 
928  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
929  // gather load.
930  // TODO: Relax this condition.
931  if (linalgOp.hasDynamicShape())
933 
934  // 1. Assume that it's a gather load when reading _into_:
935  // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
936  // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
937  // TODO: Relax these conditions.
938  // FIXME: This condition assumes non-dynamic sizes.
939  if ((llvm::count_if(targetShape,
940  [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
941  targetShape.back() == 1)
943 
944  // 2. Assume that it's a gather load when reading _from_ a tensor for which
945  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
946  // TODO: Relax this condition.
947  if (inputShape.getShape().back() == 1)
949 
950  bool leadingIdxsLoopInvariant = true;
951 
952  // 3. Analyze the leading indices of `extractOp`.
953  // Look at the way each index is calculated and decide whether it is suitable
954  // for a contiguous load, i.e. whether it's loop invariant.
955  auto indices = extractOp.getIndices();
956  auto leadIndices = indices.drop_back(1);
957 
958  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
959  if (inputShape.getShape()[i] == 1)
960  continue;
961 
962  leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
963  }
964 
965  if (!leadingIdxsLoopInvariant) {
966  LDBG("Found gather load: " << extractOp);
968  }
969 
970  // 4. Analyze the trailing index for `extractOp`.
971  // At this point we know that the leading indices are loop invariant. This
972  // means that is potentially a scalar or a contiguous load. We can decide
973  // based on the trailing idx.
974  auto extractOpTrailingIdx = indices.back();
975 
976  // 4a. Scalar broadcast load
977  // If the trailing index is loop invariant then this is a scalar load.
978  if (leadingIdxsLoopInvariant &&
979  isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
980  LDBG("Found scalar broadcast load: " << extractOp);
981 
983  }
984 
985  // 4b. Contiguous loads
986  // The trailing `extractOp` index should increment with every loop iteration.
987  // This effectively means that it must be based on the trailing loop index.
988  // This is what the following bool captures.
989  bool foundIndexOp = false;
990  bool isContiguousLoad =
991  isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
992  isContiguousLoad &= foundIndexOp;
993 
994  if (isContiguousLoad) {
995  LDBG("Found contigous load: " << extractOp);
997  }
998 
999  // 5. Fallback case - gather load.
1000  LDBG("Found gather load: " << extractOp);
1002 }
1003 
1004 /// Helper function to vectorize the tensor.extract operations. Returns
1005 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1006 /// should map the produced operations. This function is meant to be used as a
1007 /// CustomVectorizationHook.
1008 static VectorizationResult
1010  Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1011  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1012  if (!extractOp)
1014  auto loc = extractOp.getLoc();
1015 
1016  // Compute the static loop sizes of the extract op.
1017  auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1018  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1019  loc,
1020  DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1021  /*value=*/true));
1022  auto passThruConstantOp =
1023  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1024 
1025  // Base indices are currently set to 0. We will need to re-visit if more
1026  // generic scenarios are to be supported.
1027  SmallVector<Value> baseIndices(
1028  extractOp.getIndices().size(),
1029  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1030 
1031  VectorMemoryAccessKind memAccessKind =
1032  getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1033 
1034  // 1. Handle gather access
1035  if (memAccessKind == VectorMemoryAccessKind::Gather) {
1036  Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1037 
1038  // Generate the gather load
1039  Operation *gatherOp = rewriter.create<vector::GatherOp>(
1040  loc, resultType, extractOp.getTensor(), baseIndices, offset,
1041  maskConstantOp, passThruConstantOp);
1042  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1043 
1044  LDBG("Vectorised as gather load: " << extractOp << "\n");
1046  }
1047 
1048  // 2. Handle:
1049  // a. scalar loads + broadcast,
1050  // b. contiguous loads.
1051  // Both cases use vector.transfer_read.
1052 
1053  // Collect indices for `vector.transfer_read`. At this point, the indices will
1054  // either be scalars or would have been broadcast to vectors matching the
1055  // result type. For indices that are vectors, there are two options:
1056  // * for non-trailing indices, all elements are identical (contiguous
1057  // loads are identified by looking for non-trailing indices that are
1058  // invariant with respect to the corresponding linalg.generic), or
1059  // * for trailing indices, the index vector will contain values with stride
1060  // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1061  // needed.
1062  // This means that
1063  // * for scalar indices - just re-use it,
1064  // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1065  // (0th) element and use that.
1066  SmallVector<Value> transferReadIdxs;
1067  auto resTrailingDim = resultType.getShape().back();
1068  auto zero = rewriter.create<arith::ConstantOp>(
1069  loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1070  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1071  auto idx = bvm.lookup(extractOp.getIndices()[i]);
1072  if (idx.getType().isIndex()) {
1073  transferReadIdxs.push_back(idx);
1074  continue;
1075  }
1076 
1077  auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1078  loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
1079  bvm.lookup(extractOp.getIndices()[i]));
1080  transferReadIdxs.push_back(
1081  rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1082  }
1083 
1084  // `tensor.extract_element` is always in-bounds, hence the following holds.
1085  auto dstRank = resultType.getRank();
1086  auto srcRank = extractOp.getTensor().getType().getRank();
1087  SmallVector<bool> inBounds(dstRank, true);
1088 
1089  // 2a. Handle scalar broadcast access.
1090  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1091  MLIRContext *ctx = rewriter.getContext();
1092  SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1093  auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1094 
1095  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1096  loc, resultType, extractOp.getTensor(), transferReadIdxs,
1097  permutationMap, inBounds);
1098 
1099  LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1100  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1101  }
1102 
1103  // 2b. Handle contiguous access.
1104  auto permutationMap = AffineMap::getMinorIdentityMap(
1105  srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1106 
1107  int32_t rankDiff = dstRank - srcRank;
1108  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1109  // dims so that the ranks match. This is done by extending the map with 0s.
1110  // For example, for dstRank = 3, srcRank = 2, the following map created
1111  // above:
1112  // (d0, d1) --> (d0, d1)
1113  // is extended as:
1114  // (d0, d1) --> (0, d0, d1)
1115  while (rankDiff > 0) {
1116  permutationMap = permutationMap.insertResult(
1117  mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1118  rankDiff--;
1119  }
1120 
1121  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1122  loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1123  inBounds);
1124 
1125  LDBG("Vectorised as contiguous load: " << extractOp);
1126  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1127 }
1128 
1129 /// Emit reduction operations if the shapes of the value to reduce is different
1130 /// that the result shape.
1131 // Note: this is a true builder that notifies the OpBuilder listener.
1132 // TODO: Consider moving as a static helper on the ReduceOp.
1133 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1134  Value reduceValue, Value initialValue,
1135  const IRMapping &bvm) {
1136  Value reduceVec = bvm.lookup(reduceValue);
1137  Value outputVec = bvm.lookup(initialValue);
1138  auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1139  auto outputType = dyn_cast<VectorType>(outputVec.getType());
1140  // Reduce only if needed as the value may already have been reduce for
1141  // contraction vectorization.
1142  if (!reduceType ||
1143  (outputType && reduceType.getShape() == outputType.getShape()))
1144  return nullptr;
1145  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1146  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1147 }
1148 
1149 /// Generic vectorization for a single operation `op`, given already vectorized
1150 /// operands carried by `bvm`. Vectorization occurs as follows:
1151 /// 1. Try to apply any of the `customVectorizationHooks` and return its
1152 /// result on success.
1153 /// 2. Clone any constant in the current scope without vectorization: each
1154 /// consumer of the constant will later determine the shape to which the
1155 /// constant needs to be broadcast to.
1156 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1157 /// of the `customVectorizationHooks` to cover such cases.
1158 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1159 /// operand of maximal rank. Other operands have smaller rank and are
1160 /// broadcast accordingly. It is assumed this broadcast is always legal,
1161 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1162 ///
1163 /// This function assumes all operands of `op` have been vectorized and are in
1164 /// the `bvm` mapping. As a consequence, this function is meant to be called on
1165 /// a topologically-sorted list of ops.
1166 /// This function does not update `bvm` but returns a VectorizationStatus that
1167 /// instructs the caller what `bvm` update needs to occur.
1168 static VectorizationResult
1170  LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1171  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1172  LDBG("vectorize op " << *op << "\n");
1173 
1174  // 1. Try to apply any CustomVectorizationHook.
1175  if (!customVectorizationHooks.empty()) {
1176  for (auto &customFunc : customVectorizationHooks) {
1177  VectorizationResult result = customFunc(op, bvm);
1178  if (result.status == VectorizationStatus::Failure)
1179  continue;
1180  return result;
1181  }
1182  }
1183 
1184  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1185  // Clone so that the constant is not confined to the linalgOp block .
1186  if (isa<arith::ConstantOp, func::ConstantOp>(op))
1187  return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1188 
1189  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1192 
1193  // 4 . Check if the operation is a reduction.
1194  SmallVector<std::pair<Value, Value>> reductionOperands;
1195  for (Value operand : op->getOperands()) {
1196  auto blockArg = dyn_cast<BlockArgument>(operand);
1197  if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1198  blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1199  continue;
1200  SmallVector<Operation *> reductionOps;
1201  Value reduceValue = matchReduction(
1202  linalgOp.getRegionOutputArgs(),
1203  blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1204  if (!reduceValue)
1205  continue;
1206  reductionOperands.push_back(std::make_pair(reduceValue, operand));
1207  }
1208  if (!reductionOperands.empty()) {
1209  assert(reductionOperands.size() == 1);
1210  Operation *reduceOp =
1211  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1212  reductionOperands[0].second, bvm);
1213  if (reduceOp)
1215  }
1216 
1217  // 5. Generic vectorization path for ElementwiseMappable ops.
1218  // a. Get the first max ranked shape.
1219  VectorType firstMaxRankedType;
1220  for (Value operand : op->getOperands()) {
1221  auto vecOperand = bvm.lookup(operand);
1222  assert(vecOperand && "Vector operand couldn't be found");
1223 
1224  auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1225  if (vecType && (!firstMaxRankedType ||
1226  firstMaxRankedType.getRank() < vecType.getRank()))
1227  firstMaxRankedType = vecType;
1228  }
1229  // b. Broadcast each op if needed.
1230  SmallVector<Value> vecOperands;
1231  for (Value scalarOperand : op->getOperands()) {
1232  Value vecOperand = bvm.lookup(scalarOperand);
1233  assert(vecOperand && "Vector operand couldn't be found");
1234 
1235  if (firstMaxRankedType) {
1236  auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1237  getElementTypeOrSelf(vecOperand.getType()),
1238  firstMaxRankedType.getScalableDims());
1239  vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1240  } else {
1241  vecOperands.push_back(vecOperand);
1242  }
1243  }
1244  // c. for elementwise, the result is the vector with the firstMaxRankedShape
1245  SmallVector<Type> resultTypes;
1246  for (Type resultType : op->getResultTypes()) {
1247  resultTypes.push_back(
1248  firstMaxRankedType
1249  ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1250  firstMaxRankedType.getScalableDims())
1251  : resultType);
1252  }
1253  // d. Build and return the new op.
1254  return VectorizationResult{
1256  rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1257  resultTypes, op->getAttrs())};
1258 }
1259 
1260 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1261 /// vector form. Generic vectorization proceeds as follows:
1262 /// 1. Verify the `linalgOp` has one non-empty region.
1263 /// 2. Values defined above the region are mapped to themselves and will be
1264 /// broadcasted on a per-need basis by their consumers.
1265 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1266 /// load).
1267 /// TODO: Reuse opportunities for RAR dependencies.
1268 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1269 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1270 /// iteration indices.
1271 /// 5. Iteratively call vectorizeOneOp on the region operations.
1272 ///
1273 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1274 /// performed to the maximal common vector size implied by the `linalgOp`
1275 /// iteration space. This eager broadcasting is introduced in the
1276 /// permutation_map of the vector.transfer_read operations. The eager
1277 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1278 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1279 /// the absence of good canonicalizations, the amount of work increases.
1280 /// This is not deemed a problem as we expect canonicalizations and foldings to
1281 /// aggressively clean up the useless work.
1282 static LogicalResult
1284  LinalgOp linalgOp,
1285  SmallVectorImpl<Value> &newResults) {
1286  LDBG("Vectorizing operation as linalg generic\n");
1287  Block *block = linalgOp.getBlock();
1288 
1289  // 2. Values defined above the region can only be broadcast for now. Make them
1290  // map to themselves.
1291  IRMapping bvm;
1292  SetVector<Value> valuesSet;
1293  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1294  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1295 
1296  if (linalgOp.getNumDpsInits() == 0)
1297  return failure();
1298 
1299  // 3. Turn all BBArgs into vector.transfer_read / load.
1300  Location loc = linalgOp.getLoc();
1301  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1302  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1303  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1304  if (linalgOp.isScalar(opOperand)) {
1305  bvm.map(bbarg, opOperand->get());
1306  continue;
1307  }
1308 
1309  // 3.a. Convert the indexing map for this input/output to a transfer read
1310  // permutation map and masking map.
1311  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1312 
1313  // Remove zeros from indexing map to use it as masking map.
1314  SmallVector<int64_t> zeroPos;
1315  auto results = indexingMap.getResults();
1316  for (const auto &result : llvm::enumerate(results)) {
1317  if (isa<AffineConstantExpr>(result.value())) {
1318  zeroPos.push_back(result.index());
1319  }
1320  }
1321  AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1322 
1323  AffineMap readMap;
1324  VectorType readType;
1325  Type elemType = getElementTypeOrSelf(opOperand->get());
1326  if (linalgOp.isDpsInput(opOperand)) {
1327  // 3.a.i. For input reads we use the canonical vector shape.
1328  readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1329  readType = state.getCanonicalVecType(elemType);
1330  } else {
1331  // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1332  // reductions), the vector shape is computed by mapping the canonical
1333  // vector shape to the output domain and back to the canonical domain.
1334  readMap = inversePermutation(reindexIndexingMap(indexingMap));
1335  readType =
1336  state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1337  }
1338 
1339  SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1340 
1341  Operation *read = rewriter.create<vector::TransferReadOp>(
1342  loc, readType, opOperand->get(), indices, readMap);
1343  read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1344  Value readValue = read->getResult(0);
1345 
1346  // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1347  // will be in-bounds.
1348  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1349  SmallVector<bool> inBounds(readType.getRank(), true);
1350  cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1351  .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1352  }
1353 
1354  // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1355  // TODO: remove this.
1356  if (readType.getRank() == 0)
1357  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1358 
1359  LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1360  << "\n");
1361  bvm.map(bbarg, readValue);
1362  bvm.map(opOperand->get(), readValue);
1363  }
1364 
1366  // 4a. Register CustomVectorizationHook for yieldOp.
1367  CustomVectorizationHook vectorizeYield =
1368  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1369  return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1370  };
1371  hooks.push_back(vectorizeYield);
1372 
1373  // 4b. Register CustomVectorizationHook for indexOp.
1374  CustomVectorizationHook vectorizeIndex =
1375  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1376  return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1377  };
1378  hooks.push_back(vectorizeIndex);
1379 
1380  // 4c. Register CustomVectorizationHook for extractOp.
1381  CustomVectorizationHook vectorizeExtract =
1382  [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1383  return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1384  };
1385  hooks.push_back(vectorizeExtract);
1386 
1387  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1388  for (Operation &op : block->getOperations()) {
1389  VectorizationResult result =
1390  vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1391  if (result.status == VectorizationStatus::Failure) {
1392  LDBG("failed to vectorize: " << op << "\n");
1393  return failure();
1394  }
1395  if (result.status == VectorizationStatus::NewOp) {
1396  Operation *maybeMaskedOp =
1397  state.maskOperation(rewriter, result.newOp, linalgOp);
1398  LDBG("New vector op: " << *maybeMaskedOp << "\n");
1399  bvm.map(op.getResults(), maybeMaskedOp->getResults());
1400  }
1401  }
1402 
1403  return success();
1404 }
1405 
1406 /// Given a tensor::PackOp, return the `dest` shape before any packing
1407 /// permutations.
1408 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1409  ArrayRef<int64_t> destShape) {
1410  return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1411 }
1412 
1413 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1414 /// create an empty destination tensor and create a TransferWriteOp from the
1415 /// input to the empty tensor. If the destination shape is not the same as the
1416 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1418 /// inBounds attribute of the transfer write op instead of masking.
1420  Value input,
1421  SmallVector<OpFoldResult> destSizes,
1422  ArrayRef<int64_t> inputVectorSizes,
1423  bool useInBoundsInsteadOfMasking) {
1424 
1425  auto inputType = cast<VectorType>(input.getType());
1426  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1427  inputType.getElementType());
1428  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1429  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1430  auto destShape = cast<ShapedType>(dest.getType()).getShape();
1431  SmallVector<bool> inBoundsVal(rank, true);
1432  if (useInBoundsInsteadOfMasking) {
1433  // Update the inBounds attribute.
1434  for (unsigned i = 0; i < rank; i++)
1435  inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1436  !ShapedType::isDynamic(destShape[i]);
1437  }
1438  Operation *write = builder.create<vector::TransferWriteOp>(
1439  loc,
1440  /*vector=*/input,
1441  /*source=*/dest,
1442  /*indices=*/SmallVector<Value>(rank, zero),
1443  /*inBounds=*/inBoundsVal);
1444  assert(llvm::none_of(
1445  destShape.drop_front(inputVectorSizes.size()),
1446  [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1447  "Only dims aligned with inputVectorSizes may be dynamic");
1448  if (useInBoundsInsteadOfMasking)
1449  return write;
1450  bool needMaskForWrite = !llvm::equal(
1451  inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1452  if (needMaskForWrite) {
1453  SmallVector<int64_t> writeMaskShape;
1454  writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1455  writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1456  destShape.end());
1457  auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1458  Value maskForWrite =
1459  builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1460  write = mlir::vector::maskOperation(builder, write, maskForWrite);
1461  }
1462  return write;
1463 }
1464 
1465 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1466 /// padding value and (3) input vector sizes into:
1467 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1468 /// As in the following example:
1469 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1470 /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1471 ///
1472 /// This pack would be vectorized to:
1473 ///
1474 /// %load = vector.mask %mask {
1475 /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1476 /// {in_bounds = [true, true, true]} :
1477 /// tensor<32x7x16xf32>, vector<32x8x16xf32>
1478 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1479 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1480 /// to vector<32x4x2x1x16xf32>
1481 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1482 /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1483 /// %write = vector.transfer_write %transpose,
1484 /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1485 /// {in_bounds = [true, true, true, true, true]}
1486 /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1487 ///
1488 /// If the (3) input vector sizes are not provided, the vector sizes are
1489 /// determined by the result tensor shape. Also, we update the inBounds
1490 /// attribute instead of masking.
1491 static LogicalResult
1492 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1493  ArrayRef<int64_t> inputVectorSizes,
1494  SmallVectorImpl<Value> &newResults) {
1495  OpBuilder::InsertionGuard g(rewriter);
1496  rewriter.setInsertionPoint(packOp);
1497 
1498  Location loc = packOp.getLoc();
1499  auto padValue = packOp.getPaddingValue();
1500  if (!padValue) {
1501  padValue = rewriter.create<arith::ConstantOp>(
1502  loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1503  }
1504  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1505  LogicalResult status =
1506  cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1507  .reifyResultShapes(rewriter, reifiedReturnShapes);
1508  (void)status; // prevent unused variable warning on non-assert builds.
1509  assert(succeeded(status) && "failed to reify result shapes");
1510 
1511  // If the input vector sizes are not provided, then the vector sizes are
1512  // determined by the result tensor shape. In case the vector sizes aren't
1513  // provided, we update the inBounds attribute instead of masking.
1514  bool useInBoundsInsteadOfMasking = false;
1515  if (inputVectorSizes.empty()) {
1516  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1517  inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1518  useInBoundsInsteadOfMasking = true;
1519  }
1520 
1521  // Create masked TransferReadOp.
1522  SmallVector<int64_t> inputShape(inputVectorSizes);
1523  auto innerTiles = packOp.getStaticInnerTiles();
1524  auto innerDimsPos = packOp.getInnerDimsPos();
1525  auto outerDimsPerm = packOp.getOuterDimsPerm();
1526  if (!outerDimsPerm.empty())
1527  applyPermutationToVector(inputShape,
1528  invertPermutationVector(outerDimsPerm));
1529  for (auto [idx, size] : enumerate(innerTiles))
1530  inputShape[innerDimsPos[idx]] *= size;
1531  auto maskedRead = vector::createReadOrMaskedRead(
1532  rewriter, loc, packOp.getSource(), inputShape, padValue,
1533  useInBoundsInsteadOfMasking);
1534 
1535  // Create ShapeCastOp.
1536  SmallVector<int64_t> destShape(inputVectorSizes);
1537  destShape.append(innerTiles.begin(), innerTiles.end());
1538  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1539  packOp.getDestType().getElementType());
1540  auto shapeCastOp =
1541  rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1542 
1543  // Create TransposeOp.
1544  auto destPermutation =
1546  auto transposeOp = rewriter.create<vector::TransposeOp>(
1547  loc, shapeCastOp.getResult(), destPermutation);
1548 
1549  // Create TransferWriteOp.
1551  rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1552  inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1553  newResults.push_back(write->getResult(0));
1554  return success();
1555 }
1556 
1557 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1558 /// Vector::TransferReadOp - Reads a vector from the source tensor
1559 /// vector::TransposeOp - Transpose the Source tensor
1560 /// ShapeCastOp - Reshape the data based on the target.
1561 /// vector::TransferWriteOp. - Write the result vector back to the destination
1562 /// tensor.
1563 /// If the vector sizes are not provided:
1564 /// * the vector sizes are determined by the input operand and attributes,
1565 /// * update the inBounds attribute instead of masking.
1566 static LogicalResult
1567 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1568  ArrayRef<int64_t> inputVectorSizes,
1569  SmallVectorImpl<Value> &newResults) {
1570 
1571  OpBuilder::InsertionGuard g(rewriter);
1572  rewriter.setInsertionPoint(unpackOp);
1573 
1574  RankedTensorType unpackTensorType = unpackOp.getSourceType();
1575 
1576  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1577  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1578  ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1579  bool useInBoundsInsteadOfMasking = false;
1580  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1581 
1582  auto destSize = unpackOp.getDestRank();
1583 
1584  if (!inputVectorSizes.empty())
1585  assert(inputVectorSizes.size() == destSize &&
1586  "Incorrect number of input vector sizes");
1587 
1588  // vectorSizes is the shape of the vector that will be used to do final
1589  // write on the destination tensor. It is set like this: Let's say the
1590  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1591  // Thus:
1592  // 1. vectorSizes = sourceShape.take_front(N)
1593  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1594  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1595  // innerTiles attribute value.
1596  SmallVector<int64_t> vectorSizes(inputVectorSizes);
1597  if (vectorSizes.empty()) {
1598  llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1599  if (!outerDimsPerm.empty())
1600  applyPermutationToVector(vectorSizes, outerDimsPerm);
1601  for (auto [i, pos] : llvm::enumerate(innerDimPos))
1602  vectorSizes[pos] *= innerTiles[i];
1603 
1604  useInBoundsInsteadOfMasking = true;
1605  }
1606 
1607  // readVectorSizes is the size of tensor used to read and apply mask. It is
1608  // set like this: Let's say the vectorSize (VS) array is size 'N' and
1609  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1610  // size M-N
1611  // Thus:
1612  // - initially: readVectorSizes = vectorInputSizes
1613  // - Divide all the readMaskShape locations pointed by innerDimPos
1614  // by the innerTileSize attribute value.
1615  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1616  // - Append the remaining shape from SS
1617  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1618  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1619  // 128] and outer_dims_perm is [1, 0] then read shape is:
1620  // ReadVectorSizes(initial): [512, 128]
1621  // Final Value(after innerDim Adjustment): [512/32, 128/16]
1622  // = [16, 8]
1623  // After applying outer_dims_perm: [8, 16]
1624  // After appending the rest of the sourceShape: [8, 16, 32, 16]
1625 
1626  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1627 
1628  for (auto [index, size] : enumerate(innerTiles)) {
1629  readVectorSizes[innerDimPos[index]] =
1630  llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1631  }
1632  if (!outerDimsPerm.empty()) {
1633  applyPermutationToVector(readVectorSizes, outerDimsPerm);
1634  }
1635  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1636  sourceShape.end());
1637 
1638  ReifiedRankedShapedTypeDims reifiedRetShapes;
1639  LogicalResult status =
1640  cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1641  .reifyResultShapes(rewriter, reifiedRetShapes);
1642  if (status.failed()) {
1643  LDBG("Unable to reify result shapes of " << unpackOp);
1644  return failure();
1645  }
1646  Location loc = unpackOp->getLoc();
1647 
1648  auto padValue = rewriter.create<arith::ConstantOp>(
1649  loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1650 
1651  // Read result, mask if necessary. If transferReadOp shape is not equal
1652  // to shape of source, then a mask is necessary.
1654  rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1655  /*useInBoundsInsteadOfMasking=*/false);
1656 
1657  PackingMetadata packMetadata;
1658  SmallVector<int64_t> lastDimToInsertPosPerm =
1659  tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1660  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1661  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1662  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1663  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1664  RankedTensorType stripMineTensorType =
1665  RankedTensorType::get(stripMineShape, stripMineElemType);
1666  // Transpose the appropriate rows to match output.
1667  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1668  loc, readResult, lastDimToInsertPosPerm);
1669 
1670  // Collapse the vector to the size required by result.
1671  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1672  stripMineTensorType, packMetadata.reassociations);
1673  mlir::VectorType vecCollapsedType =
1674  VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1675  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1676  loc, vecCollapsedType, transposeOp->getResult(0));
1677 
1678  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1679  // otherwise the validator complains that the mask size is invalid.
1680  SmallVector<int64_t> writeVectorSizes(
1681  unpackOp.getDestType().hasStaticShape()
1682  ? vectorSizes
1683  : shapeCastOp.getResultVectorType().getShape());
1685  rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1686  writeVectorSizes, useInBoundsInsteadOfMasking);
1687  newResults.push_back(write->getResult(0));
1688  return success();
1689 }
1690 
1691 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1692 /// and (3) all-zero lowPad to
1693 /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1694 static LogicalResult
1695 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1696  ArrayRef<int64_t> inputVectorSizes,
1697  SmallVectorImpl<Value> &newResults) {
1698  auto padValue = padOp.getConstantPaddingValue();
1699  Location loc = padOp.getLoc();
1700 
1701  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1702  OpBuilder::InsertionGuard g(rewriter);
1703  rewriter.setInsertionPoint(padOp);
1704 
1705  ReifiedRankedShapedTypeDims reifiedReturnShapes;
1706  LogicalResult status =
1707  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1708  .reifyResultShapes(rewriter, reifiedReturnShapes);
1709  (void)status; // prevent unused variable warning on non-assert builds
1710  assert(succeeded(status) && "failed to reify result shapes");
1711  auto maskedRead = vector::createReadOrMaskedRead(
1712  rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1713  /*useInBoundsInsteadOfMasking=*/false);
1715  rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1716  /*useInBoundsInsteadOfMasking=*/false);
1717  newResults.push_back(write->getResult(0));
1718  return success();
1719 }
1720 
1721 // TODO: probably need some extra checks for reduction followed by consumer
1722 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1724  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1725  LDBG("reduction precondition failed: no reduction iterator\n");
1726  return failure();
1727  }
1728  for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1729  AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1730  if (indexingMap.isPermutation())
1731  continue;
1732 
1733  Operation *reduceOp = matchLinalgReduction(&opOperand);
1734  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1735  LDBG("reduction precondition failed: reduction detection failed\n");
1736  return failure();
1737  }
1738  }
1739  return success();
1740 }
1741 
1742 static LogicalResult
1744  bool flatten1DDepthwiseConv) {
1745  if (flatten1DDepthwiseConv) {
1746  LDBG("Vectorization of flattened convs with dynamic shapes is not "
1747  "supported\n");
1748  return failure();
1749  }
1750 
1751  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1752  LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1753  return failure();
1754  }
1755 
1756  // Support dynamic shapes in 1D depthwise convolution, but only in the
1757  // _channel_ dimension.
1758  Value lhs = conv.getDpsInputOperand(0)->get();
1759  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1760  auto shapeWithoutCh = lhsShape.drop_back(1);
1761  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1762  LDBG("Dynamically-shaped op vectorization precondition failed: only "
1763  "channel dim can be dynamic\n");
1764  return failure();
1765  }
1766 
1767  return success();
1768 }
1769 
1770 static LogicalResult
1772  bool flatten1DDepthwiseConv) {
1773  if (isa<ConvolutionOpInterface>(op.getOperation()))
1774  return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1775 
1776  // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1777  // linalg.copy ops and ops that implement ContractionOpInterface for now.
1778  if (!isElementwise(op) &&
1779  !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1780  op.getOperation()))
1781  return failure();
1782 
1783  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1784  return success();
1785 }
1786 
1787 /// Need to check if the inner-tiles are static/constant.
1788 static LogicalResult
1789 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1790  ArrayRef<int64_t> inputVectorSizes) {
1791 
1792  if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1793  return !getConstantIntValue(res).has_value();
1794  })) {
1795  LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1796  return failure();
1797  }
1798  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1799  bool satisfyEmptyCond = inputVectorSizes.empty() &&
1800  unpackOp.getDestType().hasStaticShape() &&
1801  unpackOp.getSourceType().hasStaticShape();
1802  if (!satisfyEmptyCond &&
1803  failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1804  return failure();
1805 
1806  return success();
1807 }
1808 
1810  LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1811  bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1812  // tensor with dimension of 0 cannot be vectorized.
1813  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1814  return failure();
1815  // Check API contract for input vector sizes.
1816  if (!inputVectorSizes.empty() &&
1817  failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1818  inputVectorSizes)))
1819  return failure();
1820 
1821  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1822  linalgOp, flatten1DDepthwiseConv))) {
1823  LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1824  return failure();
1825  }
1826 
1828 
1829  // Register CustomVectorizationPrecondition for extractOp.
1830  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1831 
1832  // All types in the body should be a supported element type for VectorType.
1833  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1834  // Check if any custom hook can vectorize the inner op.
1835  if (llvm::any_of(
1836  customPreconditions,
1837  [&](const CustomVectorizationPrecondition &customPrecondition) {
1838  return succeeded(
1839  customPrecondition(&innerOp, vectorizeNDExtract));
1840  })) {
1841  continue;
1842  }
1843  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1844  return !VectorType::isValidElementType(type);
1845  })) {
1846  return failure();
1847  }
1848  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1849  return !VectorType::isValidElementType(type);
1850  })) {
1851  return failure();
1852  }
1853  }
1854  if (isElementwise(linalgOp))
1855  return success();
1856 
1857  // TODO: isaConvolutionOpInterface that can also infer from generic
1858  // features. But we will still need stride/dilation attributes that will be
1859  // annoying to reverse-engineer...
1860  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1861  return success();
1862  // TODO: the common vector shape is equal to the static loop sizes only when
1863  // all indexing maps are projected permutations. For convs and stencils the
1864  // logic will need to evolve.
1865  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1866  LDBG("precondition failed: not projected permutations\n");
1867  return failure();
1868  }
1869  if (failed(reductionPreconditions(linalgOp))) {
1870  LDBG("precondition failed: reduction preconditions\n");
1871  return failure();
1872  }
1873  return success();
1874 }
1875 
1876 static LogicalResult
1877 vectorizePackOpPrecondition(tensor::PackOp packOp,
1878  ArrayRef<int64_t> inputVectorSizes) {
1879  auto padValue = packOp.getPaddingValue();
1880  Attribute cstAttr;
1881  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1882  LDBG("pad value is not constant: " << packOp << "\n");
1883  return failure();
1884  }
1885  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1886  bool satisfyEmptyCond = true;
1887  if (inputVectorSizes.empty()) {
1888  if (!packOp.getDestType().hasStaticShape() ||
1889  !packOp.getSourceType().hasStaticShape())
1890  satisfyEmptyCond = false;
1891  }
1892 
1893  if (!satisfyEmptyCond &&
1895  resultTensorShape.take_front(packOp.getSourceRank()),
1896  inputVectorSizes)))
1897  return failure();
1898 
1899  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1900  return !getConstantIntValue(v).has_value();
1901  })) {
1902  LDBG("inner_tiles must be constant: " << packOp << "\n");
1903  return failure();
1904  }
1905 
1906  return success();
1907 }
1908 
1909 static LogicalResult
1910 vectorizePadOpPrecondition(tensor::PadOp padOp,
1911  ArrayRef<int64_t> inputVectorSizes) {
1912  auto padValue = padOp.getConstantPaddingValue();
1913  if (!padValue) {
1914  LDBG("pad value is not constant: " << padOp << "\n");
1915  return failure();
1916  }
1917 
1918  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1919  if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1920  inputVectorSizes)))
1921  return failure();
1922 
1923  if (llvm::any_of(padOp.getLow(), [](Value v) {
1924  std::optional<int64_t> res = getConstantIntValue(v);
1925  return !res.has_value() || res.value() != 0;
1926  })) {
1927  LDBG("low pad must all be zero: " << padOp << "\n");
1928  return failure();
1929  }
1930 
1931  return success();
1932 }
1933 
1934 /// Preconditions for scalable vectors.
1935 static LogicalResult
1937  ArrayRef<int64_t> inputVectorSizes,
1938  ArrayRef<bool> inputScalableVecDims) {
1939  assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1940  "Number of input vector sizes and scalable dims doesn't match");
1941 
1942  if (inputVectorSizes.empty())
1943  return success();
1944 
1945  bool isScalable = inputScalableVecDims.back();
1946  if (!isScalable)
1947  return success();
1948 
1949  // Only element-wise and 1d depthwise conv ops supported in the presence of
1950  // scalable dims.
1951  auto linalgOp = dyn_cast<LinalgOp>(op);
1952  return success(linalgOp && (isElementwise(linalgOp) ||
1953  isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
1954 }
1955 
1957  Operation *op, ArrayRef<int64_t> inputVectorSizes,
1958  ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
1959  bool flatten1DDepthwiseConv) {
1960  if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
1961  inputScalableVecDims)))
1962  return failure();
1963 
1965  .Case<linalg::LinalgOp>([&](auto linalgOp) {
1966  return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1967  vectorizeNDExtract,
1968  flatten1DDepthwiseConv);
1969  })
1970  .Case<tensor::PadOp>([&](auto padOp) {
1971  return vectorizePadOpPrecondition(padOp, inputVectorSizes);
1972  })
1973  .Case<tensor::PackOp>([&](auto packOp) {
1974  return vectorizePackOpPrecondition(packOp, inputVectorSizes);
1975  })
1976  .Case<tensor::UnPackOp>([&](auto unpackOp) {
1977  return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
1978  })
1979  .Default([](auto) { return failure(); });
1980 }
1981 
1982 /// Converts affine.apply Ops to arithmetic operations.
1983 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1984  OpBuilder::InsertionGuard g(rewriter);
1985  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
1986 
1987  for (auto op : make_early_inc_range(toReplace)) {
1988  rewriter.setInsertionPoint(op);
1989  auto expanded = affine::expandAffineExpr(
1990  rewriter, op->getLoc(), op.getAffineMap().getResult(0),
1991  op.getOperands().take_front(op.getAffineMap().getNumDims()),
1992  op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
1993  rewriter.replaceOp(op, expanded);
1994  }
1995 }
1996 
1997 /// Emit a suitable vector form for an operation. If provided,
1998 /// `inputVectorSizes` are used to vectorize this operation.
1999 /// `inputVectorSizes` must match the rank of the iteration space of the
2000 /// operation and the input vector sizes must be greater than or equal to
2001 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2002 /// also allows the vectorization of operations with dynamic shapes.
2004  ArrayRef<int64_t> inputVectorSizes,
2005  ArrayRef<bool> inputScalableVecDims,
2006  bool vectorizeNDExtract,
2007  bool flatten1DDepthwiseConv) {
2008  LDBG("Attempting to vectorize:\n" << *op << "\n");
2009  LDBG("Input vector sizes: ");
2010  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2011  LLVM_DEBUG(llvm::dbgs() << "\n");
2012  LDBG("Input scalable vector dims: ");
2013  LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2014  LLVM_DEBUG(llvm::dbgs() << "\n");
2015 
2016  if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2017  vectorizeNDExtract,
2018  flatten1DDepthwiseConv))) {
2019  LDBG("Vectorization pre-conditions failed\n");
2020  return failure();
2021  }
2022 
2023  // Initialize vectorization state.
2024  VectorizationState state(rewriter);
2025  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2026  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2027  inputScalableVecDims))) {
2028  LDBG("Vectorization state couldn't be initialized\n");
2029  return failure();
2030  }
2031  }
2032 
2033  SmallVector<Value> results;
2034  auto vectorizeResult =
2036  .Case<linalg::LinalgOp>([&](auto linalgOp) {
2037  // TODO: isaConvolutionOpInterface that can also infer from
2038  // generic features. Will require stride/dilation attributes
2039  // inference.
2040  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2042  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2043  flatten1DDepthwiseConv);
2044  if (succeeded(convOr)) {
2045  llvm::append_range(results, (*convOr)->getResults());
2046  return success();
2047  }
2048 
2049  LDBG("Unsupported convolution can't be vectorized.\n");
2050  return failure();
2051  }
2052 
2053  LDBG("Vectorize generic by broadcasting to the canonical vector "
2054  "shape\n");
2055 
2056  // Pre-process before proceeding.
2057  convertAffineApply(rewriter, linalgOp);
2058 
2059  // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2060  // to 'OpBuilder' when it is passed over to some methods like
2061  // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2062  // erase an op within these methods, the actual rewriter won't be
2063  // notified and we will end up with read-after-free issues!
2064  return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2065  })
2066  .Case<tensor::PadOp>([&](auto padOp) {
2067  return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2068  results);
2069  })
2070  .Case<tensor::PackOp>([&](auto packOp) {
2071  return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2072  results);
2073  })
2074  .Case<tensor::UnPackOp>([&](auto unpackOp) {
2075  return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2076  inputVectorSizes, results);
2077  })
2078  .Default([](auto) { return failure(); });
2079 
2080  if (failed(vectorizeResult)) {
2081  LDBG("Vectorization failed\n");
2082  return failure();
2083  }
2084 
2085  if (!results.empty())
2086  rewriter.replaceOp(op, results);
2087  else
2088  rewriter.eraseOp(op);
2089 
2090  return success();
2091 }
2092 
2094  memref::CopyOp copyOp) {
2095  auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2096  auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2097  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2098  return failure();
2099 
2100  auto srcElementType = getElementTypeOrSelf(srcType);
2101  auto dstElementType = getElementTypeOrSelf(dstType);
2102  if (!VectorType::isValidElementType(srcElementType) ||
2103  !VectorType::isValidElementType(dstElementType))
2104  return failure();
2105 
2106  auto readType = VectorType::get(srcType.getShape(), srcElementType);
2107  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2108 
2109  Location loc = copyOp->getLoc();
2110  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2111  SmallVector<Value> indices(srcType.getRank(), zero);
2112 
2113  Value readValue = rewriter.create<vector::TransferReadOp>(
2114  loc, readType, copyOp.getSource(), indices,
2115  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2116  if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2117  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2118  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2119  }
2120  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2121  loc, readValue, copyOp.getTarget(), indices,
2122  rewriter.getMultiDimIdentityMap(srcType.getRank()));
2123  rewriter.replaceOp(copyOp, writeValue->getResults());
2124  return success();
2125 }
2126 
2127 //----------------------------------------------------------------------------//
2128 // Misc. vectorization patterns.
2129 //----------------------------------------------------------------------------//
2130 
2131 /// Helper function that retrieves the value of an IntegerAttr.
2132 static int64_t getIntFromAttr(Attribute attr) {
2133  return cast<IntegerAttr>(attr).getInt();
2134 }
2135 
2136 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2137 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2138 /// not supported.
2140  ArrayRef<OpFoldResult> ofrs) {
2141  SmallVector<Value> result;
2142  for (auto o : ofrs) {
2143  if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2144  result.push_back(val);
2145  } else {
2146  result.push_back(rewriter.create<arith::ConstantIndexOp>(
2147  loc, getIntFromAttr(o.template get<Attribute>())));
2148  }
2149  }
2150  return result;
2151 }
2152 
2153 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2154 /// InsertSliceOp. For now, only constant padding values are supported.
2155 /// If there is enough static type information, TransferReadOps and
2156 /// TransferWriteOps may be generated instead of InsertSliceOps.
2159  PatternBenefit benefit = 1)
2160  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2161  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2162  /// each dimension size is statically know in the source type or the result
2163  /// type (or both).
2165  tensor::PadOp padOp, Value dest) {
2166  auto sourceType = padOp.getSourceType();
2167  auto resultType = padOp.getResultType();
2168  if (!VectorType::isValidElementType(sourceType.getElementType()))
2169  return failure();
2170 
2171  // Copy cannot be vectorized if pad value is non-constant and source shape
2172  // is dynamic. In case of a dynamic source shape, padding must be appended
2173  // by TransferReadOp, but TransferReadOp supports only constant padding.
2174  auto padValue = padOp.getConstantPaddingValue();
2175  if (!padValue) {
2176  if (!sourceType.hasStaticShape())
2177  return failure();
2178  // Create dummy padding value.
2179  auto elemType = sourceType.getElementType();
2180  padValue = rewriter.create<arith::ConstantOp>(
2181  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2182  }
2183 
2184  SmallVector<int64_t> vecShape;
2185  SmallVector<bool> readInBounds;
2186  SmallVector<bool> writeInBounds;
2187  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2188  if (!sourceType.isDynamicDim(i)) {
2189  vecShape.push_back(sourceType.getDimSize(i));
2190  // Source shape is statically known: Neither read nor write are
2191  // out-of- bounds.
2192  readInBounds.push_back(true);
2193  writeInBounds.push_back(true);
2194  } else if (!resultType.isDynamicDim(i)) {
2195  // Source shape is not statically known, but result shape is.
2196  // Vectorize with size of result shape. This may be larger than the
2197  // source size.
2198  vecShape.push_back(resultType.getDimSize(i));
2199  // Read may be out-of-bounds because the result size could be larger
2200  // than the source size.
2201  readInBounds.push_back(false);
2202  // Write is out-of-bounds if low padding > 0.
2203  writeInBounds.push_back(
2204  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2205  static_cast<int64_t>(0));
2206  } else {
2207  // Neither source nor result dim of padOp is static. Cannot vectorize
2208  // the copy.
2209  return failure();
2210  }
2211  }
2212  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2213 
2214  // Generate TransferReadOp.
2215  SmallVector<Value> readIndices(
2216  vecType.getRank(),
2217  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2218  auto read = rewriter.create<vector::TransferReadOp>(
2219  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2220  ArrayRef<bool>{readInBounds});
2221 
2222  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2223  // entire tensor, write directly to the FillOp's operand.
2224  if (llvm::equal(vecShape, resultType.getShape()) &&
2225  llvm::all_of(writeInBounds, [](bool b) { return b; }))
2226  if (auto fill = dest.getDefiningOp<FillOp>())
2227  dest = fill.output();
2228 
2229  // Generate TransferWriteOp.
2230  auto writeIndices =
2231  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2232  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2233  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2234 
2235  return success();
2236  }
2237 };
2238 
2239 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2240 /// given operation type OpTy.
2241 template <typename OpTy>
2242 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2244 
2245  LogicalResult matchAndRewrite(tensor::PadOp padOp,
2246  PatternRewriter &rewriter) const final {
2247  bool changed = false;
2248  // Insert users in vector, because some users may be replaced/removed.
2249  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2250  if (auto op = dyn_cast<OpTy>(user))
2251  changed |= rewriteUser(rewriter, padOp, op).succeeded();
2252  return success(changed);
2253  }
2254 
2255 protected:
2257  tensor::PadOp padOp, OpTy op) const = 0;
2258 };
2259 
2260 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2261 /// ```
2262 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2263 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2264 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2265 /// ```
2266 /// is rewritten to:
2267 /// ```
2268 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2269 /// {in_bounds = [true, true]}
2270 /// : tensor<?x?xf32>, vector<17x5xf32>
2271 /// ```
2272 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2273 /// sure that the original padding value %cst was never used.
2274 ///
2275 /// This rewrite is possible if:
2276 /// - `xferOp` has no out-of-bounds dims or mask.
2277 /// - Low padding is static 0.
2278 /// - Single, scalar padding value.
2280  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2282  vector::TransferReadOp>::VectorizePadOpUserPattern;
2283 
2284  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2285  vector::TransferReadOp xferOp) const override {
2286  // Low padding must be static 0.
2287  if (!padOp.hasZeroLowPad())
2288  return failure();
2289  // Pad value must be a constant.
2290  auto padValue = padOp.getConstantPaddingValue();
2291  if (!padValue)
2292  return failure();
2293  // Padding value of existing `xferOp` is unused.
2294  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2295  return failure();
2296 
2297  rewriter.modifyOpInPlace(xferOp, [&]() {
2298  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2299  xferOp->setAttr(xferOp.getInBoundsAttrName(),
2300  rewriter.getBoolArrayAttr(inBounds));
2301  xferOp.getSourceMutable().assign(padOp.getSource());
2302  xferOp.getPaddingMutable().assign(padValue);
2303  });
2304 
2305  return success();
2306  }
2307 };
2308 
2309 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2310 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2311 /// value, where the same amount of padding is immediately removed again after
2312 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2313 /// tensor value and apply out-of-bounds masking. E.g.:
2314 /// ```
2315 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2316 /// : tensor<...> to tensor<?x?xf32>
2317 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2318 /// %2 = vector.transfer_write %vec, %1[...]
2319 /// : vector<17x5xf32>, tensor<17x5xf32>
2320 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2321 /// : tensor<17x5xf32> to tensor<?x?xf32>
2322 /// ```
2323 /// is rewritten to:
2324 /// ```
2325 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2326 /// : tensor<...> to tensor<?x?xf32>
2327 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2328 /// tensor<?x?xf32>
2329 /// ```
2330 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2331 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2332 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2333 /// from %r's old dimensions.
2334 ///
2335 /// This rewrite is possible if:
2336 /// - Low padding is static 0.
2337 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2338 /// ExtractSliceOp trims the same amount of padding that was added
2339 /// beforehand.
2340 /// - Single, scalar padding value.
2342  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2344  vector::TransferWriteOp>::VectorizePadOpUserPattern;
2345 
2346  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2347  vector::TransferWriteOp xferOp) const override {
2348  // TODO: support 0-d corner case.
2349  if (xferOp.getTransferRank() == 0)
2350  return failure();
2351 
2352  // Low padding must be static 0.
2353  if (!padOp.hasZeroLowPad())
2354  return failure();
2355  // Pad value must be a constant.
2356  auto padValue = padOp.getConstantPaddingValue();
2357  if (!padValue)
2358  return failure();
2359  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2360  if (!xferOp->hasOneUse())
2361  return failure();
2362  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2363  if (!trimPadding)
2364  return failure();
2365  // Only static zero offsets supported when trimming padding.
2366  if (!trimPadding.hasZeroOffset())
2367  return failure();
2368  // trimPadding must remove the amount of padding that was added earlier.
2369  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2370  return failure();
2371 
2372  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2373  rewriter.setInsertionPoint(xferOp);
2374 
2375  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2376  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2377  xferOp, padOp.getSource().getType(), xferOp.getVector(),
2378  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2379  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2380  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2381 
2382  return success();
2383  }
2384 
2385  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2386  /// i.e., same dimensions.
2387  ///
2388  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2389  /// dimensions, this function tries to infer the (static) tensor size by
2390  /// looking at the defining op and utilizing op-specific knowledge.
2391  ///
2392  /// This is a conservative analysis. In case equal tensor sizes cannot be
2393  /// proven statically, this analysis returns `false` even though the tensor
2394  /// sizes may turn out to be equal at runtime.
2395  bool hasSameTensorSize(Value beforePadding,
2396  tensor::ExtractSliceOp afterTrimming) const {
2397  // If the input to tensor::PadOp is a CastOp, try with both CastOp
2398  // result and CastOp operand.
2399  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2400  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2401  return true;
2402 
2403  auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2404  auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2405  // Only RankedTensorType supported.
2406  if (!t1 || !t2)
2407  return false;
2408  // Rank of both values must be the same.
2409  if (t1.getRank() != t2.getRank())
2410  return false;
2411 
2412  // All static dimensions must be the same. Mixed cases (e.g., dimension
2413  // static in `t1` but dynamic in `t2`) are not supported.
2414  for (unsigned i = 0; i < t1.getRank(); ++i) {
2415  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2416  return false;
2417  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2418  return false;
2419  }
2420 
2421  // Nothing more to check if all dimensions are static.
2422  if (t1.getNumDynamicDims() == 0)
2423  return true;
2424 
2425  // All dynamic sizes must be the same. The only supported case at the
2426  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2427  // thereof).
2428 
2429  // Apart from CastOp, only ExtractSliceOp is supported.
2430  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2431  if (!beforeSlice)
2432  return false;
2433 
2434  assert(static_cast<size_t>(t1.getRank()) ==
2435  beforeSlice.getMixedSizes().size());
2436  assert(static_cast<size_t>(t2.getRank()) ==
2437  afterTrimming.getMixedSizes().size());
2438 
2439  for (unsigned i = 0; i < t1.getRank(); ++i) {
2440  // Skip static dimensions.
2441  if (!t1.isDynamicDim(i))
2442  continue;
2443  auto size1 = beforeSlice.getMixedSizes()[i];
2444  auto size2 = afterTrimming.getMixedSizes()[i];
2445 
2446  // Case 1: Same value or same constant int.
2447  if (isEqualConstantIntOrValue(size1, size2))
2448  continue;
2449 
2450  // Other cases: Take a deeper look at defining ops of values.
2451  auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2452  auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2453  if (!v1 || !v2)
2454  return false;
2455 
2456  // Case 2: Both values are identical AffineMinOps. (Should not happen if
2457  // CSE is run.)
2458  auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2459  auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2460  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2461  minOp1.getOperands() == minOp2.getOperands())
2462  continue;
2463 
2464  // Add additional cases as needed.
2465  }
2466 
2467  // All tests passed.
2468  return true;
2469  }
2470 };
2471 
2472 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2473 /// ```
2474 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2475 /// %r = tensor.insert_slice %0
2476 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2477 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2478 /// ```
2479 /// is rewritten to:
2480 /// ```
2481 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2482 /// : tensor<?x?xf32>, vector<17x5xf32>
2483 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2484 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2485 /// ```
2486 ///
2487 /// This rewrite is possible if:
2488 /// - Low padding is static 0.
2489 /// - `padOp` result shape is static.
2490 /// - The entire padded tensor is inserted.
2491 /// (Implies that sizes of `insertOp` are all static.)
2492 /// - Only unit strides in `insertOp`.
2493 /// - Single, scalar padding value.
2494 /// - `padOp` result not used as destination.
2496  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2498  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2499 
2500  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2501  tensor::InsertSliceOp insertOp) const override {
2502  // Low padding must be static 0.
2503  if (!padOp.hasZeroLowPad())
2504  return failure();
2505  // Only unit stride supported.
2506  if (!insertOp.hasUnitStride())
2507  return failure();
2508  // Pad value must be a constant.
2509  auto padValue = padOp.getConstantPaddingValue();
2510  if (!padValue)
2511  return failure();
2512  // Dynamic shapes not supported.
2513  if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2514  return failure();
2515  // Pad result not used as destination.
2516  if (insertOp.getDest() == padOp.getResult())
2517  return failure();
2518 
2519  auto vecType = VectorType::get(padOp.getType().getShape(),
2520  padOp.getType().getElementType());
2521  unsigned vecRank = vecType.getRank();
2522  unsigned tensorRank = insertOp.getType().getRank();
2523 
2524  // Check if sizes match: Insert the entire tensor into most minor dims.
2525  // (No permutations allowed.)
2526  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2527  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2528  if (!llvm::all_of(
2529  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2530  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2531  }))
2532  return failure();
2533 
2534  // Insert the TransferReadOp and TransferWriteOp at the position of the
2535  // InsertSliceOp.
2536  rewriter.setInsertionPoint(insertOp);
2537 
2538  // Generate TransferReadOp: Read entire source tensor and add high
2539  // padding.
2540  SmallVector<Value> readIndices(
2541  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2542  auto read = rewriter.create<vector::TransferReadOp>(
2543  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2544 
2545  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2546  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2547  // source must fit into the destination at the specified offsets.
2548  auto writeIndices =
2549  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2550  SmallVector<bool> inBounds(vecRank, true);
2551  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2552  insertOp, read, insertOp.getDest(), writeIndices,
2553  ArrayRef<bool>{inBounds});
2554 
2555  return success();
2556  }
2557 };
2558 
2560  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2561  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2562  baseBenefit);
2563  // Try these specialized patterns first before resorting to the generic one.
2567  patterns.getContext(), baseBenefit.getBenefit() + 1);
2568 }
2569 
2570 //----------------------------------------------------------------------------//
2571 // Forwarding patterns
2572 //----------------------------------------------------------------------------//
2573 
2574 /// Check whether there is any interleaved use of any `values` between
2575 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2576 /// is in a different block.
2577 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2578  ValueRange values) {
2579  if (firstOp->getBlock() != secondOp->getBlock() ||
2580  !firstOp->isBeforeInBlock(secondOp)) {
2581  LDBG("interleavedUses precondition failed, firstOp: "
2582  << *firstOp << ", second op: " << *secondOp << "\n");
2583  return true;
2584  }
2585  for (auto v : values) {
2586  for (auto &u : v.getUses()) {
2587  Operation *owner = u.getOwner();
2588  if (owner == firstOp || owner == secondOp)
2589  continue;
2590  // TODO: this is too conservative, use dominance info in the future.
2591  if (owner->getBlock() == firstOp->getBlock() &&
2592  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2593  continue;
2594  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2595  << ", second op: " << *secondOp << "\n");
2596  return true;
2597  }
2598  }
2599  return false;
2600 }
2601 
2602 /// Return the unique subview use of `v` if it is indeed unique, null
2603 /// otherwise.
2604 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2605  memref::SubViewOp subViewOp;
2606  for (auto &u : v.getUses()) {
2607  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2608  if (subViewOp)
2609  return memref::SubViewOp();
2610  subViewOp = newSubViewOp;
2611  }
2612  }
2613  return subViewOp;
2614 }
2615 
2616 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2617 /// when available.
2619  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2620 
2621  // TODO: support mask.
2622  if (xferOp.getMask())
2623  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2624 
2625  // Transfer into `view`.
2626  Value viewOrAlloc = xferOp.getSource();
2627  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2628  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2629  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2630 
2631  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2632  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2633  if (!subViewOp)
2634  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2635  Value subView = subViewOp.getResult();
2636 
2637  // Find the copy into `subView` without interleaved uses.
2638  memref::CopyOp copyOp;
2639  for (auto &u : subView.getUses()) {
2640  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2641  assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2642  if (newCopyOp.getTarget() != subView)
2643  continue;
2644  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2645  continue;
2646  copyOp = newCopyOp;
2647  break;
2648  }
2649  }
2650  if (!copyOp)
2651  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2652 
2653  // Find the fill into `viewOrAlloc` without interleaved uses before the
2654  // copy.
2655  FillOp maybeFillOp;
2656  for (auto &u : viewOrAlloc.getUses()) {
2657  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2658  assert(isa<MemRefType>(newFillOp.output().getType()));
2659  if (newFillOp.output() != viewOrAlloc)
2660  continue;
2661  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2662  continue;
2663  maybeFillOp = newFillOp;
2664  break;
2665  }
2666  }
2667  // Ensure padding matches.
2668  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2669  return rewriter.notifyMatchFailure(xferOp,
2670  "padding value does not match fill");
2671 
2672  // `in` is the subview that memref.copy reads. Replace it.
2673  Value in = copyOp.getSource();
2674 
2675  // memref.copy + linalg.fill can be used to create a padded local buffer.
2676  // The `masked` attribute is only valid on this padded buffer.
2677  // When forwarding to vector.transfer_read, the attribute must be reset
2678  // conservatively.
2679  Value res = rewriter.create<vector::TransferReadOp>(
2680  xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2681  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2682  // in_bounds is explicitly reset
2683  /*inBoundsAttr=*/ArrayAttr());
2684 
2685  if (maybeFillOp)
2686  rewriter.eraseOp(maybeFillOp);
2687  rewriter.eraseOp(copyOp);
2688  rewriter.replaceOp(xferOp, res);
2689 
2690  return success();
2691 }
2692 
2693 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2694 /// when available.
2696  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2697  // TODO: support mask.
2698  if (xferOp.getMask())
2699  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2700 
2701  // Transfer into `viewOrAlloc`.
2702  Value viewOrAlloc = xferOp.getSource();
2703  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2704  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2705  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2706 
2707  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2708  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2709  if (!subViewOp)
2710  return rewriter.notifyMatchFailure(xferOp, "no subview found");
2711  Value subView = subViewOp.getResult();
2712 
2713  // Find the copy from `subView` without interleaved uses.
2714  memref::CopyOp copyOp;
2715  for (auto &u : subViewOp.getResult().getUses()) {
2716  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2717  if (newCopyOp.getSource() != subView)
2718  continue;
2719  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2720  continue;
2721  copyOp = newCopyOp;
2722  break;
2723  }
2724  }
2725  if (!copyOp)
2726  return rewriter.notifyMatchFailure(xferOp, "no copy found");
2727 
2728  // `out` is the subview copied into that we replace.
2729  assert(isa<MemRefType>(copyOp.getTarget().getType()));
2730  Value out = copyOp.getTarget();
2731 
2732  // Forward vector.transfer into copy.
2733  // memref.copy + linalg.fill can be used to create a padded local buffer.
2734  // The `masked` attribute is only valid on this padded buffer.
2735  // When forwarding to vector.transfer_write, the attribute must be reset
2736  // conservatively.
2737  rewriter.create<vector::TransferWriteOp>(
2738  xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2739  xferOp.getPermutationMapAttr(), xferOp.getMask(),
2740  // in_bounds is explicitly reset
2741  /*inBoundsAttr=*/ArrayAttr());
2742 
2743  rewriter.eraseOp(copyOp);
2744  rewriter.eraseOp(xferOp);
2745 
2746  return success();
2747 }
2748 
2749 //===----------------------------------------------------------------------===//
2750 // Convolution vectorization patterns
2751 //===----------------------------------------------------------------------===//
2752 
2753 template <int N>
2754 static void bindShapeDims(ShapedType shapedType) {}
2755 
2756 template <int N, typename IntTy, typename... IntTy2>
2757 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2758  val = shapedType.getShape()[N];
2759  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2760 }
2761 
2762 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2763 template <typename... IntTy>
2764 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2765  bindShapeDims<0>(shapedType, vals...);
2766 }
2767 
2768 namespace {
2769 bool isCastOfBlockArgument(Operation *op) {
2770  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2771  isa<BlockArgument>(op->getOperand(0));
2772 }
2773 
2774 bool isSupportedPoolKind(vector::CombiningKind kind) {
2775  switch (kind) {
2776  case vector::CombiningKind::ADD:
2777  case vector::CombiningKind::MAXNUMF:
2778  case vector::CombiningKind::MAXIMUMF:
2779  case vector::CombiningKind::MAXSI:
2780  case vector::CombiningKind::MAXUI:
2781  case vector::CombiningKind::MINNUMF:
2782  case vector::CombiningKind::MINIMUMF:
2783  case vector::CombiningKind::MINSI:
2785  return true;
2786  default:
2787  return false;
2788  }
2789 }
2790 
2791 /// Generate a vector implementation for either:
2792 /// ```
2793 /// Op def: ( w, kw )
2794 /// Iters: ({Par(), Red()})
2795 /// Layout: {{w + kw}, {kw}, {w}}
2796 /// ```
2797 /// kw is unrolled.
2798 ///
2799 /// or
2800 ///
2801 /// ```
2802 /// Op def: ( n, w, c, kw, f )
2803 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2804 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2805 /// ```
2806 /// kw is unrolled, w is unrolled iff dilationW > 1.
2807 ///
2808 /// or
2809 ///
2810 /// ```
2811 /// Op def: ( n, c, w, f, kw )
2812 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2813 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2814 /// ```
2815 /// kw is unrolled, w is unrolled iff dilationW > 1.
2816 ///
2817 /// or
2818 ///
2819 /// ```
2820 /// Op def: ( n, w, c, kw )
2821 /// Iters: ({Par(), Par(), Par(), Red()})
2822 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2823 /// ```
2824 /// kw is unrolled, w is unrolled iff dilationW > 1.
2825 struct Conv1DGenerator
2826  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2827  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2828  int dilationW)
2829  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2830  strideW(strideW), dilationW(dilationW) {
2831  // Determine whether `linalgOp` can be generated with this generator
2832  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2833  return;
2834  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2835  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2836  resShaped = linalgOp.getDpsInitOperand(0)->get();
2837  lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2838  rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2839  resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2840  if (!lhsShapedType || !rhsShapedType || !resShapedType)
2841  return;
2842  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2843  // (non-channeled convolution -> LHS and RHS both have single dimensions).
2844  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2845  (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2846  return;
2847 
2848  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2849  if (!reduceOp)
2850  return;
2851  redOp = reduceOp->getName().getIdentifier();
2852 
2853  if (!setOperKind(reduceOp))
2854  return;
2855  auto maybeKind = getCombinerOpKind(reduceOp);
2856  if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2857  (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2858  return;
2859  }
2860 
2861  auto rhsRank = rhsShapedType.getRank();
2862  switch (oper) {
2863  case Conv:
2864  if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2865  return;
2866  break;
2867  case Pool:
2868  if (rhsRank != 1)
2869  return;
2870  break;
2871  }
2872  // The op is now known to be valid.
2873  valid = true;
2874  }
2875 
2876  /// Generate a vector implementation for:
2877  /// ```
2878  /// Op def: ( w, kw )
2879  /// Iters: ({Par(), Red()})
2880  /// Layout: {{w + kw}, {kw}, {w}}
2881  /// ```
2882  /// kw is always unrolled.
2883  ///
2884  /// or
2885  ///
2886  /// ```
2887  /// Op def: ( n, w, c, kw, f )
2888  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2889  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2890  /// ```
2891  /// kw is always unrolled.
2892  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2893  /// > 1.
2894  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
2895  if (!valid)
2896  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
2897 
2898  int64_t nSize, wSize, cSize, kwSize, fSize;
2899  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
2900  bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
2901  switch (conv1DOpOrder) {
2902  case Conv1DOpOrder::W:
2903  // Initialize unused dimensions
2904  nSize = fSize = cSize = 0;
2905  // out{W}
2906  bindShapeDims(resShapedType, wSize);
2907  // kernel{kw}
2908  bindShapeDims(rhsShapedType, kwSize);
2909  lhsShape = {// iw = ow + kw - 1
2910  // (i.e. 16 convolved with 3 -> 14)
2911  (wSize + kwSize - 1)};
2912  rhsShape = {kwSize};
2913  resShape = {wSize};
2914  break;
2915  case Conv1DOpOrder::Nwc:
2916  // out{n, w, f}
2917  bindShapeDims(resShapedType, nSize, wSize, fSize);
2918  switch (oper) {
2919  case Conv:
2920  // kernel{kw, c, f}
2921  bindShapeDims(rhsShapedType, kwSize, cSize);
2922  break;
2923  case Pool:
2924  // kernel{kw}
2925  bindShapeDims(rhsShapedType, kwSize);
2926  cSize = fSize;
2927  break;
2928  }
2929  lhsShape = {nSize,
2930  // iw = ow * sw + kw * dw - 1
2931  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2932  // Perform the proper inclusive -> exclusive -> inclusive.
2933  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2934  1,
2935  cSize};
2936  switch (oper) {
2937  case Conv:
2938  rhsShape = {kwSize, cSize, fSize};
2939  break;
2940  case Pool:
2941  rhsShape = {kwSize};
2942  break;
2943  }
2944  resShape = {nSize, wSize, fSize};
2945  break;
2946  case Conv1DOpOrder::Ncw:
2947  // out{n, f, w}
2948  bindShapeDims(resShapedType, nSize, fSize, wSize);
2949  switch (oper) {
2950  case Conv:
2951  // kernel{f, c, kw}
2952  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
2953  break;
2954  case Pool:
2955  // kernel{kw}
2956  bindShapeDims(rhsShapedType, kwSize);
2957  cSize = fSize;
2958  break;
2959  }
2960  lhsShape = {nSize, cSize,
2961  // iw = ow * sw + kw * dw - 1
2962  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2963  // Perform the proper inclusive -> exclusive -> inclusive.
2964  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2965  1};
2966  switch (oper) {
2967  case Conv:
2968  rhsShape = {fSize, cSize, kwSize};
2969  break;
2970  case Pool:
2971  rhsShape = {kwSize};
2972  break;
2973  }
2974  resShape = {nSize, fSize, wSize};
2975  break;
2976  }
2977 
2978  vector::TransferWriteOp write;
2979  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2980 
2981  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2982  // When strideW == 1, we can batch the contiguous loads and avoid
2983  // unrolling
2984  int64_t wSizeStep = strideW == 1 ? wSize : 1;
2985 
2986  Type lhsEltType = lhsShapedType.getElementType();
2987  Type rhsEltType = rhsShapedType.getElementType();
2988  Type resEltType = resShapedType.getElementType();
2989  auto lhsType = VectorType::get(lhsShape, lhsEltType);
2990  auto rhsType = VectorType::get(rhsShape, rhsEltType);
2991  auto resType = VectorType::get(resShape, resEltType);
2992  // Zero padding with the corresponding dimensions for lhs, rhs and res.
2993  SmallVector<Value> lhsPadding(lhsShape.size(), zero);
2994  SmallVector<Value> rhsPadding(rhsShape.size(), zero);
2995  SmallVector<Value> resPadding(resShape.size(), zero);
2996 
2997  // Read the whole lhs, rhs and res in one shot (with zero padding).
2998  Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2999  lhsPadding);
3000  // This is needed only for Conv.
3001  Value rhs = nullptr;
3002  if (oper == Conv)
3003  rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3004  rhsPadding);
3005  Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3006  resPadding);
3007 
3008  // The base vectorization case for channeled convolution is input:
3009  // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3010  // vectorization case, we do pre transpose on input, weight, and output.
3011  switch (conv1DOpOrder) {
3012  case Conv1DOpOrder::W:
3013  case Conv1DOpOrder::Nwc:
3014  // Base case, so no transposes necessary.
3015  break;
3016  case Conv1DOpOrder::Ncw: {
3017  // To match base vectorization case, we pre-transpose current case.
3018  // ncw -> nwc
3019  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3020  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3021  // fcw -> wcf
3022  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3023 
3024  // This is needed only for Conv.
3025  if (oper == Conv)
3026  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3027  // nfw -> nwf
3028  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3029  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3030  break;
3031  }
3032  }
3033 
3034  //===------------------------------------------------------------------===//
3035  // Begin vector-only rewrite part
3036  //===------------------------------------------------------------------===//
3037  // Unroll along kw and read slices of lhs and rhs.
3038  SmallVector<Value> lhsVals, rhsVals, resVals;
3039  lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3040  kwSize, strideW, dilationW, wSizeStep,
3041  isSingleChanneled);
3042  // Do not do for pooling.
3043  if (oper == Conv)
3044  rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3045  resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3046  wSizeStep, isSingleChanneled);
3047 
3048  auto linearIndex = [&](int64_t kw, int64_t w) {
3049  return kw * (wSize / wSizeStep) + w;
3050  };
3051 
3052  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3053  // or perform outerproduct for non-channeled convolution or perform simple
3054  // arith operation for pooling
3055  for (int64_t kw = 0; kw < kwSize; ++kw) {
3056  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3057  switch (oper) {
3058  case Conv:
3059  if (isSingleChanneled) {
3060  resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3061  lhsVals[linearIndex(kw, w)],
3062  rhsVals[kw], resVals[w]);
3063  } else {
3064  resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3065  lhsVals[linearIndex(kw, w)],
3066  rhsVals[kw], resVals[w]);
3067  }
3068  break;
3069  case Pool:
3070  resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3071  resVals[w]);
3072  break;
3073  }
3074  }
3075  }
3076 
3077  res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3078  isSingleChanneled);
3079  //===------------------------------------------------------------------===//
3080  // End vector-only rewrite part
3081  //===------------------------------------------------------------------===//
3082 
3083  // The base vectorization case for channeled convolution is output:
3084  // {n,w,f} To reuse the result from base pattern vectorization case, we
3085  // post transpose the base case result.
3086  switch (conv1DOpOrder) {
3087  case Conv1DOpOrder::W:
3088  case Conv1DOpOrder::Nwc:
3089  // Base case, so no transposes necessary.
3090  break;
3091  case Conv1DOpOrder::Ncw: {
3092  // nwf -> nfw
3093  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3094  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3095  break;
3096  }
3097  }
3098 
3099  return rewriter
3100  .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3101  .getOperation();
3102  }
3103 
3104  // Take a value and widen to have the same element type as `ty`.
3105  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3106  const Type srcElementType = getElementTypeOrSelf(val.getType());
3107  const Type dstElementType = getElementTypeOrSelf(ty);
3108  assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3109  if (srcElementType == dstElementType)
3110  return val;
3111 
3112  const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3113  const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3114  const Type dstType =
3115  cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3116 
3117  if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3118  return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3119  }
3120 
3121  if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3122  srcWidth < dstWidth)
3123  return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3124 
3125  if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3126  srcWidth < dstWidth)
3127  return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3128 
3129  assert(false && "unhandled promotion case");
3130  return nullptr;
3131  }
3132 
3133  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3134  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3135  Value lhs, Value rhs, Value res) {
3136  vector::IteratorType par = vector::IteratorType::parallel;
3137  vector::IteratorType red = vector::IteratorType::reduction;
3138  AffineExpr n, w, f, c;
3139  bindDims(ctx, n, w, f, c);
3140  lhs = promote(rewriter, loc, lhs, res.getType());
3141  rhs = promote(rewriter, loc, rhs, res.getType());
3142  return rewriter.create<vector::ContractionOp>(
3143  loc, lhs, rhs, res,
3144  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3145  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3146  }
3147 
3148  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3149  // convolution.
3150  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3151  Value lhs, Value rhs, Value res) {
3152  return rewriter.create<vector::OuterProductOp>(
3153  loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3154  }
3155 
3156  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3157  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3158  Value res) {
3159  if (isPoolExt)
3160  lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3161  return rewriter
3162  .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3163  ->getResult(0);
3164  }
3165 
3166  /// Generate a vector implementation for:
3167  /// ```
3168  /// Op def: ( n, w, c, kw)
3169  /// Iters: ({Par(), Par(), Par(), Red()})
3170  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3171  /// ```
3172  /// kw is always unrolled.
3173  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3174  /// > 1.
3175  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3176  bool channelDimScalableFlag,
3177  bool flatten) {
3178  if (!valid)
3179  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3180 
3181  bool scalableChDim = false;
3182  bool useMasking = false;
3183  int64_t nSize, wSize, cSize, kwSize;
3184  // kernel{kw, c}
3185  bindShapeDims(rhsShapedType, kwSize, cSize);
3186  if (ShapedType::isDynamic(cSize)) {
3187  assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3188  cSize = channelDimVecSize;
3189  // Scalable vectors are only used when both conditions are met:
3190  // 1. channel dim is dynamic
3191  // 2. channelDimScalableFlag is set
3192  scalableChDim = channelDimScalableFlag;
3193  useMasking = true;
3194  }
3195 
3196  assert(!(useMasking && flatten) &&
3197  "Unsupported flattened conv with dynamic shapes");
3198 
3199  // out{n, w, c}
3200  bindShapeDims(resShapedType, nSize, wSize);
3201 
3202  vector::TransferWriteOp write;
3203  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3204 
3205  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3206  // When strideW == 1, we can batch the contiguous loads and avoid
3207  // unrolling
3208  int64_t wSizeStep = strideW == 1 ? wSize : 1;
3209 
3210  Type lhsEltType = lhsShapedType.getElementType();
3211  Type rhsEltType = rhsShapedType.getElementType();
3212  Type resEltType = resShapedType.getElementType();
3213  VectorType lhsType = VectorType::get(
3214  {nSize,
3215  // iw = ow * sw + kw * dw - 1
3216  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3217  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3218  cSize},
3219  lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3220  VectorType rhsType =
3221  VectorType::get({kwSize, cSize}, rhsEltType,
3222  /*scalableDims=*/{false, scalableChDim});
3223  VectorType resType =
3224  VectorType::get({nSize, wSize, cSize}, resEltType,
3225  /*scalableDims=*/{false, false, scalableChDim});
3226 
3227  // Masks the input xfer Op along the channel dim, iff the corresponding
3228  // scalable flag is set.
3229  auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3230  ArrayRef<bool> scalableDims,
3231  Operation *opToMask) {
3232  if (!useMasking)
3233  return opToMask;
3234  auto maskType =
3235  VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3236 
3238  cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3239 
3240  Value maskOp =
3241  rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3242 
3243  return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3244  };
3245 
3246  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3247  // 0].
3248  Value lhs = rewriter.create<vector::TransferReadOp>(
3249  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3250  auto maybeMaskedLhs = maybeMaskXferOp(
3251  lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3252 
3253  // Read rhs slice of size {kw, c} @ [0, 0].
3254  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3255  ValueRange{zero, zero});
3256  auto maybeMaskedRhs = maybeMaskXferOp(
3257  rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3258 
3259  // Read res slice of size {n, w, c} @ [0, 0, 0].
3260  Value res = rewriter.create<vector::TransferReadOp>(
3261  loc, resType, resShaped, ValueRange{zero, zero, zero});
3262  auto maybeMaskedRes = maybeMaskXferOp(
3263  resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3264 
3265  //===------------------------------------------------------------------===//
3266  // Begin vector-only rewrite part
3267  //===------------------------------------------------------------------===//
3268  // Unroll along kw and read slices of lhs and rhs.
3269  SmallVector<Value> lhsVals, rhsVals, resVals;
3270  auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3271  auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3272 
3273  // Extract lhs slice of size {n, wSizeStep, c}
3274  // @ [0, sw * w + dw * kw, 0].
3275  for (int64_t kw = 0; kw < kwSize; ++kw) {
3276  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3277  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3278  loc, maybeMaskedLhs->getResult(0),
3279  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3280  inOutSliceSizes, inOutStrides));
3281  }
3282  }
3283  // Extract rhs slice of size {c} @ [kw].
3284  for (int64_t kw = 0; kw < kwSize; ++kw) {
3285  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3286  loc, maybeMaskedRhs->getResult(0),
3287  /*offsets=*/ArrayRef<int64_t>{kw}));
3288  }
3289  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3290  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3291  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3292  loc, maybeMaskedRes->getResult(0),
3293  /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3294  inOutStrides));
3295  }
3296 
3297  auto linearIndex = [&](int64_t kw, int64_t w) {
3298  return kw * (wSize / wSizeStep) + w;
3299  };
3300 
3301  // Note - the scalable flags are ignored as flattening combined with
3302  // scalable vectorization is not supported.
3303  auto inOutFlattenSliceSizes =
3304  SmallVector<int64_t>{nSize, wSizeStep * cSize};
3305  auto lhsTypeAfterFlattening =
3306  VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3307  auto resTypeAfterFlattening =
3308  VectorType::get(inOutFlattenSliceSizes, resEltType);
3309 
3310  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3311  for (int64_t kw = 0; kw < kwSize; ++kw) {
3312  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3313  Value lhsVal = lhsVals[linearIndex(kw, w)];
3314  Value resVal = resVals[w];
3315  if (flatten) {
3316  // Flatten the input and output vectors (collapse the channel
3317  // dimension)
3318  lhsVal = rewriter.create<vector::ShapeCastOp>(
3319  loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3320  resVal = rewriter.create<vector::ShapeCastOp>(
3321  loc, resTypeAfterFlattening, resVals[w]);
3322  }
3323  resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3324  rhsVals[kw], resVal, flatten);
3325  if (flatten) {
3326  // Un-flatten the output vector (restore the channel dimension)
3327  resVals[w] = rewriter.create<vector::ShapeCastOp>(
3328  loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3329  }
3330  }
3331  }
3332 
3333  // Its possible we failed to create the Fma.
3334  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3335  // Manually revert (in reverse order) to avoid leaving a bad IR state.
3336  for (auto &collection :
3337  {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3338  for (Value v : collection)
3339  rewriter.eraseOp(v.getDefiningOp());
3340  return rewriter.notifyMatchFailure(op, "failed to create FMA");
3341  }
3342 
3343  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3344  // This does not depend on kw.
3345  for (int64_t w = 0; w < wSize; w += wSizeStep) {
3346  maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3347  loc, resVals[w], maybeMaskedRes->getResult(0),
3348  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3349  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3350  }
3351  //===------------------------------------------------------------------===//
3352  // End vector-only rewrite part
3353  //===------------------------------------------------------------------===//
3354 
3355  // Write back res slice of size {n, w, c} @ [0, 0, 0].
3356  Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3357  loc, maybeMaskedRes->getResult(0), resShaped,
3358  ValueRange{zero, zero, zero});
3359  return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3360  resOut);
3361  }
3362 
3363  /// Lower:
3364  /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3365  /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3366  /// to MulAcc.
3367  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3368  Value lhs, Value rhs, Value res,
3369  bool flatten) {
3370  auto rhsTy = cast<ShapedType>(rhs.getType());
3371  auto resTy = cast<ShapedType>(res.getType());
3372 
3373  // TODO(suderman): Change this to use a vector.ima intrinsic.
3374  lhs = promote(rewriter, loc, lhs, resTy);
3375 
3376  if (flatten) {
3377  // NOTE: This following logic won't work for scalable vectors. For this
3378  // reason, "flattening" is not supported when shapes are dynamic (this
3379  // should be captured by one of the pre-conditions).
3380 
3381  // There are two options for handling the filter:
3382  // * shape_cast(broadcast(filter))
3383  // * broadcast(shuffle(filter))
3384  // Opt for the option without shape_cast to simplify the codegen.
3385  auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3386  auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3387 
3388  SmallVector<int64_t, 16> indices;
3389  for (int i = 0; i < resSize / rhsSize; ++i) {
3390  for (int j = 0; j < rhsSize; ++j)
3391  indices.push_back(j);
3392  }
3393 
3394  rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3395  }
3396  // Broadcast the filter to match the output vector
3397  rhs = rewriter.create<vector::BroadcastOp>(
3398  loc, resTy.clone(rhsTy.getElementType()), rhs);
3399 
3400  rhs = promote(rewriter, loc, rhs, resTy);
3401 
3402  if (!lhs || !rhs)
3403  return nullptr;
3404 
3405  if (isa<FloatType>(resTy.getElementType()))
3406  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3407 
3408  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3409  return rewriter.create<arith::AddIOp>(loc, mul, res);
3410  }
3411 
3412  /// Entry point for non-channeled convolution:
3413  /// {{w + kw}, {kw}, {w}}
3414  FailureOr<Operation *> generateNonChanneledConv() {
3415  AffineExpr w, kw;
3416  bindDims(ctx, w, kw);
3417  if (!iters({Par(), Red()}))
3418  return rewriter.notifyMatchFailure(op,
3419  "failed to match conv::W 1-par 1-red");
3420 
3421  // No transposition needed.
3422  if (layout({/*lhsIndex*/ {w + kw},
3423  /*rhsIndex*/ {kw},
3424  /*resIndex*/ {w}}))
3425  return conv(Conv1DOpOrder::W);
3426 
3427  return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3428  }
3429 
3430  /// Entry point that transposes into the common form:
3431  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3432  FailureOr<Operation *> generateNwcConv() {
3433  AffineExpr n, w, f, kw, c;
3434  bindDims(ctx, n, w, f, kw, c);
3435  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3436  return rewriter.notifyMatchFailure(
3437  op, "failed to match conv::Nwc 3-par 2-red");
3438 
3439  // No transposition needed.
3440  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3441  /*rhsIndex*/ {kw, c, f},
3442  /*resIndex*/ {n, w, f}}))
3443  return conv(Conv1DOpOrder::Nwc);
3444 
3445  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3446  }
3447 
3448  /// Entry point that transposes into the common form:
3449  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3450  FailureOr<Operation *> generateNcwConv() {
3451  AffineExpr n, w, f, kw, c;
3452  bindDims(ctx, n, f, w, c, kw);
3453  if (!iters({Par(), Par(), Par(), Red(), Red()}))
3454  return rewriter.notifyMatchFailure(
3455  op, "failed to match conv::Ncw 3-par 2-red");
3456 
3457  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3458  /*rhsIndex*/ {f, c, kw},
3459  /*resIndex*/ {n, f, w}}))
3460  return conv(Conv1DOpOrder::Ncw);
3461 
3462  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3463  }
3464 
3465  /// Entry point that transposes into the common form:
3466  /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3467  FailureOr<Operation *> generateNwcPooling() {
3468  AffineExpr n, w, c, kw;
3469  bindDims(ctx, n, w, c, kw);
3470  if (!iters({Par(), Par(), Par(), Red()}))
3471  return rewriter.notifyMatchFailure(op,
3472  "failed to match pooling 3-par 1-red");
3473 
3474  // No transposition needed.
3475  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3476  /*rhsIndex*/ {kw},
3477  /*resIndex*/ {n, w, c}}))
3478  return conv(Conv1DOpOrder::Nwc);
3479 
3480  return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3481  }
3482 
3483  /// Entry point that transposes into the common form:
3484  /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3485  FailureOr<Operation *> generateNcwPooling() {
3486  AffineExpr n, w, c, kw;
3487  bindDims(ctx, n, c, w, kw);
3488  if (!iters({Par(), Par(), Par(), Red()}))
3489  return rewriter.notifyMatchFailure(op,
3490  "failed to match pooling 3-par 1-red");
3491 
3492  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3493  /*rhsIndex*/ {kw},
3494  /*resIndex*/ {n, c, w}}))
3495  return conv(Conv1DOpOrder::Ncw);
3496 
3497  return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3498  }
3499 
3500  /// Entry point that transposes into the common form:
3501  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3502  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3503  bool vecChDimScalableFlag = false,
3504  bool flatten = false) {
3505  AffineExpr n, w, c, kw;
3506  bindDims(ctx, n, w, c, kw);
3507  if (!iters({Par(), Par(), Par(), Red()}))
3508  return rewriter.notifyMatchFailure(
3509  op, "failed to match depthwise::Nwc conv 3-par 1-red");
3510 
3511  // No transposition needed.
3512  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3513  /*rhsIndex*/ {kw, c},
3514  /*resIndex*/ {n, w, c}}))
3515  return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3516 
3517  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3518  }
3519 
3520 private:
3521  enum OperKind { Conv, Pool };
3522  bool valid = false;
3523  OperKind oper = Conv;
3524  StringAttr redOp;
3525  StringAttr poolExtOp;
3526  bool isPoolExt = false;
3527  int strideW, dilationW;
3528  Value lhsShaped, rhsShaped, resShaped;
3529  ShapedType lhsShapedType, rhsShapedType, resShapedType;
3530 
3531  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3532  // Returns true iff it is a valid conv/pooling op.
3533  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3534  // + yield) and rhs is not used) then it is the body of a pooling
3535  // If conv, check for single `mul` predecessor. The `mul` operands must be
3536  // block arguments or extension of block arguments.
3537  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3538  // must be block arguments or extension of block arguments.
3539  bool setOperKind(Operation *reduceOp) {
3540  int numBlockArguments =
3541  llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3542  switch (numBlockArguments) {
3543  case 1: {
3544  // Will be convolution if feeder is a MulOp.
3545  // Otherwise, if it can be pooling.
3546  auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3547  llvm::IsaPred<BlockArgument>);
3548  Operation *feedOp = (*feedValIt).getDefiningOp();
3549  if (isCastOfBlockArgument(feedOp)) {
3550  oper = Pool;
3551  isPoolExt = true;
3552  poolExtOp = feedOp->getName().getIdentifier();
3553  } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3554  llvm::all_of(feedOp->getOperands(), [](Value v) {
3555  if (isa<BlockArgument>(v))
3556  return true;
3557  if (Operation *op = v.getDefiningOp())
3558  return isCastOfBlockArgument(op);
3559  return false;
3560  }))) {
3561  return false;
3562  }
3563  return true;
3564  }
3565  case 2:
3566  // Must be pooling
3567  oper = Pool;
3568  isPoolExt = false;
3569  return true;
3570  default:
3571  return false;
3572  }
3573  }
3574 };
3575 } // namespace
3576 
3577 /// Helper function to vectorize a LinalgOp with convolution semantics.
3578 // TODO: extend the generic vectorization to support windows and drop this.
3580  RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3581  ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3582  // The ConvolutionOpInterface gives us guarantees of existence for
3583  // strides/dilations. However, we do not need to rely on those, we can
3584  // simply use them if present, otherwise use the default and let the generic
3585  // conv. matcher in the ConvGenerator succeed or fail.
3586  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3587  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3588  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3589  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3590  Conv1DGenerator e(rewriter, op, stride, dilation);
3591  auto res = e.generateNonChanneledConv();
3592  if (succeeded(res))
3593  return res;
3594  res = e.generateNwcConv();
3595  if (succeeded(res))
3596  return res;
3597  res = e.generateNcwConv();
3598  if (succeeded(res))
3599  return res;
3600  res = e.generateNwcPooling();
3601  if (succeeded(res))
3602  return res;
3603  res = e.generateNcwPooling();
3604  if (succeeded(res))
3605  return res;
3606 
3607  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3608  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3609  // masked/scalable) is the channel dim (i.e. the trailing dim).
3610  uint64_t vecChDimSize = ShapedType::kDynamic;
3611  bool vecChDimScalableFlag = false;
3612  if (!inputVecSizes.empty()) {
3613  // Only use the input vector size corresponding to the channel dim. Other
3614  // vector dims will be inferred from the Ops.
3615  assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3616  isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3617  "Not a 1D depthwise conv!");
3618  size_t chDimIdx =
3620  .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3621  .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3622 
3623  vecChDimSize = inputVecSizes[chDimIdx];
3624  vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3625  }
3626  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3627  flatten1DDepthwiseConv);
3628 }
3629 
3632 
3634  PatternRewriter &rewriter) const override {
3635  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3636  if (failed(resultOrFail))
3637  return failure();
3638  Operation *newOp = *resultOrFail;
3639  if (newOp->getNumResults() == 0) {
3640  rewriter.eraseOp(op.getOperation());
3641  return success();
3642  }
3643  assert(newOp->getNumResults() == 1 && "expected single result");
3644  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3645  return success();
3646  }
3647 };
3648 
3650  RewritePatternSet &patterns, PatternBenefit benefit) {
3651  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3652 }
static VectorShape vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MINUI(lhs, rhs)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:237
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:136
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:292
MLIRContext * getContext() const
Definition: AffineMap.cpp:331
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:322
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:583
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:395
unsigned getNumInputs() const
Definition: AffineMap.cpp:391
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
Definition: AffineMap.cpp:143
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:252
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:544
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:613
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:290
OpListType & getOperations()
Definition: Block.h:135
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
IntegerType getI32Type()
Definition: Builders.cpp:83
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:150
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Definition: Utils.cpp:215
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:149
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition: File.h:43
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2267
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:650
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:781
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:757
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:62
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:687
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:627
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1412
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.