MLIR  21.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/InterleavedRange.h"
21 #include <optional>
22 
23 #define DEBUG_TYPE "vector-unroll"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
30 /// Compute the indices of the slice `index` for a transfer op.
32  ArrayRef<Value> indices,
33  AffineMap permutationMap,
34  Location loc,
35  OpBuilder &builder) {
36  MLIRContext *ctx = builder.getContext();
37  auto isBroadcast = [](AffineExpr expr) {
38  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
39  return constExpr.getValue() == 0;
40  return false;
41  };
42  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
43  SmallVector<Value> slicedIndices(indices);
44  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
45  if (isBroadcast(dim.value()))
46  continue;
47  unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
48  auto expr = getAffineDimExpr(0, builder.getContext()) +
49  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
50  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
51  slicedIndices[pos] =
52  builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
53  }
54  return slicedIndices;
55 }
56 
57 // Clones `op` into a new operations that takes `operands` and returns
58 // `resultTypes`.
60  Operation *op,
61  ArrayRef<Value> operands,
62  ArrayRef<Type> resultTypes) {
63  return builder.create(loc, op->getName().getIdentifier(), operands,
64  resultTypes, op->getAttrs());
65 }
66 
67 /// Return the target shape for unrolling for the given `op`. Return
68 /// std::nullopt if the op shouldn't be or cannot be unrolled.
69 static std::optional<SmallVector<int64_t>>
71  LDBG("");
72  LDBG("Get unroll shape for op " << op->getName().getStringRef());
73  if (options.filterConstraint && failed(options.filterConstraint(op))) {
74  LDBG("--no filter constraint -> BAIL");
75  return std::nullopt;
76  }
77  assert(options.nativeShape &&
78  "vector unrolling expects the native shape or native"
79  "shape call back function to be set");
80  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
81  if (!unrollableVectorOp) {
82  LDBG("--not an unrollable op -> BAIL");
83  return std::nullopt;
84  }
85  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
86  if (!maybeUnrollShape) {
87  LDBG("--could not get shape of op " << *op << " -> BAIL");
88  return std::nullopt;
89  }
90  LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape));
91 
92  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
93  if (!targetShape) {
94  LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
95  return std::nullopt;
96  }
97  LDBG("--target shape: " << llvm::interleaved(*targetShape));
98 
99  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
100  if (!maybeShapeRatio) {
101  LDBG("--could not compute integral shape ratio -> BAIL");
102  return std::nullopt;
103  }
104  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
105  LDBG("--no unrolling needed -> SKIP");
106  return std::nullopt;
107  }
108  LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
109  return targetShape;
110 }
111 
113 getUnrollOrder(unsigned numLoops, Operation *op,
115  SmallVector<int64_t> loopOrder =
116  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
117  if (options.traversalOrderCallback != nullptr) {
118  std::optional<SmallVector<int64_t>> order =
119  options.traversalOrderCallback(op);
120  if (order) {
121  loopOrder = std::move(*order);
122  }
123  }
124  return loopOrder;
125 }
126 
127 namespace {
128 
129 struct UnrollTransferReadPattern
130  : public OpRewritePattern<vector::TransferReadOp> {
131  UnrollTransferReadPattern(MLIRContext *context,
133  PatternBenefit benefit = 1)
134  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
135  options(options) {}
136 
137  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
138  PatternRewriter &rewriter) const override {
139  // TODO: support 0-d corner case.
140  if (readOp.getTransferRank() == 0)
141  return failure();
142  if (readOp.getMask())
143  return failure();
144  auto targetShape = getTargetShape(options, readOp);
145  if (!targetShape)
146  return failure();
147  auto sourceVectorType = readOp.getVectorType();
148  SmallVector<int64_t> strides(targetShape->size(), 1);
149  Location loc = readOp.getLoc();
150  ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
151 
152  // Prepare the result vector;
153  Value result = rewriter.create<arith::ConstantOp>(
154  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
155  auto targetType =
156  VectorType::get(*targetShape, sourceVectorType.getElementType());
157  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
158  readOp.getIndices().end());
159  SmallVector<int64_t> loopOrder =
160  getUnrollOrder(originalSize.size(), readOp, options);
161  for (SmallVector<int64_t> elementOffsets :
162  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
163  SmallVector<Value> indices =
164  sliceTransferIndices(elementOffsets, originalIndices,
165  readOp.getPermutationMap(), loc, rewriter);
166  auto slicedRead = rewriter.create<vector::TransferReadOp>(
167  loc, targetType, readOp.getSource(), indices,
168  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
169  readOp.getInBoundsAttr());
170 
171  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
172  loc, slicedRead, result, elementOffsets, strides);
173  }
174  rewriter.replaceOp(readOp, result);
175  return success();
176  }
177 
178 private:
180 };
181 
182 struct UnrollTransferWritePattern
183  : public OpRewritePattern<vector::TransferWriteOp> {
184  UnrollTransferWritePattern(MLIRContext *context,
186  PatternBenefit benefit = 1)
187  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
188  options(options) {}
189 
190  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
191  PatternRewriter &rewriter) const override {
192  // TODO: support 0-d corner case.
193  if (writeOp.getTransferRank() == 0)
194  return failure();
195 
196  if (writeOp.getMask())
197  return failure();
198  auto targetShape = getTargetShape(options, writeOp);
199  if (!targetShape)
200  return failure();
201  auto sourceVectorType = writeOp.getVectorType();
202  SmallVector<int64_t> strides(targetShape->size(), 1);
203  Location loc = writeOp.getLoc();
204  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
205  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
206  writeOp.getIndices().end());
207  SmallVector<int64_t> loopOrder =
208  getUnrollOrder(originalSize.size(), writeOp, options);
209  Value resultTensor;
210  for (SmallVector<int64_t> elementOffsets :
211  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
212  Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
213  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
214  SmallVector<Value> indices =
215  sliceTransferIndices(elementOffsets, originalIndices,
216  writeOp.getPermutationMap(), loc, rewriter);
217  Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
218  loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
219  indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
220  // For the tensor case update the destination for the next transfer write.
221  if (!slicedWrite->getResults().empty())
222  resultTensor = slicedWrite->getResult(0);
223  }
224  if (resultTensor)
225  rewriter.replaceOp(writeOp, resultTensor);
226  else
227  rewriter.eraseOp(writeOp);
228  return success();
229  }
230 
231 private:
233 };
234 
235 struct OffsetMapInfo {
236  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
237 
238  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
239 
240  static unsigned getHashValue(const SmallVector<int64_t> &v) {
241  return static_cast<unsigned>(llvm::hash_combine_range(v));
242  }
243 
244  static bool isEqual(const SmallVector<int64_t> &lhs,
245  const SmallVector<int64_t> &rhs) {
246  return lhs == rhs;
247  }
248 };
249 
250 struct UnrollContractionPattern
251  : public OpRewritePattern<vector::ContractionOp> {
252  UnrollContractionPattern(MLIRContext *context,
254  PatternBenefit benefit = 1)
255  : OpRewritePattern<vector::ContractionOp>(context, benefit),
256  options(options) {}
257 
258  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
259  PatternRewriter &rewriter) const override {
260  auto targetShape = getTargetShape(options, contractOp);
261  if (!targetShape)
262  return failure();
263  auto dstVecType = cast<VectorType>(contractOp.getResultType());
264  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
265 
266  Location loc = contractOp.getLoc();
267  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
268  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
269  llvm::MapVector<
271  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
272  accCache;
273 
275  contractOp.getIteratorTypes().size(), contractOp, options);
276 
277  for (SmallVector<int64_t> offsets :
278  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
279  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
280 
281  // Helper to compute the new shape of each operand and extract the slice.
282  auto extractOperand = [&](unsigned index, Value operand,
283  AffineMap permutationMap,
284  ArrayRef<int64_t> operandOffets) {
286  permutationMap, ArrayRef<int64_t>(*targetShape));
287  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
288  slicesOperands[index] =
289  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
290  loc, operand, operandOffets, operandShape, operandStrides);
291  };
292 
293  // Extract the new lhs operand.
294  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
295  SmallVector<int64_t> lhsOffets =
296  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
297  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
298 
299  // Extract the new rhs operand.
300  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
301  SmallVector<int64_t> rhsOffets =
302  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
303  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
304 
305  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
306  SmallVector<int64_t> accOffets =
307  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
308  // If a version of the accumulator has already been computed, use it
309  // otherwise extract the first version from the original operand.
310  auto *accIt = accCache.find(accOffets);
311  if (accIt != accCache.end())
312  slicesOperands[2] = accIt->second;
313  else
314  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
315 
316  SmallVector<int64_t> dstShape =
317  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
318  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
320  rewriter, loc, contractOp, slicesOperands, targetType);
321 
322  SmallVector<int64_t> dstOffets =
323  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
324  // Save the accumulated value untill all the loops are unrolled since
325  // reduction loop keep updating the accumulator.
326  accCache[dstOffets] = newOp->getResult(0);
327  }
328  // Assemble back the accumulator into a single vector.
329  Value result = rewriter.create<arith::ConstantOp>(
330  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
331  for (const auto &it : accCache) {
332  SmallVector<int64_t> dstStrides(it.first.size(), 1);
333  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
334  loc, it.second, result, it.first, dstStrides);
335  }
336  rewriter.replaceOp(contractOp, result);
337  return success();
338  }
339 
340 private:
342 };
343 
344 struct UnrollMultiReductionPattern
345  : public OpRewritePattern<vector::MultiDimReductionOp> {
346  UnrollMultiReductionPattern(MLIRContext *context,
348  PatternBenefit benefit = 1)
349  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
350  options(options) {}
351 
352  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
353  PatternRewriter &rewriter) const override {
354  auto resultType = reductionOp->getResult(0).getType();
355  if (resultType.isIntOrFloat()) {
356  return rewriter.notifyMatchFailure(reductionOp,
357  "Unrolling scalars is not supported");
358  }
359  std::optional<SmallVector<int64_t>> targetShape =
360  getTargetShape(options, reductionOp);
361  if (!targetShape)
362  return failure();
363  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
364  llvm::MapVector<
366  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
367  accCache;
368  Location loc = reductionOp.getLoc();
369 
370  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
371  // of multiples of the targetShape.
372  for (SmallVector<int64_t> offsets :
373  StaticTileOffsetRange(originalSize, *targetShape)) {
374  SmallVector<Value> operands;
375  SmallVector<int64_t> operandStrides(offsets.size(), 1);
376  Value slicedOperand =
377  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
378  loc, reductionOp.getSource(), offsets, *targetShape,
379  operandStrides);
380  operands.push_back(slicedOperand);
381  SmallVector<int64_t> dstShape;
382  SmallVector<int64_t> destOffset;
383  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
384  if (!reductionOp.isReducedDim(i)) {
385  destOffset.push_back(offsets[i]);
386  dstShape.push_back((*targetShape)[i]);
387  }
388  }
389  Value acc;
390  SmallVector<int64_t> accStrides(destOffset.size(), 1);
391  // If a version of the accumulator has already been computed, use it
392  // otherwise extract the first version from the original operand.
393  auto *accIt = accCache.find(destOffset);
394  if (accIt != accCache.end())
395  acc = accIt->second;
396  else
397  acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
398  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
399  operands.push_back(acc);
400  auto targetType = VectorType::get(
401  dstShape, reductionOp.getSourceVectorType().getElementType());
402  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
403  operands, targetType);
404  Value result = newOp->getResult(0);
405  accCache[destOffset] = result;
406  }
407  // Assemble back the accumulator into a single vector.
408  Value result = rewriter.create<arith::ConstantOp>(
409  loc, reductionOp.getDestType(),
410  rewriter.getZeroAttr(reductionOp.getDestType()));
411  for (const auto &it : accCache) {
412  SmallVector<int64_t> dstStrides(it.first.size(), 1);
413  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
414  loc, it.second, result, it.first, dstStrides);
415  }
416  rewriter.replaceOp(reductionOp, result);
417  return success();
418  }
419 
420 private:
422 };
423 
424 struct UnrollElementwisePattern : public RewritePattern {
425  UnrollElementwisePattern(MLIRContext *context,
427  PatternBenefit benefit = 1)
428  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
429  options(options) {}
430 
431  LogicalResult matchAndRewrite(Operation *op,
432  PatternRewriter &rewriter) const override {
434  return failure();
435  auto targetShape = getTargetShape(options, op);
436  if (!targetShape)
437  return failure();
438  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
439  SmallVector<int64_t> originalSize =
440  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
441  // Bail-out if rank(source) != rank(target). The main limitation here is the
442  // fact that `ExtractStridedSlice` requires the rank for the input and
443  // output to match. If needed, we can relax this later.
444  if (originalSize.size() != targetShape->size())
445  return rewriter.notifyMatchFailure(
446  op, "expected input vector rank to match target shape rank");
447  Location loc = op->getLoc();
448  // Prepare the result vector.
449  Value result = rewriter.create<arith::ConstantOp>(
450  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
451  SmallVector<int64_t> strides(targetShape->size(), 1);
452  VectorType newVecType =
453  VectorType::get(*targetShape, dstVecType.getElementType());
454 
455  // Create the unrolled computation.
456  for (SmallVector<int64_t> offsets :
457  StaticTileOffsetRange(originalSize, *targetShape)) {
458  SmallVector<Value> extractOperands;
459  for (OpOperand &operand : op->getOpOperands()) {
460  auto vecType = dyn_cast<VectorType>(operand.get().getType());
461  if (!vecType) {
462  extractOperands.push_back(operand.get());
463  continue;
464  }
465  extractOperands.push_back(
466  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
467  loc, operand.get(), offsets, *targetShape, strides));
468  }
470  rewriter, loc, op, extractOperands, newVecType);
471  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
472  loc, newOp->getResult(0), result, offsets, strides);
473  }
474  rewriter.replaceOp(op, result);
475  return success();
476  }
477 
478 private:
480 };
481 
482 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
483  UnrollReductionPattern(MLIRContext *context,
485  PatternBenefit benefit = 1)
486  : OpRewritePattern<vector::ReductionOp>(context, benefit),
487  options(options) {}
488 
489  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
490  PatternRewriter &rewriter) const override {
491  std::optional<SmallVector<int64_t>> targetShape =
492  getTargetShape(options, reductionOp);
493  if (!targetShape)
494  return failure();
495  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
496 
497  // Create unrolled vector reduction.
498  Location loc = reductionOp.getLoc();
499  Value accumulator = nullptr;
500  for (SmallVector<int64_t> offsets :
501  StaticTileOffsetRange(originalSize, *targetShape)) {
502  SmallVector<int64_t> strides(offsets.size(), 1);
503  Value slicedOperand =
504  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
505  loc, reductionOp.getVector(), offsets, *targetShape, strides);
507  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
508  Value result = newOp->getResult(0);
509 
510  if (!accumulator) {
511  // This is the first reduction.
512  accumulator = result;
513  } else {
514  // On subsequent reduction, combine with the accumulator.
515  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
516  accumulator, result);
517  }
518  }
519 
520  rewriter.replaceOp(reductionOp, accumulator);
521  return success();
522  }
523 
524 private:
526 };
527 
528 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
529  UnrollTransposePattern(MLIRContext *context,
531  PatternBenefit benefit = 1)
532  : OpRewritePattern<vector::TransposeOp>(context, benefit),
533  options(options) {}
534 
535  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
536  PatternRewriter &rewriter) const override {
537  if (transposeOp.getResultVectorType().getRank() == 0)
538  return failure();
539  auto targetShape = getTargetShape(options, transposeOp);
540  if (!targetShape)
541  return failure();
542  auto originalVectorType = transposeOp.getResultVectorType();
543  SmallVector<int64_t> strides(targetShape->size(), 1);
544  Location loc = transposeOp.getLoc();
545  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
546 
547  // Prepare the result vector;
548  Value result = rewriter.create<arith::ConstantOp>(
549  loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
550  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
551 
552  // Unroll the computation.
553  for (SmallVector<int64_t> elementOffsets :
554  StaticTileOffsetRange(originalSize, *targetShape)) {
555  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
556  SmallVector<int64_t> permutedShape(elementOffsets.size());
557  // Compute the source offsets and shape.
558  for (auto indices : llvm::enumerate(permutation)) {
559  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
560  permutedShape[indices.value()] = (*targetShape)[indices.index()];
561  }
562  Value slicedOperand =
563  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
564  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
565  strides);
566  Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
567  loc, slicedOperand, permutation);
568  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
569  loc, transposedSlice, result, elementOffsets, strides);
570  }
571  rewriter.replaceOp(transposeOp, result);
572  return success();
573  }
574 
575 private:
577 };
578 
579 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
580  UnrollGatherPattern(MLIRContext *context,
582  PatternBenefit benefit = 1)
583  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
584  }
585 
586  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
587  PatternRewriter &rewriter) const override {
588  VectorType sourceVectorType = gatherOp.getVectorType();
589  if (sourceVectorType.getRank() == 0)
590  return failure();
591  auto targetShape = getTargetShape(options, gatherOp);
592  if (!targetShape)
593  return failure();
594  SmallVector<int64_t> strides(targetShape->size(), 1);
595  Location loc = gatherOp.getLoc();
596  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
597 
598  // Prepare the result vector;
599  Value result = rewriter.create<arith::ConstantOp>(
600  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
601  auto targetType =
602  VectorType::get(*targetShape, sourceVectorType.getElementType());
603 
604  SmallVector<int64_t> loopOrder =
605  getUnrollOrder(originalSize.size(), gatherOp, options);
606  for (SmallVector<int64_t> elementOffsets :
607  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
608  // To get the unrolled gather, extract the same slice based on the
609  // decomposed shape from each of the index, mask, and pass-through
610  // vectors.
611  Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
612  loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
613  Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
614  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
615  Value passThruSubVec =
616  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
617  loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
618  strides);
619  auto slicedGather = rewriter.create<vector::GatherOp>(
620  loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
621  indexSubVec, maskSubVec, passThruSubVec);
622 
623  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
624  loc, slicedGather, result, elementOffsets, strides);
625  }
626  rewriter.replaceOp(gatherOp, result);
627  return success();
628  }
629 
630 private:
632 };
633 
634 } // namespace
635 
636 void mlir::vector::populateVectorUnrollPatterns(
638  PatternBenefit benefit) {
639  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
640  UnrollContractionPattern, UnrollElementwisePattern,
641  UnrollReductionPattern, UnrollMultiReductionPattern,
642  UnrollTransposePattern, UnrollGatherPattern>(
643  patterns.getContext(), options, benefit);
644 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:243
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options that control the vector unrolling.