MLIR  19.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/Debug.h"
21 #include <numeric>
22 #include <optional>
23 
24 #define DEBUG_TYPE "vector-unroll"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
27 
28 using namespace mlir;
29 using namespace mlir::vector;
30 
31 /// Compute the indices of the slice `index` for a tranfer op.
33  ArrayRef<Value> indices,
34  AffineMap permutationMap,
35  Location loc,
36  OpBuilder &builder) {
37  MLIRContext *ctx = builder.getContext();
38  auto isBroadcast = [](AffineExpr expr) {
39  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
40  return constExpr.getValue() == 0;
41  return false;
42  };
43  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
44  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
45  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
46  if (isBroadcast(dim.value()))
47  continue;
48  unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
49  auto expr = getAffineDimExpr(0, builder.getContext()) +
50  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
51  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
52  slicedIndices[pos] =
53  builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
54  }
55  return slicedIndices;
56 }
57 
58 // Clones `op` into a new operations that takes `operands` and returns
59 // `resultTypes`.
61  Operation *op,
62  ArrayRef<Value> operands,
63  ArrayRef<Type> resultTypes) {
64  return builder.create(loc, op->getName().getIdentifier(), operands,
65  resultTypes, op->getAttrs());
66 }
67 
68 /// Return the target shape for unrolling for the given `op`. Return
69 /// std::nullopt if the op shouldn't be or cannot be unrolled.
70 static std::optional<SmallVector<int64_t>>
72  LDBG("");
73  LDBG("Get unroll shape for op " << op->getName().getStringRef());
74  if (options.filterConstraint && failed(options.filterConstraint(op))) {
75  LDBG("--no filter constraint -> BAIL");
76  return std::nullopt;
77  }
78  assert(options.nativeShape &&
79  "vector unrolling expects the native shape or native"
80  "shape call back function to be set");
81  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
82  if (!unrollableVectorOp) {
83  LDBG("--not an unrollable op -> BAIL");
84  return std::nullopt;
85  }
86  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
87  if (!maybeUnrollShape) {
88  LDBG("--could not get shape of op " << *op << " -> BAIL");
89  return std::nullopt;
90  }
91  LLVM_DEBUG(
92  llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: ");
93  llvm::dbgs() << "\n";);
94 
95  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
96  if (!targetShape) {
97  LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
98  return std::nullopt;
99  }
100  LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: ");
101  llvm::dbgs() << "\n";);
102 
103  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
104  if (!maybeShapeRatio) {
105  LDBG("--could not compute integral shape ratio -> BAIL");
106  return std::nullopt;
107  }
108  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
109  LDBG("--no unrolling needed -> SKIP");
110  return std::nullopt;
111  }
112  LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
113  return targetShape;
114 }
115 
117 getUnrollOrder(unsigned numLoops, Operation *op,
119  SmallVector<int64_t> loopOrder =
120  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
121  if (options.traversalOrderCallback != nullptr) {
122  std::optional<SmallVector<int64_t>> order =
123  options.traversalOrderCallback(op);
124  if (order) {
125  loopOrder = std::move(*order);
126  }
127  }
128  return loopOrder;
129 }
130 
131 namespace {
132 
133 struct UnrollTransferReadPattern
134  : public OpRewritePattern<vector::TransferReadOp> {
135  UnrollTransferReadPattern(MLIRContext *context,
137  PatternBenefit benefit = 1)
138  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
139  options(options) {}
140 
141  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
142  PatternRewriter &rewriter) const override {
143  // TODO: support 0-d corner case.
144  if (readOp.getTransferRank() == 0)
145  return failure();
146  if (readOp.getMask())
147  return failure();
148  auto targetShape = getTargetShape(options, readOp);
149  if (!targetShape)
150  return failure();
151  auto sourceVectorType = readOp.getVectorType();
152  SmallVector<int64_t> strides(targetShape->size(), 1);
153  Location loc = readOp.getLoc();
154  ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
155 
156  // Prepare the result vector;
157  Value result = rewriter.create<arith::ConstantOp>(
158  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
159  auto targetType =
160  VectorType::get(*targetShape, sourceVectorType.getElementType());
161  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
162  readOp.getIndices().end());
163  SmallVector<int64_t> loopOrder =
164  getUnrollOrder(originalSize.size(), readOp, options);
165  for (SmallVector<int64_t> elementOffsets :
166  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
167  SmallVector<Value> indices =
168  sliceTransferIndices(elementOffsets, originalIndices,
169  readOp.getPermutationMap(), loc, rewriter);
170  auto slicedRead = rewriter.create<vector::TransferReadOp>(
171  loc, targetType, readOp.getSource(), indices,
172  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
173  readOp.getInBoundsAttr());
174 
175  result = rewriter.create<vector::InsertStridedSliceOp>(
176  loc, slicedRead, result, elementOffsets, strides);
177  }
178  rewriter.replaceOp(readOp, result);
179  return success();
180  }
181 
182 private:
184 };
185 
186 struct UnrollTransferWritePattern
187  : public OpRewritePattern<vector::TransferWriteOp> {
188  UnrollTransferWritePattern(MLIRContext *context,
190  PatternBenefit benefit = 1)
191  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
192  options(options) {}
193 
194  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
195  PatternRewriter &rewriter) const override {
196  // TODO: support 0-d corner case.
197  if (writeOp.getTransferRank() == 0)
198  return failure();
199 
200  if (writeOp.getMask())
201  return failure();
202  auto targetShape = getTargetShape(options, writeOp);
203  if (!targetShape)
204  return failure();
205  auto sourceVectorType = writeOp.getVectorType();
206  SmallVector<int64_t> strides(targetShape->size(), 1);
207  Location loc = writeOp.getLoc();
208  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
209  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
210  writeOp.getIndices().end());
211  SmallVector<int64_t> loopOrder =
212  getUnrollOrder(originalSize.size(), writeOp, options);
213  Value resultTensor;
214  for (SmallVector<int64_t> elementOffsets :
215  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
216  Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
217  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
218  SmallVector<Value> indices =
219  sliceTransferIndices(elementOffsets, originalIndices,
220  writeOp.getPermutationMap(), loc, rewriter);
221  Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
222  loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
223  indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
224  // For the tensor case update the destination for the next transfer write.
225  if (!slicedWrite->getResults().empty())
226  resultTensor = slicedWrite->getResult(0);
227  }
228  if (resultTensor)
229  rewriter.replaceOp(writeOp, resultTensor);
230  else
231  rewriter.eraseOp(writeOp);
232  return success();
233  }
234 
235 private:
237 };
238 
239 struct OffsetMapInfo {
240  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
241 
242  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
243 
244  static unsigned getHashValue(const SmallVector<int64_t> &v) {
245  return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
246  }
247 
248  static bool isEqual(const SmallVector<int64_t> &lhs,
249  const SmallVector<int64_t> &rhs) {
250  return lhs == rhs;
251  }
252 };
253 
254 struct UnrollContractionPattern
255  : public OpRewritePattern<vector::ContractionOp> {
256  UnrollContractionPattern(MLIRContext *context,
258  PatternBenefit benefit = 1)
259  : OpRewritePattern<vector::ContractionOp>(context, benefit),
260  options(options) {}
261 
262  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
263  PatternRewriter &rewriter) const override {
264  auto targetShape = getTargetShape(options, contractOp);
265  if (!targetShape)
266  return failure();
267  auto dstVecType = cast<VectorType>(contractOp.getResultType());
268  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
269 
270  Location loc = contractOp.getLoc();
271  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
272  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
273  llvm::MapVector<
275  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
276  accCache;
277 
279  contractOp.getIteratorTypes().size(), contractOp, options);
280 
281  for (SmallVector<int64_t> offsets :
282  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
283  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
284 
285  // Helper to compute the new shape of each operand and extract the slice.
286  auto extractOperand = [&](unsigned index, Value operand,
287  AffineMap permutationMap,
288  ArrayRef<int64_t> operandOffets) {
290  permutationMap, ArrayRef<int64_t>(*targetShape));
291  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
292  slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
293  loc, operand, operandOffets, operandShape, operandStrides);
294  };
295 
296  // Extract the new lhs operand.
297  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
298  SmallVector<int64_t> lhsOffets =
299  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
300  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
301 
302  // Extract the new rhs operand.
303  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
304  SmallVector<int64_t> rhsOffets =
305  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
306  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
307 
308  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
309  SmallVector<int64_t> accOffets =
310  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
311  // If a version of the accumulator has already been computed, use it
312  // otherwise extract the first version from the original operand.
313  auto *accIt = accCache.find(accOffets);
314  if (accIt != accCache.end())
315  slicesOperands[2] = accIt->second;
316  else
317  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
318 
319  SmallVector<int64_t> dstShape =
320  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
321  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
323  rewriter, loc, contractOp, slicesOperands, targetType);
324 
325  SmallVector<int64_t> dstOffets =
326  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
327  // Save the accumulated value untill all the loops are unrolled since
328  // reduction loop keep updating the accumulator.
329  accCache[dstOffets] = newOp->getResult(0);
330  }
331  // Assemble back the accumulator into a single vector.
332  Value result = rewriter.create<arith::ConstantOp>(
333  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
334  for (const auto &it : accCache) {
335  SmallVector<int64_t> dstStrides(it.first.size(), 1);
336  result = rewriter.create<vector::InsertStridedSliceOp>(
337  loc, it.second, result, it.first, dstStrides);
338  }
339  rewriter.replaceOp(contractOp, result);
340  return success();
341  }
342 
343 private:
345 };
346 
347 struct UnrollMultiReductionPattern
348  : public OpRewritePattern<vector::MultiDimReductionOp> {
349  UnrollMultiReductionPattern(MLIRContext *context,
351  PatternBenefit benefit = 1)
352  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
353  options(options) {}
354 
355  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
356  PatternRewriter &rewriter) const override {
357  std::optional<SmallVector<int64_t>> targetShape =
358  getTargetShape(options, reductionOp);
359  if (!targetShape)
360  return failure();
361  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
362  llvm::MapVector<
364  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
365  accCache;
366  Location loc = reductionOp.getLoc();
367 
368  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
369  // of multiples of the targetShape.
370  for (SmallVector<int64_t> offsets :
371  StaticTileOffsetRange(originalSize, *targetShape)) {
372  SmallVector<Value> operands;
373  SmallVector<int64_t> operandStrides(offsets.size(), 1);
374  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
375  loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
376  operands.push_back(slicedOperand);
377  SmallVector<int64_t> dstShape;
378  SmallVector<int64_t> destOffset;
379  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
380  if (!reductionOp.isReducedDim(i)) {
381  destOffset.push_back(offsets[i]);
382  dstShape.push_back((*targetShape)[i]);
383  }
384  }
385  Value acc;
386  SmallVector<int64_t> accStrides(destOffset.size(), 1);
387  // If a version of the accumulator has already been computed, use it
388  // otherwise extract the first version from the original operand.
389  auto *accIt = accCache.find(destOffset);
390  if (accIt != accCache.end())
391  acc = accIt->second;
392  else
393  acc = rewriter.create<vector::ExtractStridedSliceOp>(
394  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
395  operands.push_back(acc);
396  auto targetType = VectorType::get(
397  dstShape, reductionOp.getSourceVectorType().getElementType());
398  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
399  operands, targetType);
400  Value result = newOp->getResult(0);
401  accCache[destOffset] = result;
402  }
403  // Assemble back the accumulator into a single vector.
404  Value result = rewriter.create<arith::ConstantOp>(
405  loc, reductionOp.getDestType(),
406  rewriter.getZeroAttr(reductionOp.getDestType()));
407  for (const auto &it : accCache) {
408  SmallVector<int64_t> dstStrides(it.first.size(), 1);
409  result = rewriter.create<vector::InsertStridedSliceOp>(
410  loc, it.second, result, it.first, dstStrides);
411  }
412  rewriter.replaceOp(reductionOp, result);
413  return success();
414  }
415 
416 private:
418 };
419 
420 struct UnrollElementwisePattern : public RewritePattern {
421  UnrollElementwisePattern(MLIRContext *context,
423  PatternBenefit benefit = 1)
424  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
425  options(options) {}
426 
427  LogicalResult matchAndRewrite(Operation *op,
428  PatternRewriter &rewriter) const override {
430  return failure();
431  auto targetShape = getTargetShape(options, op);
432  if (!targetShape)
433  return failure();
434  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
435  SmallVector<int64_t> originalSize =
436  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
437  Location loc = op->getLoc();
438  // Prepare the result vector.
439  Value result = rewriter.create<arith::ConstantOp>(
440  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
441  SmallVector<int64_t> strides(targetShape->size(), 1);
442  VectorType newVecType =
443  VectorType::get(*targetShape, dstVecType.getElementType());
444 
445  // Create the unrolled computation.
446  for (SmallVector<int64_t> offsets :
447  StaticTileOffsetRange(originalSize, *targetShape)) {
448  SmallVector<Value> extractOperands;
449  for (OpOperand &operand : op->getOpOperands()) {
450  auto vecType = dyn_cast<VectorType>(operand.get().getType());
451  if (!vecType) {
452  extractOperands.push_back(operand.get());
453  continue;
454  }
455  extractOperands.push_back(
456  rewriter.create<vector::ExtractStridedSliceOp>(
457  loc, operand.get(), offsets, *targetShape, strides));
458  }
460  rewriter, loc, op, extractOperands, newVecType);
461  result = rewriter.create<vector::InsertStridedSliceOp>(
462  loc, newOp->getResult(0), result, offsets, strides);
463  }
464  rewriter.replaceOp(op, result);
465  return success();
466  }
467 
468 private:
470 };
471 
472 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
473  UnrollReductionPattern(MLIRContext *context,
475  PatternBenefit benefit = 1)
476  : OpRewritePattern<vector::ReductionOp>(context, benefit),
477  options(options) {}
478 
479  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
480  PatternRewriter &rewriter) const override {
481  std::optional<SmallVector<int64_t>> targetShape =
482  getTargetShape(options, reductionOp);
483  if (!targetShape)
484  return failure();
485  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
486 
487  // Create unrolled vector reduction.
488  Location loc = reductionOp.getLoc();
489  Value accumulator = nullptr;
490  for (SmallVector<int64_t> offsets :
491  StaticTileOffsetRange(originalSize, *targetShape)) {
492  SmallVector<int64_t> strides(offsets.size(), 1);
493  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
494  loc, reductionOp.getVector(), offsets, *targetShape, strides);
496  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
497  Value result = newOp->getResult(0);
498 
499  if (!accumulator) {
500  // This is the first reduction.
501  accumulator = result;
502  } else {
503  // On subsequent reduction, combine with the accumulator.
504  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
505  accumulator, result);
506  }
507  }
508 
509  rewriter.replaceOp(reductionOp, accumulator);
510  return success();
511  }
512 
513 private:
515 };
516 
517 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
518  UnrollTransposePattern(MLIRContext *context,
520  PatternBenefit benefit = 1)
521  : OpRewritePattern<vector::TransposeOp>(context, benefit),
522  options(options) {}
523 
524  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
525  PatternRewriter &rewriter) const override {
526  if (transposeOp.getResultVectorType().getRank() == 0)
527  return failure();
528  auto targetShape = getTargetShape(options, transposeOp);
529  if (!targetShape)
530  return failure();
531  auto originalVectorType = transposeOp.getResultVectorType();
532  SmallVector<int64_t> strides(targetShape->size(), 1);
533  Location loc = transposeOp.getLoc();
534  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
535 
536  // Prepare the result vector;
537  Value result = rewriter.create<arith::ConstantOp>(
538  loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
539  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
540 
541  // Unroll the computation.
542  for (SmallVector<int64_t> elementOffsets :
543  StaticTileOffsetRange(originalSize, *targetShape)) {
544  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
545  SmallVector<int64_t> permutedShape(elementOffsets.size());
546  // Compute the source offsets and shape.
547  for (auto indices : llvm::enumerate(permutation)) {
548  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
549  permutedShape[indices.value()] = (*targetShape)[indices.index()];
550  }
551  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
552  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
553  strides);
554  Value transposedSlice =
555  rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
556  result = rewriter.create<vector::InsertStridedSliceOp>(
557  loc, transposedSlice, result, elementOffsets, strides);
558  }
559  rewriter.replaceOp(transposeOp, result);
560  return success();
561  }
562 
563 private:
565 };
566 
567 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
568  UnrollGatherPattern(MLIRContext *context,
570  PatternBenefit benefit = 1)
571  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
572  }
573 
574  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
575  PatternRewriter &rewriter) const override {
576  VectorType sourceVectorType = gatherOp.getVectorType();
577  if (sourceVectorType.getRank() == 0)
578  return failure();
579  auto targetShape = getTargetShape(options, gatherOp);
580  if (!targetShape)
581  return failure();
582  SmallVector<int64_t> strides(targetShape->size(), 1);
583  Location loc = gatherOp.getLoc();
584  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
585 
586  // Prepare the result vector;
587  Value result = rewriter.create<arith::ConstantOp>(
588  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
589  auto targetType =
590  VectorType::get(*targetShape, sourceVectorType.getElementType());
591 
592  SmallVector<int64_t> loopOrder =
593  getUnrollOrder(originalSize.size(), gatherOp, options);
594  for (SmallVector<int64_t> elementOffsets :
595  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
596  // To get the unrolled gather, extract the same slice based on the
597  // decomposed shape from each of the index, mask, and pass-through
598  // vectors.
599  Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
600  loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
601  Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
602  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
603  Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
604  loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
605  auto slicedGather = rewriter.create<vector::GatherOp>(
606  loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
607  indexSubVec, maskSubVec, passThruSubVec);
608 
609  result = rewriter.create<vector::InsertStridedSliceOp>(
610  loc, slicedGather, result, elementOffsets, strides);
611  }
612  rewriter.replaceOp(gatherOp, result);
613  return success();
614  }
615 
616 private:
618 };
619 
620 } // namespace
621 
624  PatternBenefit benefit) {
625  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
626  UnrollContractionPattern, UnrollElementwisePattern,
627  UnrollReductionPattern, UnrollMultiReductionPattern,
628  UnrollTransposePattern, UnrollGatherPattern>(
629  patterns.getContext(), options, benefit);
630 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
#define DBGS()
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:394
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:649
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:630
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:606
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Options that control the vector unrolling.