MLIR  19.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Debug.h"
22 #include <numeric>
23 #include <optional>
24 
25 #define DEBUG_TYPE "vector-unroll"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28 
29 using namespace mlir;
30 using namespace mlir::vector;
31 
32 /// Compute the indices of the slice `index` for a tranfer op.
34  ArrayRef<Value> indices,
35  AffineMap permutationMap,
36  Location loc,
37  OpBuilder &builder) {
38  MLIRContext *ctx = builder.getContext();
39  auto isBroadcast = [](AffineExpr expr) {
40  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
41  return constExpr.getValue() == 0;
42  return false;
43  };
44  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
45  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
46  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
47  if (isBroadcast(dim.value()))
48  continue;
49  unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
50  auto expr = getAffineDimExpr(0, builder.getContext()) +
51  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
52  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
53  slicedIndices[pos] =
54  builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
55  }
56  return slicedIndices;
57 }
58 
59 // Clones `op` into a new operations that takes `operands` and returns
60 // `resultTypes`.
62  Operation *op,
63  ArrayRef<Value> operands,
64  ArrayRef<Type> resultTypes) {
65  return builder.create(loc, op->getName().getIdentifier(), operands,
66  resultTypes, op->getAttrs());
67 }
68 
69 /// Return the target shape for unrolling for the given `op`. Return
70 /// std::nullopt if the op shouldn't be or cannot be unrolled.
71 static std::optional<SmallVector<int64_t>>
73  LDBG("");
74  LDBG("Get unroll shape for op " << op->getName().getStringRef());
75  if (options.filterConstraint && failed(options.filterConstraint(op))) {
76  LDBG("--no filter constraint -> BAIL");
77  return std::nullopt;
78  }
79  assert(options.nativeShape &&
80  "vector unrolling expects the native shape or native"
81  "shape call back function to be set");
82  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
83  if (!unrollableVectorOp) {
84  LDBG("--not an unrollable op -> BAIL");
85  return std::nullopt;
86  }
87  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
88  if (!maybeUnrollShape) {
89  LDBG("--could not get shape of op " << *op << " -> BAIL");
90  return std::nullopt;
91  }
92  LLVM_DEBUG(
93  llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: ");
94  llvm::dbgs() << "\n";);
95 
96  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
97  if (!targetShape) {
98  LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
99  return std::nullopt;
100  }
101  LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: ");
102  llvm::dbgs() << "\n";);
103 
104  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
105  if (!maybeShapeRatio) {
106  LDBG("--could not compute integral shape ratio -> BAIL");
107  return std::nullopt;
108  }
109  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
110  LDBG("--no unrolling needed -> SKIP");
111  return std::nullopt;
112  }
113  LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
114  return targetShape;
115 }
116 
118 getUnrollOrder(unsigned numLoops, Operation *op,
120  SmallVector<int64_t> loopOrder =
121  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
122  if (options.traversalOrderCallback != nullptr) {
123  std::optional<SmallVector<int64_t>> order =
124  options.traversalOrderCallback(op);
125  if (order) {
126  loopOrder = std::move(*order);
127  }
128  }
129  return loopOrder;
130 }
131 
132 namespace {
133 
134 struct UnrollTransferReadPattern
135  : public OpRewritePattern<vector::TransferReadOp> {
136  UnrollTransferReadPattern(MLIRContext *context,
138  PatternBenefit benefit = 1)
139  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
140  options(options) {}
141 
142  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
143  PatternRewriter &rewriter) const override {
144  // TODO: support 0-d corner case.
145  if (readOp.getTransferRank() == 0)
146  return failure();
147  if (readOp.getMask())
148  return failure();
149  auto targetShape = getTargetShape(options, readOp);
150  if (!targetShape)
151  return failure();
152  auto sourceVectorType = readOp.getVectorType();
153  SmallVector<int64_t> strides(targetShape->size(), 1);
154  Location loc = readOp.getLoc();
155  ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
156 
157  // Prepare the result vector;
158  Value result = rewriter.create<arith::ConstantOp>(
159  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
160  auto targetType =
161  VectorType::get(*targetShape, sourceVectorType.getElementType());
162  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
163  readOp.getIndices().end());
164  SmallVector<int64_t> loopOrder =
165  getUnrollOrder(originalSize.size(), readOp, options);
166  for (SmallVector<int64_t> elementOffsets :
167  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
168  SmallVector<Value> indices =
169  sliceTransferIndices(elementOffsets, originalIndices,
170  readOp.getPermutationMap(), loc, rewriter);
171  auto slicedRead = rewriter.create<vector::TransferReadOp>(
172  loc, targetType, readOp.getSource(), indices,
173  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
174  readOp.getInBoundsAttr());
175 
176  result = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, slicedRead, result, elementOffsets, strides);
178  }
179  rewriter.replaceOp(readOp, result);
180  return success();
181  }
182 
183 private:
185 };
186 
187 struct UnrollTransferWritePattern
188  : public OpRewritePattern<vector::TransferWriteOp> {
189  UnrollTransferWritePattern(MLIRContext *context,
191  PatternBenefit benefit = 1)
192  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
193  options(options) {}
194 
195  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
196  PatternRewriter &rewriter) const override {
197  // TODO: support 0-d corner case.
198  if (writeOp.getTransferRank() == 0)
199  return failure();
200 
201  if (writeOp.getMask())
202  return failure();
203  auto targetShape = getTargetShape(options, writeOp);
204  if (!targetShape)
205  return failure();
206  auto sourceVectorType = writeOp.getVectorType();
207  SmallVector<int64_t> strides(targetShape->size(), 1);
208  Location loc = writeOp.getLoc();
209  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
210  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
211  writeOp.getIndices().end());
212  SmallVector<int64_t> loopOrder =
213  getUnrollOrder(originalSize.size(), writeOp, options);
214  Value resultTensor;
215  for (SmallVector<int64_t> elementOffsets :
216  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
217  Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
218  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
219  SmallVector<Value> indices =
220  sliceTransferIndices(elementOffsets, originalIndices,
221  writeOp.getPermutationMap(), loc, rewriter);
222  Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
223  loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
224  indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
225  // For the tensor case update the destination for the next transfer write.
226  if (!slicedWrite->getResults().empty())
227  resultTensor = slicedWrite->getResult(0);
228  }
229  if (resultTensor)
230  rewriter.replaceOp(writeOp, resultTensor);
231  else
232  rewriter.eraseOp(writeOp);
233  return success();
234  }
235 
236 private:
238 };
239 
240 struct OffsetMapInfo {
241  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
242 
243  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
244 
245  static unsigned getHashValue(const SmallVector<int64_t> &v) {
246  return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
247  }
248 
249  static bool isEqual(const SmallVector<int64_t> &lhs,
250  const SmallVector<int64_t> &rhs) {
251  return lhs == rhs;
252  }
253 };
254 
255 struct UnrollContractionPattern
256  : public OpRewritePattern<vector::ContractionOp> {
257  UnrollContractionPattern(MLIRContext *context,
259  PatternBenefit benefit = 1)
260  : OpRewritePattern<vector::ContractionOp>(context, benefit),
261  options(options) {}
262 
263  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
264  PatternRewriter &rewriter) const override {
265  auto targetShape = getTargetShape(options, contractOp);
266  if (!targetShape)
267  return failure();
268  auto dstVecType = cast<VectorType>(contractOp.getResultType());
269  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
270 
271  Location loc = contractOp.getLoc();
272  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
273  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
274  llvm::MapVector<
276  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
277  accCache;
278 
280  contractOp.getIteratorTypes().size(), contractOp, options);
281 
282  for (SmallVector<int64_t> offsets :
283  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
284  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
285 
286  // Helper to compute the new shape of each operand and extract the slice.
287  auto extractOperand = [&](unsigned index, Value operand,
288  AffineMap permutationMap,
289  ArrayRef<int64_t> operandOffets) {
291  permutationMap, ArrayRef<int64_t>(*targetShape));
292  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
293  slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
294  loc, operand, operandOffets, operandShape, operandStrides);
295  };
296 
297  // Extract the new lhs operand.
298  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
299  SmallVector<int64_t> lhsOffets =
300  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
301  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
302 
303  // Extract the new rhs operand.
304  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
305  SmallVector<int64_t> rhsOffets =
306  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
307  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
308 
309  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
310  SmallVector<int64_t> accOffets =
311  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
312  // If a version of the accumulator has already been computed, use it
313  // otherwise extract the first version from the original operand.
314  auto *accIt = accCache.find(accOffets);
315  if (accIt != accCache.end())
316  slicesOperands[2] = accIt->second;
317  else
318  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
319 
320  SmallVector<int64_t> dstShape =
321  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
322  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
324  rewriter, loc, contractOp, slicesOperands, targetType);
325 
326  SmallVector<int64_t> dstOffets =
327  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
328  // Save the accumulated value untill all the loops are unrolled since
329  // reduction loop keep updating the accumulator.
330  accCache[dstOffets] = newOp->getResult(0);
331  }
332  // Assemble back the accumulator into a single vector.
333  Value result = rewriter.create<arith::ConstantOp>(
334  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
335  for (const auto &it : accCache) {
336  SmallVector<int64_t> dstStrides(it.first.size(), 1);
337  result = rewriter.create<vector::InsertStridedSliceOp>(
338  loc, it.second, result, it.first, dstStrides);
339  }
340  rewriter.replaceOp(contractOp, result);
341  return success();
342  }
343 
344 private:
346 };
347 
348 struct UnrollMultiReductionPattern
349  : public OpRewritePattern<vector::MultiDimReductionOp> {
350  UnrollMultiReductionPattern(MLIRContext *context,
352  PatternBenefit benefit = 1)
353  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
354  options(options) {}
355 
356  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
357  PatternRewriter &rewriter) const override {
358  std::optional<SmallVector<int64_t>> targetShape =
359  getTargetShape(options, reductionOp);
360  if (!targetShape)
361  return failure();
362  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
363  llvm::MapVector<
365  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
366  accCache;
367  Location loc = reductionOp.getLoc();
368 
369  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
370  // of multiples of the targetShape.
371  for (SmallVector<int64_t> offsets :
372  StaticTileOffsetRange(originalSize, *targetShape)) {
373  SmallVector<Value> operands;
374  SmallVector<int64_t> operandStrides(offsets.size(), 1);
375  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
376  loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
377  operands.push_back(slicedOperand);
378  SmallVector<int64_t> dstShape;
379  SmallVector<int64_t> destOffset;
380  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
381  if (!reductionOp.isReducedDim(i)) {
382  destOffset.push_back(offsets[i]);
383  dstShape.push_back((*targetShape)[i]);
384  }
385  }
386  Value acc;
387  SmallVector<int64_t> accStrides(destOffset.size(), 1);
388  // If a version of the accumulator has already been computed, use it
389  // otherwise extract the first version from the original operand.
390  auto *accIt = accCache.find(destOffset);
391  if (accIt != accCache.end())
392  acc = accIt->second;
393  else
394  acc = rewriter.create<vector::ExtractStridedSliceOp>(
395  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
396  operands.push_back(acc);
397  auto targetType = VectorType::get(
398  dstShape, reductionOp.getSourceVectorType().getElementType());
399  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
400  operands, targetType);
401  Value result = newOp->getResult(0);
402  accCache[destOffset] = result;
403  }
404  // Assemble back the accumulator into a single vector.
405  Value result = rewriter.create<arith::ConstantOp>(
406  loc, reductionOp.getDestType(),
407  rewriter.getZeroAttr(reductionOp.getDestType()));
408  for (const auto &it : accCache) {
409  SmallVector<int64_t> dstStrides(it.first.size(), 1);
410  result = rewriter.create<vector::InsertStridedSliceOp>(
411  loc, it.second, result, it.first, dstStrides);
412  }
413  rewriter.replaceOp(reductionOp, result);
414  return success();
415  }
416 
417 private:
419 };
420 
421 struct UnrollElementwisePattern : public RewritePattern {
422  UnrollElementwisePattern(MLIRContext *context,
424  PatternBenefit benefit = 1)
425  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
426  options(options) {}
427 
428  LogicalResult matchAndRewrite(Operation *op,
429  PatternRewriter &rewriter) const override {
431  return failure();
432  auto targetShape = getTargetShape(options, op);
433  if (!targetShape)
434  return failure();
435  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
436  SmallVector<int64_t> originalSize =
437  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
438  Location loc = op->getLoc();
439  // Prepare the result vector.
440  Value result = rewriter.create<arith::ConstantOp>(
441  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
442  SmallVector<int64_t> strides(targetShape->size(), 1);
443  VectorType newVecType =
444  VectorType::get(*targetShape, dstVecType.getElementType());
445 
446  // Create the unrolled computation.
447  for (SmallVector<int64_t> offsets :
448  StaticTileOffsetRange(originalSize, *targetShape)) {
449  SmallVector<Value> extractOperands;
450  for (OpOperand &operand : op->getOpOperands()) {
451  auto vecType = dyn_cast<VectorType>(operand.get().getType());
452  if (!vecType) {
453  extractOperands.push_back(operand.get());
454  continue;
455  }
456  extractOperands.push_back(
457  rewriter.create<vector::ExtractStridedSliceOp>(
458  loc, operand.get(), offsets, *targetShape, strides));
459  }
461  rewriter, loc, op, extractOperands, newVecType);
462  result = rewriter.create<vector::InsertStridedSliceOp>(
463  loc, newOp->getResult(0), result, offsets, strides);
464  }
465  rewriter.replaceOp(op, result);
466  return success();
467  }
468 
469 private:
471 };
472 
473 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
474  UnrollReductionPattern(MLIRContext *context,
476  PatternBenefit benefit = 1)
477  : OpRewritePattern<vector::ReductionOp>(context, benefit),
478  options(options) {}
479 
480  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
481  PatternRewriter &rewriter) const override {
482  std::optional<SmallVector<int64_t>> targetShape =
483  getTargetShape(options, reductionOp);
484  if (!targetShape)
485  return failure();
486  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
487 
488  // Create unrolled vector reduction.
489  Location loc = reductionOp.getLoc();
490  Value accumulator = nullptr;
491  for (SmallVector<int64_t> offsets :
492  StaticTileOffsetRange(originalSize, *targetShape)) {
493  SmallVector<int64_t> strides(offsets.size(), 1);
494  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
495  loc, reductionOp.getVector(), offsets, *targetShape, strides);
497  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
498  Value result = newOp->getResult(0);
499 
500  if (!accumulator) {
501  // This is the first reduction.
502  accumulator = result;
503  } else {
504  // On subsequent reduction, combine with the accumulator.
505  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
506  accumulator, result);
507  }
508  }
509 
510  rewriter.replaceOp(reductionOp, accumulator);
511  return success();
512  }
513 
514 private:
516 };
517 
518 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
519  UnrollTransposePattern(MLIRContext *context,
521  PatternBenefit benefit = 1)
522  : OpRewritePattern<vector::TransposeOp>(context, benefit),
523  options(options) {}
524 
525  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
526  PatternRewriter &rewriter) const override {
527  if (transposeOp.getResultVectorType().getRank() == 0)
528  return failure();
529  auto targetShape = getTargetShape(options, transposeOp);
530  if (!targetShape)
531  return failure();
532  auto originalVectorType = transposeOp.getResultVectorType();
533  SmallVector<int64_t> strides(targetShape->size(), 1);
534  Location loc = transposeOp.getLoc();
535  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
536 
537  // Prepare the result vector;
538  Value result = rewriter.create<arith::ConstantOp>(
539  loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
540  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
541 
542  // Unroll the computation.
543  for (SmallVector<int64_t> elementOffsets :
544  StaticTileOffsetRange(originalSize, *targetShape)) {
545  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
546  SmallVector<int64_t> permutedShape(elementOffsets.size());
547  // Compute the source offsets and shape.
548  for (auto indices : llvm::enumerate(permutation)) {
549  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
550  permutedShape[indices.value()] = (*targetShape)[indices.index()];
551  }
552  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
553  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
554  strides);
555  Value transposedSlice =
556  rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
557  result = rewriter.create<vector::InsertStridedSliceOp>(
558  loc, transposedSlice, result, elementOffsets, strides);
559  }
560  rewriter.replaceOp(transposeOp, result);
561  return success();
562  }
563 
564 private:
566 };
567 
568 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
569  UnrollGatherPattern(MLIRContext *context,
571  PatternBenefit benefit = 1)
572  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
573  }
574 
575  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
576  PatternRewriter &rewriter) const override {
577  VectorType sourceVectorType = gatherOp.getVectorType();
578  if (sourceVectorType.getRank() == 0)
579  return failure();
580  auto targetShape = getTargetShape(options, gatherOp);
581  if (!targetShape)
582  return failure();
583  SmallVector<int64_t> strides(targetShape->size(), 1);
584  Location loc = gatherOp.getLoc();
585  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
586 
587  // Prepare the result vector;
588  Value result = rewriter.create<arith::ConstantOp>(
589  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
590  auto targetType =
591  VectorType::get(*targetShape, sourceVectorType.getElementType());
592 
593  SmallVector<int64_t> loopOrder =
594  getUnrollOrder(originalSize.size(), gatherOp, options);
595  for (SmallVector<int64_t> elementOffsets :
596  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
597  // To get the unrolled gather, extract the same slice based on the
598  // decomposed shape from each of the index, mask, and pass-through
599  // vectors.
600  Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
601  loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
602  Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
603  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
604  Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
605  loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
606  auto slicedGather = rewriter.create<vector::GatherOp>(
607  loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
608  indexSubVec, maskSubVec, passThruSubVec);
609 
610  result = rewriter.create<vector::InsertStridedSliceOp>(
611  loc, slicedGather, result, elementOffsets, strides);
612  }
613  rewriter.replaceOp(gatherOp, result);
614  return success();
615  }
616 
617 private:
619 };
620 
621 } // namespace
622 
625  PatternBenefit benefit) {
626  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
627  UnrollContractionPattern, UnrollElementwisePattern,
628  UnrollReductionPattern, UnrollMultiReductionPattern,
629  UnrollTransposePattern, UnrollGatherPattern>(
630  patterns.getContext(), options, benefit);
631 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
#define DBGS()
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1391
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:650
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Options that control the vector unrolling.