MLIR  16.0.0git
Vectorization.cpp
Go to the documentation of this file.
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Support/LLVM.h"
32 #include "llvm/ADT/ScopeExit.h"
33 #include "llvm/ADT/Sequence.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <type_traits>
39 
40 using namespace mlir;
41 using namespace mlir::linalg;
42 
43 #define DEBUG_TYPE "linalg-vectorization"
44 
45 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
47 
48 /// Try to vectorize `convOp` as a convolution.
50  LinalgOp convOp);
51 
52 /// Return the unique instance of OpType in `block` if it is indeed unique.
53 /// Return null if none or more than 1 instances exist.
54 template <typename OpType>
55 static OpType getSingleOpOfType(Block &block) {
56  OpType res;
57  block.walk([&](OpType op) {
58  if (res) {
59  res = nullptr;
60  return WalkResult::interrupt();
61  }
62  res = op;
63  return WalkResult::advance();
64  });
65  return res;
66 }
67 
68 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
69 /// projectedPermutation, compress the unused dimensions to serve as a
70 /// permutation_map for a vector transfer operation.
71 /// For example, given a linalg op such as:
72 ///
73 /// ```
74 /// %0 = linalg.generic {
75 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
76 /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
77 /// }
78 /// ins(%0 : tensor<2x3x4xf32>)
79 /// outs(%1 : tensor<5x6xf32>)
80 /// ```
81 ///
82 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
83 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
84 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
86  assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
87  "expected projected permutation");
88  auto res = compressUnusedDims(map);
89  assert(res.getNumDims() == res.getNumResults() &&
90  "expected reindexed map with same number of dims and results");
91  return res;
92 }
93 
94 /// Helper enum to represent conv1d input traversal order.
95 enum class Conv1DOpOrder {
96  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
97  Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
98 };
99 
100 /// Helper data structure to represent the result of vectorization.
101 /// In certain specific cases, like terminators, we do not want to propagate/
103  /// Op failed to vectorize.
104  Failure = 0,
105  /// Op vectorized and custom function took care of replacement logic
107  /// Op vectorized into a new Op whose results will replace original Op's
108  /// results.
109  NewOp
110  // TODO: support values if Op vectorized to Many-Ops whose results we need to
111  // aggregate for replacement.
112 };
114  /// Return status from vectorizing the current op.
116  /// New vectorized operation to replace the current op.
117  /// Replacement behavior is specified by `status`.
119 };
120 
123  using ::mlir::vector::CombiningKind;
124 
125  if (!combinerOp)
126  return std::nullopt;
128  combinerOp)
129  .Case<arith::AddIOp, arith::AddFOp>(
130  [&](auto op) { return CombiningKind::ADD; })
131  .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
132  .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
133  .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
134  .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
135  .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
136  .Case<arith::MulIOp, arith::MulFOp>(
137  [&](auto op) { return CombiningKind::MUL; })
138  .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
139  .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
140  .Default([&](auto op) { return std::nullopt; });
141 }
142 
143 /// Check whether `outputOperand` is a reduction with a single combiner
144 /// operation. Return the combiner operation of the reduction. Return
145 /// nullptr otherwise. Multiple reduction operations would impose an
146 /// ordering between reduction dimensions and is currently unsupported in
147 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
148 /// max(min(X))
149 // TODO: use in LinalgOp verification, there is a circular dependency atm.
150 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
151  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
152  unsigned outputPos =
153  outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
154  // Only single combiner operations are supported for now.
155  SmallVector<Operation *, 4> combinerOps;
156  if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
157  combinerOps.size() != 1)
158  return nullptr;
159 
160  // Return the combiner operation.
161  return combinerOps[0];
162 }
163 
164 /// Broadcast `value` to a vector of `shape` if possible. Return value
165 /// otherwise.
167  ArrayRef<int64_t> shape) {
168  // If no shape to broadcast to, just return `value`.
169  if (shape.empty())
170  return value;
171  VectorType targetVectorType =
172  VectorType::get(shape, getElementTypeOrSelf(value));
173  if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
175  return value;
176  Location loc = b.getInsertionPoint()->getLoc();
177  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
178  value);
179 }
180 
181 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
182 /// assumes that `reductionOp` has two operands and one of them is the reduction
183 /// initial value.buildMultiDimReduce
184 // Note: this is a true builder that notifies the OpBuilder listener.
185 // TODO: Consider moving as a static helper on the ReduceOp.
187  Operation *reduceOp, Value valueToReduce,
188  Value acc,
189  const SmallVector<bool> &reductionMask) {
190  auto maybeKind = getCombinerOpKind(reduceOp);
191  assert(maybeKind && "Failed precondition: could not get reduction kind");
192  return b.create<vector::MultiDimReductionOp>(
193  reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
194 }
195 
196 static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
197  return llvm::to_vector(
198  llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
199 }
200 
201 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
202 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
203 /// currently being vectorized. If `dest` has null rank, build an memref.store.
204 /// Return the produced value or null if no value is produced.
205 // Note: this is a true builder that notifies the OpBuilder listener.
206 // TODO: Consider moving as a static helper on the ReduceOp.
208  OpOperand *outputOperand) {
209  Operation *write;
210  Location loc = value.getLoc();
211  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
212  ArrayRef<int64_t> shape = linalgOp.getShape(outputOperand);
213  auto vectorType = VectorType::get(
214  shape, getElementTypeOrSelf(outputOperand->get().getType()));
215  if (vectorType.getRank() > 0) {
216  // 0-d case is still special: do not invert the reindexing map.
217  AffineMap map =
218  reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand));
219  SmallVector<int64_t> transposeShape =
220  applyPermutationMap(inversePermutation(map), vectorType.getShape());
221  assert(!transposeShape.empty() && "unexpected empty transpose shape");
222  vectorType = VectorType::get(transposeShape, vectorType.getElementType());
223  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
224  b.create<arith::ConstantIndexOp>(loc, 0));
225  value = broadcastIfNeeded(b, value, vectorType.getShape());
226  write = b.create<vector::TransferWriteOp>(
227  loc, value, outputOperand->get(), indices, map);
228  } else {
229  if (!value.getType().isa<VectorType>())
230  value = b.create<vector::BroadcastOp>(loc, vectorType, value);
231  assert(value.getType() == vectorType && "incorrect type");
232  write = b.create<vector::TransferWriteOp>(
233  loc, value, outputOperand->get(), ValueRange{});
234  }
235  LDBG("vectorized op: " << *write);
236  if (!write->getResults().empty())
237  return write->getResult(0);
238  return Value();
239 }
240 
241 // Custom vectorization precondition function type. This is intented to be used
242 // with CustomVectorizationHook. Returns success if the corresponding custom
243 // hook can vectorize the op.
245  std::function<LogicalResult(Operation *)>;
246 
247 // Custom vectorization function type. Produce a vector form of Operation*
248 // assuming all its vectorized operands are already in the BlockAndValueMapping.
249 // Return nullptr if the Operation cannot be vectorized.
251  Operation *, const BlockAndValueMapping &)>;
252 
253 /// Helper function to vectorize the terminator of a `linalgOp`. New result
254 /// vector values are appended to `newResults`. Return
255 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
256 /// should not try to map produced operations and instead return the results
257 /// using the `newResults` vector making them available to the vectorization
258 /// algorithm for RAUW. This function is meant to be used as a
259 /// CustomVectorizationHook.
260 static VectorizationResult
262  const BlockAndValueMapping &bvm, LinalgOp linalgOp,
263  SmallVectorImpl<Value> &newResults) {
264  auto yieldOp = dyn_cast<linalg::YieldOp>(op);
265  if (!yieldOp)
267  for (const auto &outputs : llvm::enumerate(yieldOp.getValues())) {
268  // TODO: Scan for an opportunity for reuse.
269  // TODO: use a map.
270  Value vectorValue = bvm.lookup(outputs.value());
271  Value newResult = buildVectorWrite(
272  rewriter, vectorValue, linalgOp.getDpsInitOperand(outputs.index()));
273  if (newResult)
274  newResults.push_back(newResult);
275  }
277 }
278 
279 /// Helper function to vectorize the index operations of a `linalgOp`. Return
280 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
281 /// should map the produced operations. This function is meant to be used as a
282 /// CustomVectorizationHook.
283 static VectorizationResult
284 vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
285  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
286  if (!indexOp)
288  auto loc = indexOp.getLoc();
289  // Compute the static loop sizes of the index op.
290  auto targetShape = linalgOp.computeStaticLoopSizes();
291  // Compute a one-dimensional index vector for the index op dimension.
292  SmallVector<int64_t> constantSeq =
293  llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
294  auto constantOp = rewriter.create<arith::ConstantOp>(
295  loc, rewriter.getIndexVectorAttr(constantSeq));
296  // Return the one-dimensional index vector if it lives in the trailing
297  // dimension of the iteration space since the vectorization algorithm in this
298  // case can handle the broadcast.
299  if (indexOp.getDim() == targetShape.size() - 1)
301  // Otherwise permute the targetShape to move the index dimension last,
302  // broadcast the one-dimensional index vector to the permuted shape, and
303  // finally transpose the broadcasted index vector to undo the permutation.
304  std::swap(targetShape[indexOp.getDim()], targetShape.back());
305  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
306  loc, VectorType::get(targetShape, rewriter.getIndexType()), constantOp);
307  SmallVector<int64_t> transposition =
308  llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
309  std::swap(transposition.back(), transposition[indexOp.getDim()]);
310  auto transposeOp =
311  rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
312  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
313 }
314 
315 /// Helper function to check if the tensor.extract can be vectorized by the
316 /// custom hook vectorizeTensorExtract.
318  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
319  if (!extractOp)
320  return failure();
321 
322  // Currently only supports extraction with an 1-D index.
323  if (extractOp.getIndices().size() != 1)
324  return failure();
325 
326  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
327  return failure();
328 
329  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
330  return !VectorType::isValidElementType(type);
331  })) {
332  return failure();
333  }
334 
335  return success();
336 }
337 
338 /// Helper function to vectorize the tensor.extract operations. Returns
339 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
340 /// should map the produced operations. This function is meant to be used as a
341 /// CustomVectorizationHook.
342 static VectorizationResult
343 vectorizeTensorExtract(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp,
344  const BlockAndValueMapping &bvm) {
345  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
346  if (!extractOp)
348  auto loc = extractOp.getLoc();
349 
350  // Currently only supports extraction with an 1-D index. Checked in the
351  // tensorExtractVectorizationPrecondition.
352  assert(extractOp.getIndices().size() == 1);
353 
354  auto indexVec = bvm.lookup(extractOp.getIndices()[0]);
355  // Compute the static loop sizes of the extract op.
356  auto targetShape = linalgOp.computeStaticLoopSizes();
357 
358  SmallVector<Value> gatherIndices;
359  gatherIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
360 
361  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
363  VectorType::get(targetShape, rewriter.getI1Type()),
364  /*value=*/true));
365 
366  auto resultType =
367  VectorType::get(targetShape, extractOp.getResult().getType());
368  auto passThruConstantOp =
369  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
370 
371  auto gatherOp = rewriter.create<vector::GatherOp>(
372  loc, resultType, extractOp.getTensor(), gatherIndices, indexVec,
373  maskConstantOp, passThruConstantOp);
374 
376 }
377 
378 /// Emit reduction operations if the shapes of the value to reduce is different
379 /// that the result shape.
380 // Note: this is a true builder that notifies the OpBuilder listener.
381 // TODO: Consider moving as a static helper on the ReduceOp.
382 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp,
383  Operation *op, Value reduceValue,
384  Value initialValue,
385  const BlockAndValueMapping &bvm) {
386  Value reduceVec = bvm.lookup(reduceValue);
387  Value outputVec = bvm.lookup(initialValue);
388  auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
389  auto outputType = outputVec.getType().dyn_cast<VectorType>();
390  // Reduce only if needed as the value may already have been reduce for
391  // contraction vectorization.
392  if (!reduceType ||
393  (outputType && reduceType.getShape() == outputType.getShape()))
394  return nullptr;
395  SmallVector<bool> reductionMask = getReductionMask(linalgOp);
396  return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
397 }
398 
399 /// Generic vectorization for a single operation `op`, given already vectorized
400 /// operands carried by `bvm`. Vectorization occurs as follows:
401 /// 1. Try to apply any of the `customVectorizationHooks` and return its
402 /// result on success.
403 /// 2. Clone any constant in the current scope without vectorization: each
404 /// consumer of the constant will later determine the shape to which the
405 /// constant needs to be broadcast to.
406 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
407 /// of the `customVectorizationHooks` to cover such cases.
408 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
409 /// operand of maximal rank. Other operands have smaller rank and are
410 /// broadcast accordingly. It is assumed this broadcast is always legal,
411 /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
412 ///
413 /// This function assumes all operands of `op` have been vectorized and are in
414 /// the `bvm` mapping. As a consequence, this function is meant to be called on
415 /// a topologically-sorted list of ops.
416 /// This function does not update `bvm` but returns a VectorizationStatus that
417 /// instructs the caller what `bvm` update needs to occur.
418 static VectorizationResult
419 vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
420  const BlockAndValueMapping &bvm,
421  ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
422  LDBG("vectorize op " << *op);
423 
424  // 1. Try to apply any CustomVectorizationHook.
425  if (!customVectorizationHooks.empty()) {
426  for (auto &customFunc : customVectorizationHooks) {
427  VectorizationResult result = customFunc(op, bvm);
428  if (result.status == VectorizationStatus::Failure)
429  continue;
430  return result;
431  }
432  }
433 
434  // 2. Constant ops don't get vectorized but rather broadcasted at their users.
435  // Clone so that the constant is not confined to the linalgOp block .
436  if (isa<arith::ConstantOp, func::ConstantOp>(op))
438 
439  // 3. Only ElementwiseMappable are allowed in the generic vectorization.
442 
443  // 4 . Check if the operation is a reduction.
444  SmallVector<std::pair<Value, Value>> reductionOperands;
445  for (Value operand : op->getOperands()) {
446  auto arg = operand.dyn_cast<BlockArgument>();
447  if (!arg || arg.getArgNumber() < linalgOp.getNumDpsInputs())
448  continue;
449  SmallVector<Operation *> reductionOps;
450  Value reduceValue = matchReduction(
451  linalgOp.getRegionOutputArgs(),
452  arg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
453  if (!reduceValue)
454  continue;
455  reductionOperands.push_back(std::make_pair(reduceValue, operand));
456  }
457  if (!reductionOperands.empty()) {
458  assert(reductionOperands.size() == 1);
459  Operation *reduceOp =
460  reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
461  reductionOperands[0].second, bvm);
462  if (reduceOp)
464  }
465 
466  // 5. Generic vectorization path for ElementwiseMappable ops.
467  // a. first get the first max ranked shape.
468  SmallVector<int64_t, 4> firstMaxRankedShape;
469  for (Value operand : op->getOperands()) {
470  auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
471  if (vt && firstMaxRankedShape.size() < vt.getShape().size())
472  firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
473  }
474  // rewriter. broadcast each op if needed.
475  auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
476  return firstMaxRankedShape.empty()
477  ? bvm.lookup(v)
478  : broadcastIfNeeded(rewriter, bvm.lookup(v),
479  firstMaxRankedShape);
480  });
481  // c. for elementwise, the result is the vector with the firstMaxRankedShape
482  auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
483  return firstMaxRankedShape.empty()
484  ? t
485  : VectorType::get(firstMaxRankedShape, t);
486  });
487 
488  // Build and return the new op.
489  return VectorizationResult{
491  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
492  llvm::to_vector<4>(vectorizedOperands),
493  llvm::to_vector<4>(returnTypes), op->getAttrs())};
494 }
495 
496 /// Generic vectorization function that rewrites the body of a `linalgOp` into
497 /// vector form. Generic vectorization proceeds as follows:
498 /// 1. Verify the `linalgOp` has one non-empty region.
499 /// 2. Values defined above the region are mapped to themselves and will be
500 /// broadcasted on a per-need basis by their consumers.
501 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
502 /// load).
503 /// TODO: Reuse opportunities for RAR dependencies.
504 /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
505 /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
506 /// iteration indices.
507 /// 5. Iteratively call vectorizeOneOp on the region operations.
508 ///
509 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
510 /// performed to the maximal common vector size implied by the `linalgOp`
511 /// iteration space. This eager broadcasting is introduced in the
512 /// permutation_map of the vector.transfer_read operations. The eager
513 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
514 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
515 /// the absence of good canonicalizations, the amount of work increases.
516 /// This is not deemed a problem as we expect canonicalizations and foldings to
517 /// aggressively clean up the useless work.
518 static LogicalResult
519 vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
520  SmallVectorImpl<Value> &newResults) {
521  Block *block = linalgOp.getBlock();
522 
523  // 2. Values defined above the region can only be broadcast for now. Make them
524  // map to themselves.
526  SetVector<Value> valuesSet;
527  mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
528  bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
529 
530  if (linalgOp.getNumDpsInits() == 0)
531  return failure();
532 
533  // TODO: the common vector shape is equal to the static loop sizes only when
534  // all indexing maps are projected permutations. For convs and stencils the
535  // logic will need to evolve.
536  SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
537 
538  // 3. Turn all BBArgs into vector.transfer_read / load.
539  Location loc = linalgOp.getLoc();
540  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
541  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
542  BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
543  if (linalgOp.isScalar(opOperand)) {
544  bvm.map(bbarg, opOperand->get());
545  continue;
546  }
547  VectorType readType;
548  AffineMap map;
549  // TODO: can we keep this simplification?
550  // if (linalgOp.getShape(&opOperand).empty()) {
551  // readType = VectorType::get({}, bbarg.getType());
552  // } else {
553  if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) {
555  linalgOp.getMatchingIndexingMap(opOperand));
556  readType = VectorType::get(commonVectorShape,
557  getElementTypeOrSelf(opOperand->get()));
558  } else {
559  map = inversePermutation(
560  reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
561  readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
562  getElementTypeOrSelf(opOperand->get()));
563  }
564  // }
565 
566  auto shape = linalgOp.getShape(opOperand);
567  SmallVector<Value> indices(shape.size(), zero);
568  Value readValue = rewriter.create<vector::TransferReadOp>(
569  loc, readType, opOperand->get(), indices, map);
570  // Not all ops support 0-d vectors, extract the scalar for now.
571  // TODO: remove this.
572  if (readValue.getType().cast<VectorType>().getRank() == 0)
573  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
574 
575  LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
576  bvm.map(bbarg, readValue);
577  bvm.map(opOperand->get(), readValue);
578  }
579 
581  // 4a. Register CustomVectorizationHook for yieldOp.
582  CustomVectorizationHook vectorizeYield =
583  [&](Operation *op,
585  return vectorizeLinalgYield(rewriter, op, bvm, linalgOp, newResults);
586  };
587  hooks.push_back(vectorizeYield);
588 
589  // 4rewriter. Register CustomVectorizationHook for indexOp.
590  CustomVectorizationHook vectorizeIndex =
591  [&](Operation *op,
593  return vectorizeLinalgIndex(rewriter, op, linalgOp);
594  };
595  hooks.push_back(vectorizeIndex);
596 
597  // 4c. Register CustomVectorizationHook for extractOp.
598  CustomVectorizationHook vectorizeExtract =
599  [&](Operation *op,
601  return vectorizeTensorExtract(rewriter, op, linalgOp, bvm);
602  };
603  hooks.push_back(vectorizeExtract);
604 
605  // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
606  for (Operation &op : block->getOperations()) {
607  VectorizationResult result =
608  vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks);
609  if (result.status == VectorizationStatus::Failure) {
610  LDBG("failed to vectorize: " << op);
611  return failure();
612  }
613  if (result.status == VectorizationStatus::NewOp) {
614  LDBG("new vector op: " << *result.newOp;);
615  bvm.map(op.getResults(), result.newOp->getResults());
616  }
617  }
618 
619  return success();
620 }
621 
622 // TODO: probably need some extra checks for reduction followed by consumer
623 // ops that may not commute (e.g. linear reduction + non-linear instructions).
625  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
626  LDBG("reduction precondition failed: no reduction iterator");
627  return failure();
628  }
629  for (OpOperand *opOperand : op.getDpsInitOperands()) {
630  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
631  if (indexingMap.isPermutation())
632  continue;
633 
634  Operation *reduceOp = matchLinalgReduction(opOperand);
635  if (!reduceOp || !getCombinerOpKind(reduceOp)) {
636  LDBG("reduction precondition failed: reduction detection failed");
637  return failure();
638  }
639  }
640  return success();
641 }
642 
644  linalg::LinalgOp op,
645  ArrayRef<CustomVectorizationPrecondition> customPreconditions) {
646 
647  // All types in the body should be a supported element type for VectorType.
648  for (Operation &innerOp : op->getRegion(0).front()) {
649  // Check if any custom hook can vectorize the inner op.
650  if (llvm::any_of(
651  customPreconditions,
652  [&](const CustomVectorizationPrecondition &customPrecondition) {
653  return succeeded(customPrecondition(&innerOp));
654  })) {
655  continue;
656  }
657  if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
658  return !VectorType::isValidElementType(type);
659  })) {
660  return failure();
661  }
662  if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
663  return !VectorType::isValidElementType(type);
664  })) {
665  return failure();
666  }
667  }
668  if (isElementwise(op))
669  return success();
670  // TODO: isaConvolutionOpInterface that can also infer from generic features.
671  // But we will still need stride/dilation attributes that will be annoying to
672  // reverse-engineer...
673  if (isa<ConvolutionOpInterface>(op.getOperation()))
674  return success();
675  // TODO: the common vector shape is equal to the static loop sizes only when
676  // all indexing maps are projected permutations. For convs and stencils the
677  // logic will need to evolve.
679  LDBG("precondition failed: not projected permutations");
680  return failure();
681  }
682  if (failed(reductionPreconditions(op))) {
683  LDBG("precondition failed: reduction preconditions");
684  return failure();
685  }
686  return success();
687 }
688 
690  // All types must be static shape to go to vector.
691  if (linalgOp.hasDynamicShape()) {
692  LDBG("precondition failed: dynamic shape");
693  return failure();
694  }
695 
697 
698  // Register CustomVectorizationPrecondition for extractOp.
699  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
700 
701  return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions);
702 }
703 
705  LinalgOp linalgOp) {
706  if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
707  return failure();
708 
709  SmallVector<Value> results;
710  // TODO: isaConvolutionOpInterface that can also infer from generic
711  // features. Will require stride/dilation attributes inference.
712  FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
713  if (succeeded(convOr)) {
714  llvm::append_range(results, (*convOr)->getResults());
715  } else {
716  if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
717  return failure();
718  LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
719  if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
720  return failure();
721  }
722 
723  if (!results.empty())
724  rewriter.replaceOp(linalgOp, results);
725  else
726  rewriter.eraseOp(linalgOp);
727 
728  return success();
729 }
730 
732  memref::CopyOp copyOp) {
733 
734  auto srcType = copyOp.getSource().getType().cast<MemRefType>();
735  auto dstType = copyOp.getTarget().getType().cast<MemRefType>();
736  if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
737  return failure();
738 
739  auto readType =
740  VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
741  auto writeType =
742  VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));
743 
744  Location loc = copyOp->getLoc();
745  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
746  SmallVector<Value> indices(srcType.getRank(), zero);
747 
748  Value readValue = rewriter.create<vector::TransferReadOp>(
749  loc, readType, copyOp.getSource(), indices,
750  rewriter.getMultiDimIdentityMap(srcType.getRank()));
751  if (readValue.getType().cast<VectorType>().getRank() == 0) {
752  readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
753  readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
754  }
755  Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
756  loc, readValue, copyOp.getTarget(), indices,
757  rewriter.getMultiDimIdentityMap(srcType.getRank()));
758  rewriter.replaceOp(copyOp, writeValue->getResults());
759  return success();
760 }
761 
762 //----------------------------------------------------------------------------//
763 // Misc. vectorization patterns.
764 //----------------------------------------------------------------------------//
765 
766 /// Helper function that retrieves the value of an IntegerAttr.
767 static int64_t getIntFromAttr(Attribute attr) {
768  return attr.cast<IntegerAttr>().getInt();
769 }
770 
771 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
772 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
773 /// not supported.
775  ArrayRef<OpFoldResult> ofrs) {
776  SmallVector<Value> result;
777  for (auto o : ofrs) {
778  if (auto val = o.template dyn_cast<Value>()) {
779  result.push_back(val);
780  } else {
781  result.push_back(rewriter.create<arith::ConstantIndexOp>(
782  loc, getIntFromAttr(o.template get<Attribute>())));
783  }
784  }
785  return result;
786 }
787 
788 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
789 /// InsertSliceOp. For now, only constant padding values are supported.
790 /// If there is enough static type information, TransferReadOps and
791 /// TransferWriteOps may be generated instead of InsertSliceOps.
794  PatternBenefit benefit = 1)
795  : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
796  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
797  /// each dimension size is statically know in the source type or the result
798  /// type (or both).
800  tensor::PadOp padOp, Value dest) {
801  auto sourceType = padOp.getSourceType();
802  auto resultType = padOp.getResultType();
803 
804  // Copy cannot be vectorized if pad value is non-constant and source shape
805  // is dynamic. In case of a dynamic source shape, padding must be appended
806  // by TransferReadOp, but TransferReadOp supports only constant padding.
807  auto padValue = padOp.getConstantPaddingValue();
808  if (!padValue) {
809  if (!sourceType.hasStaticShape())
810  return failure();
811  // Create dummy padding value.
812  auto elemType = sourceType.getElementType();
813  padValue = rewriter.create<arith::ConstantOp>(
814  padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
815  }
816 
817  SmallVector<int64_t> vecShape;
818  SmallVector<bool> readInBounds;
819  SmallVector<bool> writeInBounds;
820  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
821  if (!sourceType.isDynamicDim(i)) {
822  vecShape.push_back(sourceType.getDimSize(i));
823  // Source shape is statically known: Neither read nor write are
824  // out-of- bounds.
825  readInBounds.push_back(true);
826  writeInBounds.push_back(true);
827  } else if (!resultType.isDynamicDim(i)) {
828  // Source shape is not statically known, but result shape is.
829  // Vectorize with size of result shape. This may be larger than the
830  // source size.
831  vecShape.push_back(resultType.getDimSize(i));
832  // Read may be out-of-bounds because the result size could be larger
833  // than the source size.
834  readInBounds.push_back(false);
835  // Write is out-of-bounds if low padding > 0.
836  writeInBounds.push_back(
837  getConstantIntValue(padOp.getMixedLowPad()[i]) ==
838  static_cast<int64_t>(0));
839  } else {
840  // Neither source nor result dim of padOp is static. Cannot vectorize
841  // the copy.
842  return failure();
843  }
844  }
845  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
846 
847  // Generate TransferReadOp.
848  SmallVector<Value> readIndices(
849  vecType.getRank(),
850  rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
851  auto read = rewriter.create<vector::TransferReadOp>(
852  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
853  ArrayRef<bool>{readInBounds});
854 
855  // If `dest` is a FillOp and the TransferWriteOp would overwrite the
856  // entire tensor, write directly to the FillOp's operand.
857  if (llvm::equal(vecShape, resultType.getShape()) &&
858  llvm::all_of(writeInBounds, [](bool b) { return b; }))
859  if (auto fill = dest.getDefiningOp<FillOp>())
860  dest = fill.output();
861 
862  // Generate TransferWriteOp.
863  auto writeIndices =
864  ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
865  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
866  padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
867 
868  return success();
869  }
870 };
871 
872 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
873 /// given operation type OpTy.
874 template <typename OpTy>
875 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
877 
878  LogicalResult matchAndRewrite(tensor::PadOp padOp,
879  PatternRewriter &rewriter) const final {
880  bool changed = false;
881  // Insert users in vector, because some users may be replaced/removed.
882  for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
883  if (auto op = dyn_cast<OpTy>(user))
884  changed |= rewriteUser(rewriter, padOp, op).succeeded();
885  return success(changed);
886  }
887 
888 protected:
890  tensor::PadOp padOp, OpTy op) const = 0;
891 };
892 
893 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
894 /// ```
895 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
896 /// %r = vector.transfer_read %0[%c0, %c0], %cst
897 /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
898 /// ```
899 /// is rewritten to:
900 /// ```
901 /// %r = vector.transfer_read %src[%c0, %c0], %padding
902 /// {in_bounds = [true, true]}
903 /// : tensor<?x?xf32>, vector<17x5xf32>
904 /// ```
905 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
906 /// sure that the original padding value %cst was never used.
907 ///
908 /// This rewrite is possible if:
909 /// - `xferOp` has no out-of-bounds dims or mask.
910 /// - Low padding is static 0.
911 /// - Single, scalar padding value.
913  : public VectorizePadOpUserPattern<vector::TransferReadOp> {
915  vector::TransferReadOp>::VectorizePadOpUserPattern;
916 
917  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
918  vector::TransferReadOp xferOp) const override {
919  // Low padding must be static 0.
920  if (!padOp.hasZeroLowPad())
921  return failure();
922  // Pad value must be a constant.
923  auto padValue = padOp.getConstantPaddingValue();
924  if (!padValue)
925  return failure();
926  // Padding value of existing `xferOp` is unused.
927  if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
928  return failure();
929 
930  rewriter.updateRootInPlace(xferOp, [&]() {
931  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
932  xferOp->setAttr(xferOp.getInBoundsAttrName(),
933  rewriter.getBoolArrayAttr(inBounds));
934  xferOp.getSourceMutable().assign(padOp.getSource());
935  xferOp.getPaddingMutable().assign(padValue);
936  });
937 
938  return success();
939  }
940 };
941 
942 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
943 /// This pattern rewrites TransferWriteOps that write to a padded tensor
944 /// value, where the same amount of padding is immediately removed again after
945 /// the write. In such cases, the TransferWriteOp can write to the non-padded
946 /// tensor value and apply out-of-bounds masking. E.g.:
947 /// ```
948 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
949 /// : tensor<...> to tensor<?x?xf32>
950 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
951 /// %2 = vector.transfer_write %vec, %1[...]
952 /// : vector<17x5xf32>, tensor<17x5xf32>
953 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
954 /// : tensor<17x5xf32> to tensor<?x?xf32>
955 /// ```
956 /// is rewritten to:
957 /// ```
958 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
959 /// : tensor<...> to tensor<?x?xf32>
960 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
961 /// tensor<?x?xf32>
962 /// ```
963 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
964 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
965 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
966 /// from %r's old dimensions.
967 ///
968 /// This rewrite is possible if:
969 /// - Low padding is static 0.
970 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
971 /// ExtractSliceOp trims the same amount of padding that was added
972 /// beforehand.
973 /// - Single, scalar padding value.
975  : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
977  vector::TransferWriteOp>::VectorizePadOpUserPattern;
978 
979  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
980  vector::TransferWriteOp xferOp) const override {
981  // TODO: support 0-d corner case.
982  if (xferOp.getTransferRank() == 0)
983  return failure();
984 
985  // Low padding must be static 0.
986  if (!padOp.hasZeroLowPad())
987  return failure();
988  // Pad value must be a constant.
989  auto padValue = padOp.getConstantPaddingValue();
990  if (!padValue)
991  return failure();
992  // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
993  if (!xferOp->hasOneUse())
994  return failure();
995  auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
996  if (!trimPadding)
997  return failure();
998  // Only static zero offsets supported when trimming padding.
999  if (!trimPadding.hasZeroOffset())
1000  return failure();
1001  // trimPadding must remove the amount of padding that was added earlier.
1002  if (!hasSameTensorSize(padOp.getSource(), trimPadding))
1003  return failure();
1004 
1005  // Insert the new TransferWriteOp at position of the old TransferWriteOp.
1006  rewriter.setInsertionPoint(xferOp);
1007 
1008  SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
1009  auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1010  xferOp, padOp.getSource().getType(), xferOp.getVector(),
1011  padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
1012  xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
1013  rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
1014 
1015  return success();
1016  }
1017 
1018  /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
1019  /// i.e., same dimensions.
1020  ///
1021  /// Dimensions may be static, dynamic or mix of both. In case of dynamic
1022  /// dimensions, this function tries to infer the (static) tensor size by
1023  /// looking at the defining op and utilizing op-specific knowledge.
1024  ///
1025  /// This is a conservative analysis. In case equal tensor sizes cannot be
1026  /// proven statically, this analysis returns `false` even though the tensor
1027  /// sizes may turn out to be equal at runtime.
1028  bool hasSameTensorSize(Value beforePadding,
1029  tensor::ExtractSliceOp afterTrimming) const {
1030  // If the input to tensor::PadOp is a CastOp, try with with both CastOp
1031  // result and CastOp operand.
1032  if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
1033  if (hasSameTensorSize(castOp.getSource(), afterTrimming))
1034  return true;
1035 
1036  auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
1037  auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
1038  // Only RankedTensorType supported.
1039  if (!t1 || !t2)
1040  return false;
1041  // Rank of both values must be the same.
1042  if (t1.getRank() != t2.getRank())
1043  return false;
1044 
1045  // All static dimensions must be the same. Mixed cases (e.g., dimension
1046  // static in `t1` but dynamic in `t2`) are not supported.
1047  for (unsigned i = 0; i < t1.getRank(); ++i) {
1048  if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
1049  return false;
1050  if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
1051  return false;
1052  }
1053 
1054  // Nothing more to check if all dimensions are static.
1055  if (t1.getNumDynamicDims() == 0)
1056  return true;
1057 
1058  // All dynamic sizes must be the same. The only supported case at the
1059  // moment is when `beforePadding` is an ExtractSliceOp (or a cast
1060  // thereof).
1061 
1062  // Apart from CastOp, only ExtractSliceOp is supported.
1063  auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
1064  if (!beforeSlice)
1065  return false;
1066 
1067  assert(static_cast<size_t>(t1.getRank()) ==
1068  beforeSlice.getMixedSizes().size());
1069  assert(static_cast<size_t>(t2.getRank()) ==
1070  afterTrimming.getMixedSizes().size());
1071 
1072  for (unsigned i = 0; i < t1.getRank(); ++i) {
1073  // Skip static dimensions.
1074  if (!t1.isDynamicDim(i))
1075  continue;
1076  auto size1 = beforeSlice.getMixedSizes()[i];
1077  auto size2 = afterTrimming.getMixedSizes()[i];
1078 
1079  // Case 1: Same value or same constant int.
1080  if (isEqualConstantIntOrValue(size1, size2))
1081  continue;
1082 
1083  // Other cases: Take a deeper look at defining ops of values.
1084  auto v1 = size1.dyn_cast<Value>();
1085  auto v2 = size2.dyn_cast<Value>();
1086  if (!v1 || !v2)
1087  return false;
1088 
1089  // Case 2: Both values are identical AffineMinOps. (Should not happen if
1090  // CSE is run.)
1091  auto minOp1 = v1.getDefiningOp<AffineMinOp>();
1092  auto minOp2 = v2.getDefiningOp<AffineMinOp>();
1093  if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
1094  minOp1.getOperands() == minOp2.getOperands())
1095  continue;
1096 
1097  // Add additional cases as needed.
1098  }
1099 
1100  // All tests passed.
1101  return true;
1102  }
1103 };
1104 
1105 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
1106 /// ```
1107 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
1108 /// %r = tensor.insert_slice %0
1109 /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
1110 /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
1111 /// ```
1112 /// is rewritten to:
1113 /// ```
1114 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
1115 /// : tensor<?x?xf32>, vector<17x5xf32>
1116 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
1117 /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
1118 /// ```
1119 ///
1120 /// This rewrite is possible if:
1121 /// - Low padding is static 0.
1122 /// - `padOp` result shape is static.
1123 /// - The entire padded tensor is inserted.
1124 /// (Implies that sizes of `insertOp` are all static.)
1125 /// - Only unit strides in `insertOp`.
1126 /// - Single, scalar padding value.
1127 /// - `padOp` result not used as destination.
1129  : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
1131  tensor::InsertSliceOp>::VectorizePadOpUserPattern;
1132 
1133  LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
1134  tensor::InsertSliceOp insertOp) const override {
1135  // Low padding must be static 0.
1136  if (!padOp.hasZeroLowPad())
1137  return failure();
1138  // Only unit stride supported.
1139  if (!insertOp.hasUnitStride())
1140  return failure();
1141  // Pad value must be a constant.
1142  auto padValue = padOp.getConstantPaddingValue();
1143  if (!padValue)
1144  return failure();
1145  // Dynamic shapes not supported.
1146  if (!padOp.getResult().getType().cast<ShapedType>().hasStaticShape())
1147  return failure();
1148  // Pad result not used as destination.
1149  if (insertOp.getDest() == padOp.getResult())
1150  return failure();
1151 
1152  auto vecType = VectorType::get(padOp.getType().getShape(),
1153  padOp.getType().getElementType());
1154  unsigned vecRank = vecType.getRank();
1155  unsigned tensorRank = insertOp.getType().getRank();
1156 
1157  // Check if sizes match: Insert the entire tensor into most minor dims.
1158  // (No permutations allowed.)
1159  SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
1160  expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
1161  if (!llvm::all_of(
1162  llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
1163  return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
1164  }))
1165  return failure();
1166 
1167  // Insert the TransferReadOp and TransferWriteOp at the position of the
1168  // InsertSliceOp.
1169  rewriter.setInsertionPoint(insertOp);
1170 
1171  // Generate TransferReadOp: Read entire source tensor and add high
1172  // padding.
1173  SmallVector<Value> readIndices(
1174  vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1175  auto read = rewriter.create<vector::TransferReadOp>(
1176  padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
1177 
1178  // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
1179  // specified offsets. Write is fully in-bounds because a InsertSliceOp's
1180  // source must fit into the destination at the specified offsets.
1181  auto writeIndices =
1182  ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
1183  SmallVector<bool> inBounds(vecRank, true);
1184  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1185  insertOp, read, insertOp.getDest(), writeIndices,
1186  ArrayRef<bool>{inBounds});
1187 
1188  return success();
1189  }
1190 };
1191 
1193  RewritePatternSet &patterns, PatternBenefit baseBenefit) {
1194  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
1195  baseBenefit);
1196  // Try these specialized patterns first before resorting to the generic one.
1200  patterns.getContext(), baseBenefit.getBenefit() + 1);
1201 }
1202 
1203 //----------------------------------------------------------------------------//
1204 // Forwarding patterns
1205 //----------------------------------------------------------------------------//
1206 
1207 /// Check whether there is any interleaved use of any `values` between
1208 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
1209 /// is in a different block.
1210 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
1211  ValueRange values) {
1212  if (firstOp->getBlock() != secondOp->getBlock() ||
1213  !firstOp->isBeforeInBlock(secondOp)) {
1214  LDBG("interleavedUses precondition failed, firstOp: "
1215  << *firstOp << ", second op: " << *secondOp);
1216  return true;
1217  }
1218  for (auto v : values) {
1219  for (auto &u : v.getUses()) {
1220  Operation *owner = u.getOwner();
1221  if (owner == firstOp || owner == secondOp)
1222  continue;
1223  // TODO: this is too conservative, use dominance info in the future.
1224  if (owner->getBlock() == firstOp->getBlock() &&
1225  (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
1226  continue;
1227  LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
1228  << ", second op: " << *secondOp);
1229  return true;
1230  }
1231  }
1232  return false;
1233 }
1234 
1235 /// Return the unique subview use of `v` if it is indeed unique, null
1236 /// otherwise.
1237 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
1238  memref::SubViewOp subViewOp;
1239  for (auto &u : v.getUses()) {
1240  if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
1241  if (subViewOp)
1242  return memref::SubViewOp();
1243  subViewOp = newSubViewOp;
1244  }
1245  }
1246  return subViewOp;
1247 }
1248 
1249 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1250 /// when available.
1252  vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
1253 
1254  // TODO: support mask.
1255  if (xferOp.getMask())
1256  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
1257 
1258  // Transfer into `view`.
1259  Value viewOrAlloc = xferOp.getSource();
1260  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1261  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1262  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
1263 
1264  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1265  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1266  if (!subViewOp)
1267  return rewriter.notifyMatchFailure(xferOp, "no subview found");
1268  Value subView = subViewOp.getResult();
1269 
1270  // Find the copy into `subView` without interleaved uses.
1271  memref::CopyOp copyOp;
1272  for (auto &u : subView.getUses()) {
1273  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
1274  assert(newCopyOp.getTarget().getType().isa<MemRefType>());
1275  if (newCopyOp.getTarget() != subView)
1276  continue;
1277  if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
1278  continue;
1279  copyOp = newCopyOp;
1280  break;
1281  }
1282  }
1283  if (!copyOp)
1284  return rewriter.notifyMatchFailure(xferOp, "no copy found");
1285 
1286  // Find the fill into `viewOrAlloc` without interleaved uses before the
1287  // copy.
1288  FillOp maybeFillOp;
1289  for (auto &u : viewOrAlloc.getUses()) {
1290  if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
1291  assert(newFillOp.output().getType().isa<MemRefType>());
1292  if (newFillOp.output() != viewOrAlloc)
1293  continue;
1294  if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
1295  continue;
1296  maybeFillOp = newFillOp;
1297  break;
1298  }
1299  }
1300  // Ensure padding matches.
1301  if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
1302  return rewriter.notifyMatchFailure(xferOp,
1303  "padding value does not match fill");
1304 
1305  // `in` is the subview that memref.copy reads. Replace it.
1306  Value in = copyOp.getSource();
1307 
1308  // memref.copy + linalg.fill can be used to create a padded local buffer.
1309  // The `masked` attribute is only valid on this padded buffer.
1310  // When forwarding to vector.transfer_read, the attribute must be reset
1311  // conservatively.
1312  Value res = rewriter.create<vector::TransferReadOp>(
1313  xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
1314  xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
1315  // in_bounds is explicitly reset
1316  /*inBoundsAttr=*/ArrayAttr());
1317 
1318  if (maybeFillOp)
1319  rewriter.eraseOp(maybeFillOp);
1320  rewriter.eraseOp(copyOp);
1321  rewriter.replaceOp(xferOp, res);
1322 
1323  return success();
1324 }
1325 
1326 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1327 /// when available.
1329  vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
1330  // TODO: support mask.
1331  if (xferOp.getMask())
1332  return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
1333 
1334  // Transfer into `viewOrAlloc`.
1335  Value viewOrAlloc = xferOp.getSource();
1336  if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1337  !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1338  return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
1339 
1340  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1341  memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1342  if (!subViewOp)
1343  return rewriter.notifyMatchFailure(xferOp, "no subview found");
1344  Value subView = subViewOp.getResult();
1345 
1346  // Find the copy from `subView` without interleaved uses.
1347  memref::CopyOp copyOp;
1348  for (auto &u : subViewOp.getResult().getUses()) {
1349  if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
1350  if (newCopyOp.getSource() != subView)
1351  continue;
1352  if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
1353  continue;
1354  copyOp = newCopyOp;
1355  break;
1356  }
1357  }
1358  if (!copyOp)
1359  return rewriter.notifyMatchFailure(xferOp, "no copy found");
1360 
1361  // `out` is the subview copied into that we replace.
1362  assert(copyOp.getTarget().getType().isa<MemRefType>());
1363  Value out = copyOp.getTarget();
1364 
1365  // Forward vector.transfer into copy.
1366  // memref.copy + linalg.fill can be used to create a padded local buffer.
1367  // The `masked` attribute is only valid on this padded buffer.
1368  // When forwarding to vector.transfer_write, the attribute must be reset
1369  // conservatively.
1370  rewriter.create<vector::TransferWriteOp>(
1371  xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
1372  xferOp.getPermutationMapAttr(), xferOp.getMask(),
1373  // in_bounds is explicitly reset
1374  /*inBoundsAttr=*/ArrayAttr());
1375 
1376  rewriter.eraseOp(copyOp);
1377  rewriter.eraseOp(xferOp);
1378 
1379  return success();
1380 }
1381 
1382 //===----------------------------------------------------------------------===//
1383 // Convolution vectorization patterns
1384 //===----------------------------------------------------------------------===//
1385 
1386 template <int N>
1387 static void bindShapeDims(ShapedType shapedType) {}
1388 
1389 template <int N, typename IntTy, typename... IntTy2>
1390 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
1391  val = shapedType.getShape()[N];
1392  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
1393 }
1394 
1395 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
1396 template <typename... IntTy>
1397 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
1398  bindShapeDims<0>(shapedType, vals...);
1399 }
1400 
1401 namespace {
1402 /// Generate a vector implementation for either:
1403 /// ```
1404 /// Op def: ( n, w, c, kw, f )
1405 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
1406 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1407 /// ```
1408 /// kw is unrolled, w is unrolled iff dilationW > 1.
1409 ///
1410 /// or
1411 ///
1412 /// ```
1413 /// Op def: ( n, c, w, f, kw )
1414 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
1415 /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
1416 /// ```
1417 /// kw is unrolled, w is unrolled iff dilationW > 1.
1418 ///
1419 /// or
1420 ///
1421 /// ```
1422 /// Op def: ( n, w, c, kw )
1423 /// Iters: ({Par(), Par(), Par(), Red()})
1424 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1425 /// ```
1426 /// kw is unrolled, w is unrolled iff dilationW > 1.
1427 struct Conv1DGenerator
1428  : public StructuredGenerator<LinalgOp, utils::IteratorType> {
1429  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
1430  int dilationW)
1431  : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
1432  strideW(strideW), dilationW(dilationW) {
1433  // Determine whether `linalgOp` can be generated with this generator
1434  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
1435  return;
1436  lhsShaped = linalgOp.getDpsInputOperand(0)->get();
1437  rhsShaped = linalgOp.getDpsInputOperand(1)->get();
1438  resShaped = linalgOp.getDpsInitOperand(0)->get();
1439  lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
1440  rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
1441  resShapedType = resShaped.getType().dyn_cast<ShapedType>();
1442  if (!lhsShapedType || !rhsShapedType || !resShapedType)
1443  return;
1444  if (lhsShapedType.getRank() != 3 ||
1445  (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) ||
1446  resShapedType.getRank() != 3)
1447  return;
1448 
1449  // Check for reduction `add` preceded by `mul`.
1450  Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
1451  if (!reduceOp)
1452  return;
1454  maybeKind = getCombinerOpKind(reduceOp);
1455  if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
1456  return;
1457  // Check for single `mul` predecessor. The `mul` operands must be block
1458  // arguments or extension of block arguments.
1459  Operation *mulOp = nullptr;
1460  for (Value operand : reduceOp->getOperands()) {
1461  if (operand.isa<BlockArgument>())
1462  continue;
1463  if (mulOp)
1464  return;
1465  mulOp = operand.getDefiningOp();
1466  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
1467  return;
1468  }
1469  if (!mulOp)
1470  return;
1471  for (Value operand : mulOp->getOperands()) {
1472  if (Operation *def = operand.getDefiningOp()) {
1473  if (!isa<CastOpInterface>(def))
1474  return;
1475  operand = def->getOperand(0);
1476  }
1477  if (!operand.isa<BlockArgument>())
1478  return;
1479  }
1480  // The op is now known to be valid.
1481  valid = true;
1482  }
1483 
1484  /// Generate a vector implementation for:
1485  /// ```
1486  /// Op def: ( n, w, c, kw, f )
1487  /// Iters: ({Par(), Par(), Par(), Red(), Red()})
1488  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1489  /// ```
1490  /// kw is always unrolled.
1491  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
1492  /// > 1.
1493  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
1494  if (!valid)
1495  return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv");
1496 
1497  int64_t nSize, wSize, cSize, kwSize, fSize;
1498  SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
1499  switch (conv1DOpOrder) {
1500  case Conv1DOpOrder::Nwc:
1501  // kernel{kw, c, f}
1502  bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
1503  // out{n, w, f}
1504  bindShapeDims(resShapedType, nSize, wSize);
1505  lhsShape = {nSize,
1506  // iw = ow * sw + kw * dw - 1
1507  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1508  // Perform the proper inclusive -> exclusive -> inclusive.
1509  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
1510  1,
1511  cSize};
1512  rhsShape = {kwSize, cSize, fSize};
1513  resShape = {nSize, wSize, fSize};
1514  break;
1515  case Conv1DOpOrder::Ncw:
1516  // kernel{f, c, kw}
1517  bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
1518  // out{n, f, w}
1519  bindShapeDims(resShapedType, nSize, fSize, wSize);
1520  lhsShape = {nSize, cSize,
1521  // iw = ow * sw + kw * dw - 1
1522  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1523  // Perform the proper inclusive -> exclusive -> inclusive.
1524  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
1525  1};
1526  rhsShape = {fSize, cSize, kwSize};
1527  resShape = {nSize, fSize, wSize};
1528  break;
1529  }
1530 
1531  vector::TransferWriteOp write;
1532  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1533 
1534  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
1535  // When strideW == 1, we can batch the contiguous loads and avoid
1536  // unrolling
1537  int64_t wSizeStep = strideW == 1 ? wSize : 1;
1538 
1539  Type lhsEltType = lhsShapedType.getElementType();
1540  Type rhsEltType = rhsShapedType.getElementType();
1541  Type resEltType = resShapedType.getElementType();
1542  auto lhsType = VectorType::get(lhsShape, lhsEltType);
1543  auto rhsType = VectorType::get(rhsShape, rhsEltType);
1544  auto resType = VectorType::get(resShape, resEltType);
1545  // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
1546  // 0].
1547  Value lhs = rewriter.create<vector::TransferReadOp>(
1548  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
1549  // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
1550  Value rhs = rewriter.create<vector::TransferReadOp>(
1551  loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
1552  // Read res slice of size {n, w, f} @ [0, 0, 0].
1553  Value res = rewriter.create<vector::TransferReadOp>(
1554  loc, resType, resShaped, ValueRange{zero, zero, zero});
1555 
1556  // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
1557  // {n,w,f}. To reuse the base pattern vectorization case, we do pre
1558  // transpose on input, weight, and output.
1559  switch (conv1DOpOrder) {
1560  case Conv1DOpOrder::Nwc:
1561  // Base case, so no transposes necessary.
1562  break;
1563  case Conv1DOpOrder::Ncw: {
1564  // To match base vectorization case, we pre-transpose current case.
1565  // ncw -> nwc
1566  static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
1567  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
1568  // fcw -> wcf
1569  static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
1570  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
1571  // nfw -> nwf
1572  static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
1573  res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
1574  break;
1575  }
1576  }
1577 
1578  //===------------------------------------------------------------------===//
1579  // Begin vector-only rewrite part
1580  //===------------------------------------------------------------------===//
1581  // Unroll along kw and read slices of lhs and rhs.
1582  SmallVector<Value> lhsVals, rhsVals, resVals;
1583  // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
1584  for (int64_t kw = 0; kw < kwSize; ++kw) {
1585  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1586  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
1587  loc, lhs,
1588  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
1589  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
1590  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1591  }
1592  }
1593  // Extract rhs slice of size {c, f} @ [kw].
1594  for (int64_t kw = 0; kw < kwSize; ++kw) {
1595  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
1596  loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
1597  }
1598  // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
1599  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1600  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
1601  loc, res,
1602  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
1603  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
1604  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1605  }
1606 
1607  auto linearIndex = [&](int64_t kw, int64_t w) {
1608  return kw * (wSize / wSizeStep) + w;
1609  };
1610 
1611  // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
1612  for (int64_t kw = 0; kw < kwSize; ++kw) {
1613  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1614  resVals[w] = conv1dSliceAsContraction(
1615  rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
1616  }
1617  }
1618 
1619  // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
1620  // This does not depend on kw.
1621  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1622  res = rewriter.create<vector::InsertStridedSliceOp>(
1623  loc, resVals[w], res,
1624  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
1625  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
1626  }
1627  //===------------------------------------------------------------------===//
1628  // End vector-only rewrite part
1629  //===------------------------------------------------------------------===//
1630 
1631  // The base vectorization case is output: {n,w,f}
1632  // To reuse the result from base pattern vectorization case, we post
1633  // transpose the base case result.
1634  switch (conv1DOpOrder) {
1635  case Conv1DOpOrder::Nwc:
1636  // Base case, so no transposes necessary.
1637  break;
1638  case Conv1DOpOrder::Ncw: {
1639  // nwf -> nfw
1640  static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
1641  res = rewriter.create<vector::TransposeOp>(loc, res, perm);
1642  break;
1643  }
1644  }
1645 
1646  // Write back res slice of size {n, w, f} @ [0, 0, 0].
1647  return rewriter
1648  .create<vector::TransferWriteOp>(loc, res, resShaped,
1649  ValueRange{zero, zero, zero})
1650  .getOperation();
1651  }
1652 
1653  // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
1654  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
1655  Value lhs, Value rhs, Value res) {
1656  vector::IteratorType par = vector::IteratorType::parallel;
1657  vector::IteratorType red = vector::IteratorType::reduction;
1658  AffineExpr n, w, f, c;
1659  bindDims(ctx, n, w, f, c);
1660  return rewriter.create<vector::ContractionOp>(
1661  loc, lhs, rhs, res,
1662  /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
1663  /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
1664  }
1665 
1666  /// Generate a vector implementation for:
1667  /// ```
1668  /// Op def: ( n, w, c, kw)
1669  /// Iters: ({Par(), Par(), Par(), Red()})
1670  /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1671  /// ```
1672  /// kw is always unrolled.
1673  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
1674  /// > 1.
1675  FailureOr<Operation *> depthwiseConv() {
1676  if (!valid)
1677  return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
1678 
1679  int64_t nSize, wSize, cSize, kwSize;
1680  // kernel{kw, c}
1681  bindShapeDims(rhsShapedType, kwSize, cSize);
1682  // out{n, w, c}
1683  bindShapeDims(resShapedType, nSize, wSize);
1684 
1685  vector::TransferWriteOp write;
1686  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1687 
1688  // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
1689  // When strideW == 1, we can batch the contiguous loads and avoid
1690  // unrolling
1691  int64_t wSizeStep = strideW == 1 ? wSize : 1;
1692 
1693  Type lhsEltType = lhsShapedType.getElementType();
1694  Type rhsEltType = rhsShapedType.getElementType();
1695  Type resEltType = resShapedType.getElementType();
1696  VectorType lhsType = VectorType::get(
1697  {nSize,
1698  // iw = ow * sw + kw * dw - 1
1699  // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1700  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
1701  cSize},
1702  lhsEltType);
1703  VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
1704  VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
1705 
1706  // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
1707  // 0].
1708  Value lhs = rewriter.create<vector::TransferReadOp>(
1709  loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
1710  // Read rhs slice of size {kw, c} @ [0, 0].
1711  Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
1712  ValueRange{zero, zero});
1713  // Read res slice of size {n, w, c} @ [0, 0, 0].
1714  Value res = rewriter.create<vector::TransferReadOp>(
1715  loc, resType, resShaped, ValueRange{zero, zero, zero});
1716 
1717  //===------------------------------------------------------------------===//
1718  // Begin vector-only rewrite part
1719  //===------------------------------------------------------------------===//
1720  // Unroll along kw and read slices of lhs and rhs.
1721  SmallVector<Value> lhsVals, rhsVals, resVals;
1722  // Extract lhs slice of size {n, wSizeStep, c}
1723  // @ [0, sw * w + dw * kw, 0].
1724  for (int64_t kw = 0; kw < kwSize; ++kw) {
1725  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1726  lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
1727  loc, lhs,
1728  /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
1729  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
1730  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1731  }
1732  }
1733  // Extract rhs slice of size {c} @ [kw].
1734  for (int64_t kw = 0; kw < kwSize; ++kw) {
1735  rhsVals.push_back(rewriter.create<vector::ExtractOp>(
1736  loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
1737  }
1738  // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
1739  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1740  resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
1741  loc, res,
1742  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
1743  /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
1744  /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1745  }
1746 
1747  auto linearIndex = [&](int64_t kw, int64_t w) {
1748  return kw * (wSize / wSizeStep) + w;
1749  };
1750 
1751  // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
1752  for (int64_t kw = 0; kw < kwSize; ++kw) {
1753  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1754  resVals[w] = depthwiseConv1dSliceAsMulAcc(
1755  rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
1756  }
1757  }
1758 
1759  // Its possible we failed to create the Fma.
1760  if (!llvm::all_of(resVals, [](Value v) { return v; })) {
1761  // Manually revert (in reverse order) to avoid leaving a bad IR state.
1762  for (auto &collection : {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
1763  for (Value v : collection)
1764  rewriter.eraseOp(v.getDefiningOp());
1765  return rewriter.notifyMatchFailure(op, "failed to create FMA");
1766  }
1767 
1768  // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
1769  // This does not depend on kw.
1770  for (int64_t w = 0; w < wSize; w += wSizeStep) {
1771  res = rewriter.create<vector::InsertStridedSliceOp>(
1772  loc, resVals[w], res,
1773  /*offsets=*/ArrayRef<int64_t>{0, w, 0},
1774  /*strides=*/ArrayRef<int64_t>{1, 1, 1});
1775  }
1776  //===------------------------------------------------------------------===//
1777  // End vector-only rewrite part
1778  //===------------------------------------------------------------------===//
1779 
1780  // Write back res slice of size {n, w, c} @ [0, 0, 0].
1781  return rewriter
1782  .create<vector::TransferWriteOp>(loc, res, resShaped,
1783  ValueRange{zero, zero, zero})
1784  .getOperation();
1785  }
1786 
1787  // Take a value of element type T and widen to the destination type.
1788  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
1789  if (val.getType() == ty)
1790  return val;
1791 
1792  const int64_t srcWidth =
1794  const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
1795 
1796  if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
1797  return rewriter.create<arith::ExtFOp>(loc, ty, val);
1798 
1799  if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
1800  return rewriter.create<arith::ExtSIOp>(loc, ty, val);
1801 
1802  return nullptr;
1803  }
1804 
1805  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
1806  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
1807  Value lhs, Value rhs, Value res) {
1808  auto rhsTy = rhs.getType().cast<ShapedType>();
1809  auto resTy = res.getType().cast<ShapedType>();
1810 
1811  // TODO(suderman): Change this to use a vector.ima intrinsic.
1812  lhs = promote(rewriter, loc, lhs, resTy);
1813 
1814  rhs = rewriter.create<vector::BroadcastOp>(
1815  loc, resTy.clone(rhsTy.getElementType()), rhs);
1816  rhs = promote(rewriter, loc, rhs, resTy);
1817 
1818  if (!lhs || !rhs)
1819  return nullptr;
1820 
1821  if (resTy.getElementType().isa<FloatType>())
1822  return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
1823 
1824  auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
1825  return rewriter.create<arith::AddIOp>(loc, mul, res);
1826  }
1827 
1828  /// Entry point that transposes into the common form:
1829  /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1830  FailureOr<Operation *> generateNwcConv() {
1831  AffineExpr n, w, f, kw, c;
1832  bindDims(ctx, n, w, f, kw, c);
1833  if (!iters({Par(), Par(), Par(), Red(), Red()}))
1834  return rewriter.notifyMatchFailure(
1835  op, "failed to match conv::Nwc 3-par 2-red");
1836 
1837  // No transposition needed.
1838  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
1839  /*rhsIndex*/ {kw, c, f},
1840  /*resIndex*/ {n, w, f}}))
1841  return conv(Conv1DOpOrder::Nwc);
1842  return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
1843  }
1844 
1845  /// Entry point that transposes into the common form:
1846  /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
1847  FailureOr<Operation *> generateNcwConv() {
1848  AffineExpr n, w, f, kw, c;
1849  bindDims(ctx, n, f, w, c, kw);
1850  if (!iters({Par(), Par(), Par(), Red(), Red()}))
1851  return rewriter.notifyMatchFailure(
1852  op, "failed to match conv::Ncw 3-par 2-red");
1853 
1854  if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
1855  /*rhsIndex*/ {f, c, kw},
1856  /*resIndex*/ {n, f, w}}))
1857  return conv(Conv1DOpOrder::Ncw);
1858 
1859  return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
1860  }
1861 
1862  /// Entry point that transposes into the common form:
1863  /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1864  FailureOr<Operation *> generateDilatedConv() {
1865  AffineExpr n, w, c, kw;
1866  bindDims(ctx, n, w, c, kw);
1867  if (!iters({Par(), Par(), Par(), Red()}))
1868  return rewriter.notifyMatchFailure(
1869  op, "failed to match depthwise::Nwc conv 3-par 1-red");
1870 
1871  // No transposition needed.
1872  if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
1873  /*rhsIndex*/ {kw, c},
1874  /*resIndex*/ {n, w, c}}))
1875  return depthwiseConv();
1876 
1877  return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
1878  }
1879 
1880 private:
1881  bool valid = false;
1882  int strideW, dilationW;
1883  Value lhsShaped, rhsShaped, resShaped;
1884  ShapedType lhsShapedType, rhsShapedType, resShapedType;
1885 };
1886 } // namespace
1887 
1888 /// Helper function to vectorize a LinalgOp with convolution semantics.
1889 // TODO: extend the generic vectorization to support windows and drop this.
1891  LinalgOp op) {
1892  // The ConvolutionOpInterface gives us guarantees of existence for
1893  // strides/dilations. However, we do not need to rely on those, we can simply
1894  // use them if present, otherwise use the default and let the generic conv.
1895  // matcher in the ConvGenerator succeed or fail.
1896  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
1897  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
1898  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
1899  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
1900  Conv1DGenerator e(rewriter, op, stride, dilation);
1901  auto res = e.generateNwcConv();
1902  if (succeeded(res))
1903  return res;
1904  res = e.generateNcwConv();
1905  if (succeeded(res))
1906  return res;
1907  return e.generateDilatedConv();
1908 }
1909 
1911  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
1912 
1914  PatternRewriter &rewriter) const override {
1915  FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
1916  if (failed(resultOrFail))
1917  return failure();
1918  Operation *newOp = *resultOrFail;
1919  if (newOp->getNumResults() == 0) {
1920  rewriter.eraseOp(op.getOperation());
1921  return success();
1922  }
1923  assert(newOp->getNumResults() == 1 && "expected single result");
1924  rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
1925  return success();
1926  }
1927 };
1928 
1930  RewritePatternSet &patterns, PatternBenefit benefit) {
1931  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
1932 }
static constexpr const bool value
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const BlockAndValueMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const BlockAndValueMapping &bvm, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static Value buildVectorWrite(OpBuilder &b, Value value, OpOperand *outputOperand)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp, const BlockAndValueMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static void bindShapeDims(ShapedType shapedType)
static SmallVector< bool > getReductionMask(LinalgOp linalgOp)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op, ArrayRef< CustomVectorizationPrecondition > customPreconditions)
std::function< VectorizationResult(Operation *, const BlockAndValueMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, const SmallVector< bool > &reductionMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp)
Try to vectorize convOp as a convolution.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static Value broadcastIfNeeded(OpBuilder &b, Value value, ArrayRef< int64_t > shape)
Broadcast value to a vector of shape if possible.
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
#define LDBG(X)
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
std::function< LogicalResult(Operation *)> CustomVectorizationPrecondition
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:494
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:455
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:524
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:137
Block * lookup(Block *from) const
Lookup a mapped value within the map.
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
This class represents an argument of a Block.
Definition: Value.h:296
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:308
Block represents an ordered list of Operations.
Definition: Block.h:30
RetT walk(FnT &&callback)
Walk the operations in this block.
Definition: Block.h:271
OpListType & getOperations()
Definition: Block.h:126
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:350
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:306
IntegerType getI1Type()
Definition: Builders.cpp:58
IndexType getIndexType()
Definition: Builders.cpp:56
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:251
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:135
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class helps build Operations.
Definition: Builders.h:198
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:397
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:472
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:510
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:265
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:490
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
result_type_range getResultTypes()
Definition: Operation.h:345
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getResults()
Definition: Operation.h:332
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:522
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:193
U dyn_cast() const
Definition: Value.h:95
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1126
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
void bindDims(MLIRContext *ctx)
Definition: AffineExpr.h:307
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition: Utils.cpp:155
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:194
LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp)
Return success if the operation can be vectorized.
llvm::Optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:175
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp)
Emit a suitable vector form for a Linalg op with fully static shape.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:1859
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:563
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:693
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:336
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:669
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:59
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:578
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:371
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:839
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.