MLIR  18.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Debug.h"
22 #include <numeric>
23 #include <optional>
24 
25 #define DEBUG_TYPE "vector-unroll"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28 
29 using namespace mlir;
30 using namespace mlir::vector;
31 
32 /// Compute the indices of the slice `index` for a tranfer op.
34  ArrayRef<Value> indices,
35  AffineMap permutationMap,
36  Location loc,
37  OpBuilder &builder) {
38  MLIRContext *ctx = builder.getContext();
39  auto isBroadcast = [](AffineExpr expr) {
40  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
41  return constExpr.getValue() == 0;
42  return false;
43  };
44  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
45  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
46  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
47  if (isBroadcast(dim.value()))
48  continue;
49  unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
50  auto expr = getAffineDimExpr(0, builder.getContext()) +
51  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
52  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
53  slicedIndices[pos] =
54  builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
55  }
56  return slicedIndices;
57 }
58 
59 // Clones `op` into a new operations that takes `operands` and returns
60 // `resultTypes`.
62  Operation *op,
63  ArrayRef<Value> operands,
64  ArrayRef<Type> resultTypes) {
65  return builder.create(loc, op->getName().getIdentifier(), operands,
66  resultTypes, op->getAttrs());
67 }
68 
69 /// Return the target shape for unrolling for the given `op`. Return
70 /// std::nullopt if the op shouldn't be or cannot be unrolled.
71 static std::optional<SmallVector<int64_t>>
73  LDBG("");
74  LDBG("Get unroll shape for op " << op->getName().getStringRef());
75  if (options.filterConstraint && failed(options.filterConstraint(op))) {
76  LDBG("--no filter constraint -> BAIL");
77  return std::nullopt;
78  }
79  assert(options.nativeShape &&
80  "vector unrolling expects the native shape or native"
81  "shape call back function to be set");
82  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
83  if (!unrollableVectorOp) {
84  LDBG("--not an unrollable op -> BAIL");
85  return std::nullopt;
86  }
87  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
88  if (!maybeUnrollShape) {
89  LDBG("--could not get shape of op " << *op << " -> BAIL");
90  return std::nullopt;
91  }
92  LLVM_DEBUG(
93  llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: ");
94  llvm::dbgs() << "\n";);
95 
96  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
97  if (!targetShape) {
98  LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
99  return std::nullopt;
100  }
101  LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: ");
102  llvm::dbgs() << "\n";);
103 
104  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
105  if (!maybeShapeRatio) {
106  LDBG("--could not compute integral shape ratio -> BAIL");
107  return std::nullopt;
108  }
109  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
110  LDBG("--no unrolling needed -> SKIP");
111  return std::nullopt;
112  }
113  LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
114  return targetShape;
115 }
116 
118 getUnrollOrder(unsigned numLoops, Operation *op,
120  SmallVector<int64_t> loopOrder =
121  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
122  if (options.traversalOrderCallback != nullptr) {
123  std::optional<SmallVector<int64_t>> order =
124  options.traversalOrderCallback(op);
125  if (order) {
126  loopOrder = std::move(*order);
127  }
128  }
129  return loopOrder;
130 }
131 
132 namespace {
133 
134 struct UnrollTransferReadPattern
135  : public OpRewritePattern<vector::TransferReadOp> {
136  UnrollTransferReadPattern(MLIRContext *context,
138  PatternBenefit benefit = 1)
139  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
140  options(options) {}
141 
142  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
143  PatternRewriter &rewriter) const override {
144  // TODO: support 0-d corner case.
145  if (readOp.getTransferRank() == 0)
146  return failure();
147  if (readOp.getMask())
148  return failure();
149  auto targetShape = getTargetShape(options, readOp);
150  if (!targetShape)
151  return failure();
152  auto sourceVectorType = readOp.getVectorType();
153  SmallVector<int64_t> strides(targetShape->size(), 1);
154  Location loc = readOp.getLoc();
155  ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
156 
157  // Prepare the result vector;
158  Value result = rewriter.create<arith::ConstantOp>(
159  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
160  auto targetType =
161  VectorType::get(*targetShape, sourceVectorType.getElementType());
162  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
163  readOp.getIndices().end());
164  SmallVector<int64_t> loopOrder =
165  getUnrollOrder(originalSize.size(), readOp, options);
166  for (SmallVector<int64_t> elementOffsets :
167  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
168  SmallVector<Value> indices =
169  sliceTransferIndices(elementOffsets, originalIndices,
170  readOp.getPermutationMap(), loc, rewriter);
171  auto slicedRead = rewriter.create<vector::TransferReadOp>(
172  loc, targetType, readOp.getSource(), indices,
173  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
174  readOp.getInBoundsAttr());
175 
176  result = rewriter.create<vector::InsertStridedSliceOp>(
177  loc, slicedRead, result, elementOffsets, strides);
178  }
179  rewriter.replaceOp(readOp, result);
180  return success();
181  }
182 
183 private:
185 };
186 
187 struct UnrollTransferWritePattern
188  : public OpRewritePattern<vector::TransferWriteOp> {
189  UnrollTransferWritePattern(MLIRContext *context,
191  PatternBenefit benefit = 1)
192  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
193  options(options) {}
194 
195  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
196  PatternRewriter &rewriter) const override {
197  // TODO: support 0-d corner case.
198  if (writeOp.getTransferRank() == 0)
199  return failure();
200 
201  if (writeOp.getMask())
202  return failure();
203  auto targetShape = getTargetShape(options, writeOp);
204  if (!targetShape)
205  return failure();
206  auto sourceVectorType = writeOp.getVectorType();
207  SmallVector<int64_t> strides(targetShape->size(), 1);
208  Location loc = writeOp.getLoc();
209  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
210  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
211  writeOp.getIndices().end());
212  SmallVector<int64_t> loopOrder =
213  getUnrollOrder(originalSize.size(), writeOp, options);
214  Value resultTensor;
215  for (SmallVector<int64_t> elementOffsets :
216  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
217  Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
218  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
219  SmallVector<Value> indices =
220  sliceTransferIndices(elementOffsets, originalIndices,
221  writeOp.getPermutationMap(), loc, rewriter);
222  Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
223  loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
224  indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
225  // For the tensor case update the destination for the next transfer write.
226  if (!slicedWrite->getResults().empty())
227  resultTensor = slicedWrite->getResult(0);
228  }
229  if (resultTensor)
230  rewriter.replaceOp(writeOp, resultTensor);
231  else
232  rewriter.eraseOp(writeOp);
233  return success();
234  }
235 
236 private:
238 };
239 
240 struct OffsetMapInfo {
241  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
242 
243  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
244 
245  static unsigned getHashValue(const SmallVector<int64_t> &v) {
246  return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
247  }
248 
249  static bool isEqual(const SmallVector<int64_t> &lhs,
250  const SmallVector<int64_t> &rhs) {
251  return lhs == rhs;
252  }
253 };
254 
255 struct UnrollContractionPattern
256  : public OpRewritePattern<vector::ContractionOp> {
257  UnrollContractionPattern(MLIRContext *context,
259  PatternBenefit benefit = 1)
260  : OpRewritePattern<vector::ContractionOp>(context, benefit),
261  options(options) {}
262 
263  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
264  PatternRewriter &rewriter) const override {
265  auto targetShape = getTargetShape(options, contractOp);
266  if (!targetShape)
267  return failure();
268  auto dstVecType = cast<VectorType>(contractOp.getResultType());
269  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
270 
271  Location loc = contractOp.getLoc();
272  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
273  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
274  llvm::MapVector<
276  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
277  accCache;
278 
280  contractOp.getIteratorTypes().size(), contractOp, options);
281 
282  for (SmallVector<int64_t> offsets :
283  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
284  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
285 
286  // Helper to compute the new shape of each operand and extract the slice.
287  auto extractOperand = [&](unsigned index, Value operand,
288  AffineMap permutationMap,
289  ArrayRef<int64_t> operandOffets) {
291  permutationMap, ArrayRef<int64_t>(*targetShape));
292  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
293  slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
294  loc, operand, operandOffets, operandShape, operandStrides);
295  };
296 
297  // Extract the new lhs operand.
298  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
299  SmallVector<int64_t> lhsOffets =
300  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
301  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
302 
303  // Extract the new rhs operand.
304  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
305  SmallVector<int64_t> rhsOffets =
306  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
307  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
308 
309  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
310  SmallVector<int64_t> accOffets =
311  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
312  // If a version of the accumulator has already been computed, use it
313  // otherwise extract the first version from the original operand.
314  auto accIt = accCache.find(accOffets);
315  if (accIt != accCache.end())
316  slicesOperands[2] = accIt->second;
317  else
318  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
319 
320  SmallVector<int64_t> dstShape =
321  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
322  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
324  rewriter, loc, contractOp, slicesOperands, targetType);
325 
326  SmallVector<int64_t> dstOffets =
327  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
328  // Save the accumulated value untill all the loops are unrolled since
329  // reduction loop keep updating the accumulator.
330  accCache[dstOffets] = newOp->getResult(0);
331  }
332  // Assemble back the accumulator into a single vector.
333  Value result = rewriter.create<arith::ConstantOp>(
334  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
335  for (const auto &it : accCache) {
336  SmallVector<int64_t> dstStrides(it.first.size(), 1);
337  result = rewriter.create<vector::InsertStridedSliceOp>(
338  loc, it.second, result, it.first, dstStrides);
339  }
340  rewriter.replaceOp(contractOp, result);
341  return success();
342  }
343 
344 private:
346 };
347 
348 struct UnrollMultiReductionPattern
349  : public OpRewritePattern<vector::MultiDimReductionOp> {
350  UnrollMultiReductionPattern(MLIRContext *context,
352  PatternBenefit benefit = 1)
353  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
354  options(options) {}
355 
356  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
357  PatternRewriter &rewriter) const override {
358  std::optional<SmallVector<int64_t>> targetShape =
359  getTargetShape(options, reductionOp);
360  if (!targetShape)
361  return failure();
362  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
363  llvm::MapVector<
365  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
366  accCache;
367  Location loc = reductionOp.getLoc();
368 
369  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
370  // of multiples of the targetShape.
371  for (SmallVector<int64_t> offsets :
372  StaticTileOffsetRange(originalSize, *targetShape)) {
373  SmallVector<Value> operands;
374  SmallVector<int64_t> operandStrides(offsets.size(), 1);
375  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
376  loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
377  operands.push_back(slicedOperand);
378  SmallVector<int64_t> dstShape;
379  SmallVector<int64_t> destOffset;
380  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
381  if (!reductionOp.isReducedDim(i)) {
382  destOffset.push_back(offsets[i]);
383  dstShape.push_back((*targetShape)[i]);
384  }
385  }
386  Value acc;
387  SmallVector<int64_t> accStrides(destOffset.size(), 1);
388  // If a version of the accumulator has already been computed, use it
389  // otherwise extract the first version from the original operand.
390  auto accIt = accCache.find(destOffset);
391  if (accIt != accCache.end())
392  acc = accIt->second;
393  else
394  acc = rewriter.create<vector::ExtractStridedSliceOp>(
395  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
396  operands.push_back(acc);
397  auto targetType = VectorType::get(
398  dstShape, reductionOp.getSourceVectorType().getElementType());
399  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
400  operands, targetType);
401  Value result = newOp->getResult(0);
402  accCache[destOffset] = result;
403  }
404  // Assemble back the accumulator into a single vector.
405  Value result = rewriter.create<arith::ConstantOp>(
406  loc, reductionOp.getDestType(),
407  rewriter.getZeroAttr(reductionOp.getDestType()));
408  for (const auto &it : accCache) {
409  SmallVector<int64_t> dstStrides(it.first.size(), 1);
410  result = rewriter.create<vector::InsertStridedSliceOp>(
411  loc, it.second, result, it.first, dstStrides);
412  }
413  rewriter.replaceOp(reductionOp, result);
414  return success();
415  }
416 
417 private:
419 };
420 
421 struct UnrollElementwisePattern : public RewritePattern {
422  UnrollElementwisePattern(MLIRContext *context,
424  PatternBenefit benefit = 1)
425  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
426  options(options) {}
427 
428  LogicalResult matchAndRewrite(Operation *op,
429  PatternRewriter &rewriter) const override {
431  return failure();
432  auto targetShape = getTargetShape(options, op);
433  if (!targetShape)
434  return failure();
435  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
436  SmallVector<int64_t> originalSize =
437  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
438  Location loc = op->getLoc();
439  // Prepare the result vector.
440  Value result = rewriter.create<arith::ConstantOp>(
441  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
442  SmallVector<int64_t> strides(targetShape->size(), 1);
443  VectorType newVecType =
444  VectorType::get(*targetShape, dstVecType.getElementType());
445 
446  // Create the unrolled computation.
447  for (SmallVector<int64_t> offsets :
448  StaticTileOffsetRange(originalSize, *targetShape)) {
449  SmallVector<Value> extractOperands;
450  for (OpOperand &operand : op->getOpOperands()) {
451  auto vecType = dyn_cast<VectorType>(operand.get().getType());
452  if (!vecType) {
453  extractOperands.push_back(operand.get());
454  continue;
455  }
456  extractOperands.push_back(
457  rewriter.create<vector::ExtractStridedSliceOp>(
458  loc, operand.get(), offsets, *targetShape, strides));
459  }
461  rewriter, loc, op, extractOperands, newVecType);
462  result = rewriter.create<vector::InsertStridedSliceOp>(
463  loc, newOp->getResult(0), result, offsets, strides);
464  }
465  rewriter.replaceOp(op, result);
466  return success();
467  }
468 
469 private:
471 };
472 
473 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
474  UnrollReductionPattern(MLIRContext *context,
476  PatternBenefit benefit = 1)
477  : OpRewritePattern<vector::ReductionOp>(context, benefit),
478  options(options) {}
479 
480  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
481  PatternRewriter &rewriter) const override {
482  std::optional<SmallVector<int64_t>> targetShape =
483  getTargetShape(options, reductionOp);
484  if (!targetShape)
485  return failure();
486  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
487 
488  // Create unrolled vector reduction.
489  Location loc = reductionOp.getLoc();
490  Value accumulator = nullptr;
491  for (SmallVector<int64_t> offsets :
492  StaticTileOffsetRange(originalSize, *targetShape)) {
493  SmallVector<int64_t> strides(offsets.size(), 1);
494  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
495  loc, reductionOp.getVector(), offsets, *targetShape, strides);
497  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
498  Value result = newOp->getResult(0);
499 
500  if (!accumulator) {
501  // This is the first reduction.
502  accumulator = result;
503  } else {
504  // On subsequent reduction, combine with the accumulator.
505  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
506  accumulator, result);
507  }
508  }
509 
510  rewriter.replaceOp(reductionOp, accumulator);
511  return success();
512  }
513 
514 private:
516 };
517 
518 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
519  UnrollTransposePattern(MLIRContext *context,
521  PatternBenefit benefit = 1)
522  : OpRewritePattern<vector::TransposeOp>(context, benefit),
523  options(options) {}
524 
525  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
526  PatternRewriter &rewriter) const override {
527  if (transposeOp.getResultVectorType().getRank() == 0)
528  return failure();
529  auto targetShape = getTargetShape(options, transposeOp);
530  if (!targetShape)
531  return failure();
532  auto originalVectorType = transposeOp.getResultVectorType();
533  SmallVector<int64_t> strides(targetShape->size(), 1);
534  Location loc = transposeOp.getLoc();
535  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
536 
537  // Prepare the result vector;
538  Value result = rewriter.create<arith::ConstantOp>(
539  loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
540  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
541 
542  // Unroll the computation.
543  for (SmallVector<int64_t> elementOffsets :
544  StaticTileOffsetRange(originalSize, *targetShape)) {
545  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
546  SmallVector<int64_t> permutedShape(elementOffsets.size());
547  // Compute the source offsets and shape.
548  for (auto indices : llvm::enumerate(permutation)) {
549  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
550  permutedShape[indices.value()] = (*targetShape)[indices.index()];
551  }
552  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
553  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
554  strides);
555  Value transposedSlice =
556  rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
557  result = rewriter.create<vector::InsertStridedSliceOp>(
558  loc, transposedSlice, result, elementOffsets, strides);
559  }
560  rewriter.replaceOp(transposeOp, result);
561  return success();
562  }
563 
564 private:
566 };
567 
568 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
569  UnrollGatherPattern(MLIRContext *context,
571  PatternBenefit benefit = 1)
572  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
573  }
574 
575  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
576  PatternRewriter &rewriter) const override {
577  VectorType sourceVectorType = gatherOp.getVectorType();
578  if (sourceVectorType.getRank() == 0)
579  return failure();
580  auto targetShape = getTargetShape(options, gatherOp);
581  if (!targetShape)
582  return failure();
583  SmallVector<int64_t> strides(targetShape->size(), 1);
584  Location loc = gatherOp.getLoc();
585  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
586 
587  // Prepare the result vector;
588  Value result = rewriter.create<arith::ConstantOp>(
589  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
590  auto targetType =
591  VectorType::get(*targetShape, sourceVectorType.getElementType());
592 
593  SmallVector<int64_t> loopOrder =
594  getUnrollOrder(originalSize.size(), gatherOp, options);
595  for (SmallVector<int64_t> elementOffsets :
596  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
597  // To get the unrolled gather, extract the same slice based on the
598  // decomposed shape from each of the index, mask, and pass-through
599  // vectors.
600  Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
601  loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
602  Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
603  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
604  Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
605  loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
606  auto slicedGather = rewriter.create<vector::GatherOp>(
607  loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
608  indexSubVec, maskSubVec, passThruSubVec);
609 
610  result = rewriter.create<vector::InsertStridedSliceOp>(
611  loc, slicedGather, result, elementOffsets, strides);
612  }
613  rewriter.replaceOp(gatherOp, result);
614  return success();
615  }
616 
617 private:
619 };
620 
621 } // namespace
622 
625  PatternBenefit benefit) {
626  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
627  UnrollContractionPattern, UnrollElementwisePattern,
628  UnrollReductionPattern, UnrollMultiReductionPattern,
629  UnrollTransposePattern, UnrollGatherPattern>(
630  patterns.getContext(), options, benefit);
631 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
#define DBGS()
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1344
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:648
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:608
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:584
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Options that control the vector unrolling.