MLIR  22.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
21 #include "llvm/Support/InterleavedRange.h"
22 #include <optional>
23 
24 #define DEBUG_TYPE "vector-unroll"
25 
26 using namespace mlir;
27 using namespace mlir::vector;
28 
29 /// Compute the indices of the slice `index` for a transfer op.
31  ArrayRef<Value> indices,
32  AffineMap permutationMap,
33  Location loc,
34  OpBuilder &builder) {
35  MLIRContext *ctx = builder.getContext();
36  auto isBroadcast = [](AffineExpr expr) {
37  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38  return constExpr.getValue() == 0;
39  return false;
40  };
41  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
42  SmallVector<Value> slicedIndices(indices);
43  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
44  if (isBroadcast(dim.value()))
45  continue;
46  unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
47  auto expr = getAffineDimExpr(0, builder.getContext()) +
48  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
49  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
50  slicedIndices[pos] =
51  affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
52  }
53  return slicedIndices;
54 }
55 
56 // Compute the new indices by adding `offsets` to `originalIndices`.
57 // If m < n (m = offsets.size(), n = originalIndices.size()),
58 // then only the trailing m values in `originalIndices` are updated.
60  Location loc,
61  OperandRange originalIndices,
62  ArrayRef<int64_t> offsets) {
63  assert(offsets.size() <= originalIndices.size() &&
64  "Offsets should not exceed the number of original indices");
65  SmallVector<Value> indices(originalIndices);
66 
67  auto start = indices.size() - offsets.size();
68  for (auto [i, offset] : llvm::enumerate(offsets)) {
69  if (offset != 0) {
70  indices[start + i] = arith::AddIOp::create(
71  rewriter, loc, originalIndices[start + i],
72  arith::ConstantIndexOp::create(rewriter, loc, offset));
73  }
74  }
75  return indices;
76 }
77 
78 // Clones `op` into a new operations that takes `operands` and returns
79 // `resultTypes`.
81  Operation *op,
82  ArrayRef<Value> operands,
83  ArrayRef<Type> resultTypes) {
84  return builder.create(loc, op->getName().getIdentifier(), operands,
85  resultTypes, op->getAttrs());
86 }
87 
88 /// Return the target shape for unrolling for the given `op`. Return
89 /// std::nullopt if the op shouldn't be or cannot be unrolled.
90 static std::optional<SmallVector<int64_t>>
92  LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
93  if (options.filterConstraint && failed(options.filterConstraint(op))) {
94  LDBG() << "--no filter constraint -> BAIL";
95  return std::nullopt;
96  }
97  assert(options.nativeShape &&
98  "vector unrolling expects the native shape or native"
99  "shape call back function to be set");
100  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101  if (!unrollableVectorOp) {
102  LDBG() << "--not an unrollable op -> BAIL";
103  return std::nullopt;
104  }
105  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106  if (!maybeUnrollShape) {
107  LDBG() << "--could not get shape of op " << *op << " -> BAIL";
108  return std::nullopt;
109  }
110  LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
111 
112  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
113  if (!targetShape) {
114  LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
115  return std::nullopt;
116  }
117  LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
118 
119  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
120  if (!maybeShapeRatio) {
121  LDBG() << "--could not compute integral shape ratio -> BAIL";
122  return std::nullopt;
123  }
124  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
125  LDBG() << "--no unrolling needed -> SKIP";
126  return std::nullopt;
127  }
128  LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
129  return targetShape;
130 }
131 
133 getUnrollOrder(unsigned numLoops, Operation *op,
135  SmallVector<int64_t> loopOrder =
136  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
137  if (options.traversalOrderCallback != nullptr) {
138  std::optional<SmallVector<int64_t>> order =
139  options.traversalOrderCallback(op);
140  if (order) {
141  loopOrder = std::move(*order);
142  }
143  }
144  return loopOrder;
145 }
146 
147 namespace {
148 
149 struct UnrollTransferReadPattern
150  : public OpRewritePattern<vector::TransferReadOp> {
151  UnrollTransferReadPattern(MLIRContext *context,
153  PatternBenefit benefit = 1)
154  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155  options(options) {}
156 
157  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158  PatternRewriter &rewriter) const override {
159  // TODO: support 0-d corner case.
160  if (readOp.getTransferRank() == 0)
161  return failure();
162  if (readOp.getMask())
163  return failure();
164  auto targetShape = getTargetShape(options, readOp);
165  if (!targetShape)
166  return failure();
167  auto sourceVectorType = readOp.getVectorType();
168  SmallVector<int64_t> strides(targetShape->size(), 1);
169  Location loc = readOp.getLoc();
170  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
171 
172  // Prepare the result vector;
173  Value result =
174  arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175  rewriter.getZeroAttr(sourceVectorType));
176  auto targetType =
177  VectorType::get(*targetShape, sourceVectorType.getElementType());
178  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179  readOp.getIndices().end());
180  SmallVector<int64_t> loopOrder =
181  getUnrollOrder(originalSize.size(), readOp, options);
182  for (SmallVector<int64_t> elementOffsets :
183  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
184  SmallVector<Value> indices =
185  sliceTransferIndices(elementOffsets, originalIndices,
186  readOp.getPermutationMap(), loc, rewriter);
187  auto slicedRead = vector::TransferReadOp::create(
188  rewriter, loc, targetType, readOp.getBase(), indices,
189  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190  readOp.getInBoundsAttr());
191 
192  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
193  loc, slicedRead, result, elementOffsets, strides);
194  }
195  rewriter.replaceOp(readOp, result);
196  return success();
197  }
198 
199 private:
201 };
202 
203 struct UnrollTransferWritePattern
204  : public OpRewritePattern<vector::TransferWriteOp> {
205  UnrollTransferWritePattern(MLIRContext *context,
207  PatternBenefit benefit = 1)
208  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209  options(options) {}
210 
211  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212  PatternRewriter &rewriter) const override {
213  // TODO: support 0-d corner case.
214  if (writeOp.getTransferRank() == 0)
215  return failure();
216 
217  if (writeOp.getMask())
218  return failure();
219  auto targetShape = getTargetShape(options, writeOp);
220  if (!targetShape)
221  return failure();
222  auto sourceVectorType = writeOp.getVectorType();
223  SmallVector<int64_t> strides(targetShape->size(), 1);
224  Location loc = writeOp.getLoc();
225  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
226  // Bail-out if rank(source) != rank(target). The main limitation here is the
227  // fact that `ExtractStridedSlice` requires the rank for the input and
228  // output to match. If needed, we can relax this later.
229  if (originalSize.size() != targetShape->size())
230  return rewriter.notifyMatchFailure(
231  writeOp,
232  "expected source input vector rank to match target shape rank");
233 
234  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235  writeOp.getIndices().end());
236  SmallVector<int64_t> loopOrder =
237  getUnrollOrder(originalSize.size(), writeOp, options);
238  Value resultTensor;
239  for (SmallVector<int64_t> elementOffsets :
240  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241  Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
242  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
243  SmallVector<Value> indices =
244  sliceTransferIndices(elementOffsets, originalIndices,
245  writeOp.getPermutationMap(), loc, rewriter);
246  Operation *slicedWrite = vector::TransferWriteOp::create(
247  rewriter, loc, slicedVector,
248  resultTensor ? resultTensor : writeOp.getBase(), indices,
249  writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250  // For the tensor case update the destination for the next transfer write.
251  if (!slicedWrite->getResults().empty())
252  resultTensor = slicedWrite->getResult(0);
253  }
254  if (resultTensor)
255  rewriter.replaceOp(writeOp, resultTensor);
256  else
257  rewriter.eraseOp(writeOp);
258  return success();
259  }
260 
261 private:
263 };
264 
265 struct OffsetMapInfo {
266  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
267 
268  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
269 
270  static unsigned getHashValue(const SmallVector<int64_t> &v) {
271  return static_cast<unsigned>(llvm::hash_combine_range(v));
272  }
273 
274  static bool isEqual(const SmallVector<int64_t> &lhs,
275  const SmallVector<int64_t> &rhs) {
276  return lhs == rhs;
277  }
278 };
279 
280 struct UnrollContractionPattern
281  : public OpRewritePattern<vector::ContractionOp> {
282  UnrollContractionPattern(MLIRContext *context,
284  PatternBenefit benefit = 1)
285  : OpRewritePattern<vector::ContractionOp>(context, benefit),
286  options(options) {}
287 
288  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289  PatternRewriter &rewriter) const override {
290  auto targetShape = getTargetShape(options, contractOp);
291  if (!targetShape)
292  return failure();
293  auto dstVecType = cast<VectorType>(contractOp.getResultType());
294  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
295 
296  Location loc = contractOp.getLoc();
297  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
299  llvm::MapVector<
301  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
302  accCache;
303 
305  contractOp.getIteratorTypes().size(), contractOp, options);
306 
307  for (SmallVector<int64_t> offsets :
308  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
310 
311  // Helper to compute the new shape of each operand and extract the slice.
312  auto extractOperand = [&](unsigned index, Value operand,
313  AffineMap permutationMap,
314  ArrayRef<int64_t> operandOffets) {
316  permutationMap, ArrayRef<int64_t>(*targetShape));
317  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318  slicesOperands[index] =
319  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
320  loc, operand, operandOffets, operandShape, operandStrides);
321  };
322 
323  // Extract the new lhs operand.
324  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325  SmallVector<int64_t> lhsOffets =
326  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
327  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
328 
329  // Extract the new rhs operand.
330  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331  SmallVector<int64_t> rhsOffets =
332  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
333  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
334 
335  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336  SmallVector<int64_t> accOffets =
337  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
338  // If a version of the accumulator has already been computed, use it
339  // otherwise extract the first version from the original operand.
340  auto *accIt = accCache.find(accOffets);
341  if (accIt != accCache.end())
342  slicesOperands[2] = accIt->second;
343  else
344  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
345 
346  SmallVector<int64_t> dstShape =
347  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
348  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
350  rewriter, loc, contractOp, slicesOperands, targetType);
351 
352  SmallVector<int64_t> dstOffets =
353  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
354  // Save the accumulated value untill all the loops are unrolled since
355  // reduction loop keep updating the accumulator.
356  accCache[dstOffets] = newOp->getResult(0);
357  }
358  // Assemble back the accumulator into a single vector.
359  Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
360  rewriter.getZeroAttr(dstVecType));
361  for (const auto &it : accCache) {
362  SmallVector<int64_t> dstStrides(it.first.size(), 1);
363  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
364  loc, it.second, result, it.first, dstStrides);
365  }
366  rewriter.replaceOp(contractOp, result);
367  return success();
368  }
369 
370 private:
372 };
373 
374 struct UnrollMultiReductionPattern
375  : public OpRewritePattern<vector::MultiDimReductionOp> {
376  UnrollMultiReductionPattern(MLIRContext *context,
378  PatternBenefit benefit = 1)
379  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
380  options(options) {}
381 
382  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383  PatternRewriter &rewriter) const override {
384  auto resultType = reductionOp->getResult(0).getType();
385  if (resultType.isIntOrFloat()) {
386  return rewriter.notifyMatchFailure(reductionOp,
387  "Unrolling scalars is not supported");
388  }
389  std::optional<SmallVector<int64_t>> targetShape =
390  getTargetShape(options, reductionOp);
391  if (!targetShape)
392  return failure();
393  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
394  llvm::MapVector<
396  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
397  accCache;
398  Location loc = reductionOp.getLoc();
399 
400  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
401  // of multiples of the targetShape.
402  for (SmallVector<int64_t> offsets :
403  StaticTileOffsetRange(originalSize, *targetShape)) {
404  SmallVector<Value> operands;
405  SmallVector<int64_t> operandStrides(offsets.size(), 1);
406  Value slicedOperand =
407  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
408  loc, reductionOp.getSource(), offsets, *targetShape,
409  operandStrides);
410  operands.push_back(slicedOperand);
411  SmallVector<int64_t> dstShape;
412  SmallVector<int64_t> destOffset;
413  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
414  if (!reductionOp.isReducedDim(i)) {
415  destOffset.push_back(offsets[i]);
416  dstShape.push_back((*targetShape)[i]);
417  }
418  }
419  Value acc;
420  SmallVector<int64_t> accStrides(destOffset.size(), 1);
421  // If a version of the accumulator has already been computed, use it
422  // otherwise extract the first version from the original operand.
423  auto *accIt = accCache.find(destOffset);
424  if (accIt != accCache.end())
425  acc = accIt->second;
426  else
427  acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
428  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429  operands.push_back(acc);
430  auto targetType = VectorType::get(
431  dstShape, reductionOp.getSourceVectorType().getElementType());
432  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
433  operands, targetType);
434  Value result = newOp->getResult(0);
435  accCache[destOffset] = result;
436  }
437  // Assemble back the accumulator into a single vector.
438  Value result = arith::ConstantOp::create(
439  rewriter, loc, reductionOp.getDestType(),
440  rewriter.getZeroAttr(reductionOp.getDestType()));
441  for (const auto &it : accCache) {
442  SmallVector<int64_t> dstStrides(it.first.size(), 1);
443  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
444  loc, it.second, result, it.first, dstStrides);
445  }
446  rewriter.replaceOp(reductionOp, result);
447  return success();
448  }
449 
450 private:
452 };
453 
454 struct UnrollElementwisePattern : public RewritePattern {
455  UnrollElementwisePattern(MLIRContext *context,
457  PatternBenefit benefit = 1)
458  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
459  options(options) {}
460 
461  LogicalResult matchAndRewrite(Operation *op,
462  PatternRewriter &rewriter) const override {
464  return failure();
465  auto targetShape = getTargetShape(options, op);
466  if (!targetShape)
467  return failure();
468  int64_t targetShapeRank = targetShape->size();
469  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
470  SmallVector<int64_t> originalSize =
471  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472  int64_t originalShapeRank = originalSize.size();
473 
474  Location loc = op->getLoc();
475 
476  // Handle rank mismatch by adding leading unit dimensions to targetShape
477  SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
478  int64_t rankDiff = originalShapeRank - targetShapeRank;
479  std::fill(adjustedTargetShape.begin(),
480  adjustedTargetShape.begin() + rankDiff, 1);
481  std::copy(targetShape->begin(), targetShape->end(),
482  adjustedTargetShape.begin() + rankDiff);
483 
484  int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
485  // Prepare the result vector.
486  Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
487  rewriter.getZeroAttr(dstVecType));
488  SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
489  VectorType unrolledVecType =
490  VectorType::get(*targetShape, dstVecType.getElementType());
491 
492  // Create the unrolled computation.
493  for (SmallVector<int64_t> offsets :
494  StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
495  SmallVector<Value> extractOperands;
496  for (OpOperand &operand : op->getOpOperands()) {
497  auto vecType = dyn_cast<VectorType>(operand.get().getType());
498  if (!vecType) {
499  extractOperands.push_back(operand.get());
500  continue;
501  }
502  Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
503  loc, operand.get(), offsets, adjustedTargetShape, strides);
504 
505  // Reshape to remove leading unit dims if needed
506  if (adjustedTargetShapeRank > targetShapeRank) {
507  extracted = rewriter.createOrFold<vector::ShapeCastOp>(
508  loc, VectorType::get(*targetShape, vecType.getElementType()),
509  extracted);
510  }
511  extractOperands.push_back(extracted);
512  }
513 
515  rewriter, loc, op, extractOperands, unrolledVecType);
516 
517  Value computeResult = newOp->getResult(0);
518 
519  // Use strides sized to targetShape for proper insertion
520  SmallVector<int64_t> insertStrides =
521  (adjustedTargetShapeRank > targetShapeRank)
522  ? SmallVector<int64_t>(targetShapeRank, 1)
523  : strides;
524 
525  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
526  loc, computeResult, result, offsets, insertStrides);
527  }
528  rewriter.replaceOp(op, result);
529  return success();
530  }
531 
532 private:
534 };
535 
536 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
537  UnrollReductionPattern(MLIRContext *context,
539  PatternBenefit benefit = 1)
540  : OpRewritePattern<vector::ReductionOp>(context, benefit),
541  options(options) {}
542 
543  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
544  PatternRewriter &rewriter) const override {
545  std::optional<SmallVector<int64_t>> targetShape =
546  getTargetShape(options, reductionOp);
547  if (!targetShape)
548  return failure();
549  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
550 
551  // Create unrolled vector reduction.
552  Location loc = reductionOp.getLoc();
553  Value accumulator = nullptr;
554  for (SmallVector<int64_t> offsets :
555  StaticTileOffsetRange(originalSize, *targetShape)) {
556  SmallVector<int64_t> strides(offsets.size(), 1);
557  Value slicedOperand =
558  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
559  loc, reductionOp.getVector(), offsets, *targetShape, strides);
561  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
562  Value result = newOp->getResult(0);
563 
564  if (!accumulator) {
565  // This is the first reduction.
566  accumulator = result;
567  } else {
568  // On subsequent reduction, combine with the accumulator.
569  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
570  accumulator, result);
571  }
572  }
573 
574  rewriter.replaceOp(reductionOp, accumulator);
575  return success();
576  }
577 
578 private:
580 };
581 
582 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
583  UnrollTransposePattern(MLIRContext *context,
585  PatternBenefit benefit = 1)
586  : OpRewritePattern<vector::TransposeOp>(context, benefit),
587  options(options) {}
588 
589  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
590  PatternRewriter &rewriter) const override {
591  if (transposeOp.getResultVectorType().getRank() == 0)
592  return failure();
593  auto targetShape = getTargetShape(options, transposeOp);
594  if (!targetShape)
595  return failure();
596  auto originalVectorType = transposeOp.getResultVectorType();
597  SmallVector<int64_t> strides(targetShape->size(), 1);
598  Location loc = transposeOp.getLoc();
599  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
600 
601  // Prepare the result vector;
602  Value result =
603  arith::ConstantOp::create(rewriter, loc, originalVectorType,
604  rewriter.getZeroAttr(originalVectorType));
605  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
606 
607  // Unroll the computation.
608  for (SmallVector<int64_t> elementOffsets :
609  StaticTileOffsetRange(originalSize, *targetShape)) {
610  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
611  SmallVector<int64_t> permutedShape(elementOffsets.size());
612  // Compute the source offsets and shape.
613  for (auto indices : llvm::enumerate(permutation)) {
614  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
615  permutedShape[indices.value()] = (*targetShape)[indices.index()];
616  }
617  Value slicedOperand =
618  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
619  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
620  strides);
621  Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
622  loc, slicedOperand, permutation);
623  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
624  loc, transposedSlice, result, elementOffsets, strides);
625  }
626  rewriter.replaceOp(transposeOp, result);
627  return success();
628  }
629 
630 private:
632 };
633 
634 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
635  UnrollGatherPattern(MLIRContext *context,
637  PatternBenefit benefit = 1)
638  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
639  }
640 
641  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
642  PatternRewriter &rewriter) const override {
643  VectorType sourceVectorType = gatherOp.getVectorType();
644  if (sourceVectorType.getRank() == 0)
645  return failure();
646  auto targetShape = getTargetShape(options, gatherOp);
647  if (!targetShape)
648  return failure();
649  SmallVector<int64_t> strides(targetShape->size(), 1);
650  Location loc = gatherOp.getLoc();
651  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
652 
653  // Prepare the result vector;
654  Value result =
655  arith::ConstantOp::create(rewriter, loc, sourceVectorType,
656  rewriter.getZeroAttr(sourceVectorType));
657  auto targetType =
658  VectorType::get(*targetShape, sourceVectorType.getElementType());
659 
660  SmallVector<int64_t> loopOrder =
661  getUnrollOrder(originalSize.size(), gatherOp, options);
662  for (SmallVector<int64_t> elementOffsets :
663  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
664  // To get the unrolled gather, extract the same slice based on the
665  // decomposed shape from each of the index, mask, and pass-through
666  // vectors.
667  Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
668  loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
669  Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
670  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
671  Value passThruSubVec =
672  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
673  loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
674  strides);
675  auto slicedGather = vector::GatherOp::create(
676  rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
677  indexSubVec, maskSubVec, passThruSubVec);
678 
679  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
680  loc, slicedGather, result, elementOffsets, strides);
681  }
682  rewriter.replaceOp(gatherOp, result);
683  return success();
684  }
685 
686 private:
688 };
689 
690 struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
691  UnrollLoadPattern(MLIRContext *context,
693  PatternBenefit benefit = 1)
694  : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
695 
696  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697  PatternRewriter &rewriter) const override {
698  VectorType vecType = loadOp.getVectorType();
699 
700  auto targetShape = getTargetShape(options, loadOp);
701  if (!targetShape)
702  return failure();
703 
704  Location loc = loadOp.getLoc();
705  ArrayRef<int64_t> originalShape = vecType.getShape();
706  SmallVector<int64_t> strides(targetShape->size(), 1);
707 
708  Value result = arith::ConstantOp::create(rewriter, loc, vecType,
709  rewriter.getZeroAttr(vecType));
710 
711  SmallVector<int64_t> loopOrder =
712  getUnrollOrder(originalShape.size(), loadOp, options);
713 
714  auto targetVecType =
715  VectorType::get(*targetShape, vecType.getElementType());
716 
717  for (SmallVector<int64_t> offsets :
718  StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
719  SmallVector<Value> indices =
720  sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
721  Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
722  loadOp.getBase(), indices);
723  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
724  loc, slicedLoad, result, offsets, strides);
725  }
726  rewriter.replaceOp(loadOp, result);
727  return success();
728  }
729 
730 private:
732 };
733 
734 struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
735  UnrollStorePattern(MLIRContext *context,
737  PatternBenefit benefit = 1)
738  : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
739 
740  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
741  PatternRewriter &rewriter) const override {
742  VectorType vecType = storeOp.getVectorType();
743 
744  auto targetShape = getTargetShape(options, storeOp);
745  if (!targetShape)
746  return failure();
747 
748  Location loc = storeOp.getLoc();
749  ArrayRef<int64_t> originalShape = vecType.getShape();
750  SmallVector<int64_t> strides(targetShape->size(), 1);
751 
752  Value base = storeOp.getBase();
753  Value vector = storeOp.getValueToStore();
754 
755  SmallVector<int64_t> loopOrder =
756  getUnrollOrder(originalShape.size(), storeOp, options);
757 
758  for (SmallVector<int64_t> offsets :
759  StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
760  SmallVector<Value> indices =
761  sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
762  Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
763  loc, vector, offsets, *targetShape, strides);
764  vector::StoreOp::create(rewriter, loc, slice, base, indices);
765  }
766  rewriter.eraseOp(storeOp);
767  return success();
768  }
769 
770 private:
772 };
773 
774 struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
775  UnrollBroadcastPattern(MLIRContext *context,
777  PatternBenefit benefit = 1)
778  : OpRewritePattern<vector::BroadcastOp>(context, benefit),
779  options(options) {}
780 
781  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
782  PatternRewriter &rewriter) const override {
783  auto targetShape = getTargetShape(options, broadcastOp);
784  if (!targetShape)
785  return failure();
786 
787  Location loc = broadcastOp.getLoc();
788  VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
789  VectorType resType = broadcastOp.getResultVectorType();
790  VectorType targetType =
791  resType.cloneWith(*targetShape, resType.getElementType());
792  Value result = arith::ConstantOp::create(rewriter, loc, resType,
793  rewriter.getZeroAttr(resType));
794 
795  SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
796  SmallVector<int64_t> strides(originalShape.size(), 1);
797 
798  for (SmallVector<int64_t> offsets :
799  StaticTileOffsetRange(originalShape, *targetShape)) {
800  Value newSrc;
801  if (!srcType) {
802  // Scalar to vector broadcast.
803  newSrc = broadcastOp.getSource();
804  } else {
805  // Vector to vector broadcast.
806  int64_t rank = srcType.getRank();
807  SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
808  SmallVector<int64_t> srcShape(targetShape->end() - rank,
809  targetShape->end());
810  SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
811  // adjust the offset and shape for src if the corresponding dim is 1.
812  for (int64_t i = 0; i < rank; ++i) {
813  if (srcType.getDimSize(i) == 1) {
814  srcOffsets[i] = 0;
815  srcShape[i] = 1;
816  }
817  }
818  newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
819  loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
820  }
821 
822  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
823  newSrc, targetType);
824 
825  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
826  loc, newOp->getResult(0), result, offsets, strides);
827  }
828 
829  rewriter.replaceOp(broadcastOp, result);
830  return success();
831  }
832 
833 private:
835 };
836 
837 /// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
838 /// outermost dimension of the operand. For example:
839 ///
840 /// ```
841 /// %0:4 = vector.to_elements %v : vector<2x2xf32>
842 ///
843 /// ==>
844 ///
845 /// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
846 /// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
847 /// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
848 /// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
849 /// ```
850 ///
851 /// When this pattern is applied until a fixed-point is reached,
852 /// this will produce a sequence of 1-d from_elements
853 /// ops.
854 struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
855  UnrollToElements(MLIRContext *context,
857  PatternBenefit benefit = 1)
858  : OpRewritePattern<vector::ToElementsOp>(context, benefit),
859  options(options) {}
860 
861  LogicalResult matchAndRewrite(vector::ToElementsOp op,
862  PatternRewriter &rewriter) const override {
863 
864  TypedValue<VectorType> source = op.getSource();
865  FailureOr<SmallVector<Value>> result =
866  vector::unrollVectorValue(source, rewriter);
867  if (failed(result)) {
868  return failure();
869  }
870  SmallVector<Value> vectors = *result;
871 
872  SmallVector<Value> results;
873  for (Value vector : vectors) {
874  auto subElements =
875  vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
876  llvm::append_range(results, subElements.getResults());
877  }
878  rewriter.replaceOp(op, results);
879  return success();
880  }
881 
882 private:
884 };
885 
886 /// This pattern unrolls `vector.step` operations according to the provided
887 /// target unroll shape. It decomposes a large step vector into smaller step
888 /// vectors (segments) and assembles the result by inserting each computed
889 /// segment into the appropriate offset of the original vector.
890 ///
891 /// The pattern does not support scalable vectors and will fail to match them.
892 ///
893 /// For each segment, it adds the base step vector and the segment's offset,
894 /// then inserts the result into the output vector at the corresponding
895 /// position.
896 ///
897 /// Example:
898 /// Given a step operation:
899 /// %0 = vector.step : vector<8xindex>
900 ///
901 /// and a target unroll shape of <4>, the pattern produces:
902 ///
903 /// %base = vector.step : vector<4xindex>
904 /// %zero = arith.constant dense<0> : vector<8xindex>
905 /// %result0 = vector.insert_strided_slice %base, %zero
906 /// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
907 /// %offset = arith.constant dense<4> : vector<4xindex>
908 /// %segment1 = arith.addi %base, %offset : vector<4xindex>
909 /// %result1 = vector.insert_strided_slice %segment1, %result0
910 /// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
911 ///
912 struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
913  UnrollStepPattern(MLIRContext *context,
915  PatternBenefit benefit = 1)
916  : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
917 
918  LogicalResult matchAndRewrite(vector::StepOp stepOp,
919  PatternRewriter &rewriter) const override {
920  std::optional<SmallVector<int64_t>> targetShape =
921  getTargetShape(options, stepOp);
922  if (!targetShape)
923  return failure();
924 
925  VectorType vecType = stepOp.getType();
926  if (vecType.isScalable()) {
927  // Scalable vectors are not supported by this pattern.
928  return failure();
929  }
930  int64_t originalSize = vecType.getShape()[0];
931  Location loc = stepOp.getLoc();
932  SmallVector<int64_t> strides(1, 1);
933 
934  Value result = arith::ConstantOp::create(rewriter, loc, vecType,
935  rewriter.getZeroAttr(vecType));
936 
937  auto targetVecType =
938  VectorType::get(*targetShape, vecType.getElementType());
939  Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
940  for (const SmallVector<int64_t> &offsets :
941  StaticTileOffsetRange({originalSize}, *targetShape)) {
942  Value bcastOffset = arith::ConstantOp::create(
943  rewriter, loc, targetVecType,
945  targetVecType,
946  IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
947  Value tileStep =
948  arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
949 
950  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
951  loc, tileStep, result, offsets, strides);
952  }
953  rewriter.replaceOp(stepOp, result);
954  return success();
955  }
956 
957 private:
959 };
960 
961 /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
962 /// outermost dimension. For example:
963 /// ```
964 /// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
965 ///
966 /// ==>
967 ///
968 /// %0 = ub.poison : vector<2x3xf32>
969 /// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
970 /// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
971 /// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
972 /// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
973 /// ```
974 ///
975 /// When this pattern is applied until a fixed-point is reached,
976 /// this will produce a sequence of 1-d from_elements
977 /// ops.
978 struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
979  UnrollFromElements(MLIRContext *context,
981  PatternBenefit benefit = 1)
982  : OpRewritePattern<vector::FromElementsOp>(context, benefit),
983  options(options) {}
984 
985  LogicalResult matchAndRewrite(vector::FromElementsOp op,
986  PatternRewriter &rewriter) const override {
987  ValueRange allElements = op.getElements();
988 
989  auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
990  VectorType subTy, int64_t index) {
991  size_t subTyNumElements = subTy.getNumElements();
992  assert((index + 1) * subTyNumElements <= allElements.size() &&
993  "out of bounds");
994  ValueRange subElements =
995  allElements.slice(index * subTyNumElements, subTyNumElements);
996  return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
997  };
998 
999  return unrollVectorOp(op, rewriter, unrollFromElementsFn);
1000  }
1001 
1002 private:
1004 };
1005 
1006 } // namespace
1007 
1008 void mlir::vector::populateVectorUnrollPatterns(
1010  PatternBenefit benefit) {
1011  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1012  UnrollContractionPattern, UnrollElementwisePattern,
1013  UnrollReductionPattern, UnrollMultiReductionPattern,
1014  UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1015  UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016  UnrollToElements, UnrollStepPattern>(patterns.getContext(),
1017  options, benefit);
1018 }
1019 
1020 void mlir::vector::populateVectorToElementsUnrollPatterns(
1022  patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1023  benefit);
1024 }
1025 
1026 void mlir::vector::populateVectorFromElementsUnrollPatterns(
1028  patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1029  benefit);
1030 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
This class represents an operand of an operation.
Definition: Value.h:257
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options that control the vector unrolling.