MLIR  22.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
21 #include "llvm/Support/InterleavedRange.h"
22 #include <optional>
23 
24 #define DEBUG_TYPE "vector-unroll"
25 
26 using namespace mlir;
27 using namespace mlir::vector;
28 
29 /// Compute the indices of the slice `index` for a transfer op.
31  ArrayRef<Value> indices,
32  AffineMap permutationMap,
33  Location loc,
34  OpBuilder &builder) {
35  MLIRContext *ctx = builder.getContext();
36  auto isBroadcast = [](AffineExpr expr) {
37  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38  return constExpr.getValue() == 0;
39  return false;
40  };
41  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
42  SmallVector<Value> slicedIndices(indices);
43  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
44  if (isBroadcast(dim.value()))
45  continue;
46  unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
47  auto expr = getAffineDimExpr(0, builder.getContext()) +
48  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
49  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
50  slicedIndices[pos] =
51  affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
52  }
53  return slicedIndices;
54 }
55 
56 // Compute the new indices by adding `offsets` to `originalIndices`.
57 // If m < n (m = offsets.size(), n = originalIndices.size()),
58 // then only the trailing m values in `originalIndices` are updated.
60  Location loc,
61  OperandRange originalIndices,
62  ArrayRef<int64_t> offsets) {
63  assert(offsets.size() <= originalIndices.size() &&
64  "Offsets should not exceed the number of original indices");
65  SmallVector<Value> indices(originalIndices);
66 
67  auto start = indices.size() - offsets.size();
68  for (auto [i, offset] : llvm::enumerate(offsets)) {
69  if (offset != 0) {
70  indices[start + i] = arith::AddIOp::create(
71  rewriter, loc, originalIndices[start + i],
72  arith::ConstantIndexOp::create(rewriter, loc, offset));
73  }
74  }
75  return indices;
76 }
77 
78 // Clones `op` into a new operations that takes `operands` and returns
79 // `resultTypes`.
81  Operation *op,
82  ArrayRef<Value> operands,
83  ArrayRef<Type> resultTypes) {
84  return builder.create(loc, op->getName().getIdentifier(), operands,
85  resultTypes, op->getAttrs());
86 }
87 
88 /// Return the target shape for unrolling for the given `op`. Return
89 /// std::nullopt if the op shouldn't be or cannot be unrolled.
90 static std::optional<SmallVector<int64_t>>
92  LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
93  if (options.filterConstraint && failed(options.filterConstraint(op))) {
94  LDBG() << "--no filter constraint -> BAIL";
95  return std::nullopt;
96  }
97  assert(options.nativeShape &&
98  "vector unrolling expects the native shape or native"
99  "shape call back function to be set");
100  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101  if (!unrollableVectorOp) {
102  LDBG() << "--not an unrollable op -> BAIL";
103  return std::nullopt;
104  }
105  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106  if (!maybeUnrollShape) {
107  LDBG() << "--could not get shape of op " << *op << " -> BAIL";
108  return std::nullopt;
109  }
110  LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
111 
112  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
113  if (!targetShape) {
114  LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
115  return std::nullopt;
116  }
117  LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
118 
119  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
120  if (!maybeShapeRatio) {
121  LDBG() << "--could not compute integral shape ratio -> BAIL";
122  return std::nullopt;
123  }
124  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
125  LDBG() << "--no unrolling needed -> SKIP";
126  return std::nullopt;
127  }
128  LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
129  return targetShape;
130 }
131 
133 getUnrollOrder(unsigned numLoops, Operation *op,
135  SmallVector<int64_t> loopOrder =
136  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
137  if (options.traversalOrderCallback != nullptr) {
138  std::optional<SmallVector<int64_t>> order =
139  options.traversalOrderCallback(op);
140  if (order) {
141  loopOrder = std::move(*order);
142  }
143  }
144  return loopOrder;
145 }
146 
147 namespace {
148 
149 struct UnrollTransferReadPattern
150  : public OpRewritePattern<vector::TransferReadOp> {
151  UnrollTransferReadPattern(MLIRContext *context,
153  PatternBenefit benefit = 1)
154  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155  options(options) {}
156 
157  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158  PatternRewriter &rewriter) const override {
159  // TODO: support 0-d corner case.
160  if (readOp.getTransferRank() == 0)
161  return failure();
162  if (readOp.getMask())
163  return failure();
164  auto targetShape = getTargetShape(options, readOp);
165  if (!targetShape)
166  return failure();
167  auto sourceVectorType = readOp.getVectorType();
168  SmallVector<int64_t> strides(targetShape->size(), 1);
169  Location loc = readOp.getLoc();
170  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
171 
172  // Prepare the result vector;
173  Value result =
174  arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175  rewriter.getZeroAttr(sourceVectorType));
176  auto targetType =
177  VectorType::get(*targetShape, sourceVectorType.getElementType());
178  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179  readOp.getIndices().end());
180  SmallVector<int64_t> loopOrder =
181  getUnrollOrder(originalSize.size(), readOp, options);
182  for (SmallVector<int64_t> elementOffsets :
183  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
184  SmallVector<Value> indices =
185  sliceTransferIndices(elementOffsets, originalIndices,
186  readOp.getPermutationMap(), loc, rewriter);
187  auto slicedRead = vector::TransferReadOp::create(
188  rewriter, loc, targetType, readOp.getBase(), indices,
189  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190  readOp.getInBoundsAttr());
191 
192  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
193  loc, slicedRead, result, elementOffsets, strides);
194  }
195  rewriter.replaceOp(readOp, result);
196  return success();
197  }
198 
199 private:
201 };
202 
203 struct UnrollTransferWritePattern
204  : public OpRewritePattern<vector::TransferWriteOp> {
205  UnrollTransferWritePattern(MLIRContext *context,
207  PatternBenefit benefit = 1)
208  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209  options(options) {}
210 
211  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212  PatternRewriter &rewriter) const override {
213  // TODO: support 0-d corner case.
214  if (writeOp.getTransferRank() == 0)
215  return failure();
216 
217  if (writeOp.getMask())
218  return failure();
219  auto targetShape = getTargetShape(options, writeOp);
220  if (!targetShape)
221  return failure();
222  auto sourceVectorType = writeOp.getVectorType();
223  SmallVector<int64_t> strides(targetShape->size(), 1);
224  Location loc = writeOp.getLoc();
225  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
226  // Bail-out if rank(source) != rank(target). The main limitation here is the
227  // fact that `ExtractStridedSlice` requires the rank for the input and
228  // output to match. If needed, we can relax this later.
229  if (originalSize.size() != targetShape->size())
230  return rewriter.notifyMatchFailure(
231  writeOp,
232  "expected source input vector rank to match target shape rank");
233 
234  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235  writeOp.getIndices().end());
236  SmallVector<int64_t> loopOrder =
237  getUnrollOrder(originalSize.size(), writeOp, options);
238  Value resultTensor;
239  for (SmallVector<int64_t> elementOffsets :
240  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241  Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
242  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
243  SmallVector<Value> indices =
244  sliceTransferIndices(elementOffsets, originalIndices,
245  writeOp.getPermutationMap(), loc, rewriter);
246  Operation *slicedWrite = vector::TransferWriteOp::create(
247  rewriter, loc, slicedVector,
248  resultTensor ? resultTensor : writeOp.getBase(), indices,
249  writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250  // For the tensor case update the destination for the next transfer write.
251  if (!slicedWrite->getResults().empty())
252  resultTensor = slicedWrite->getResult(0);
253  }
254  if (resultTensor)
255  rewriter.replaceOp(writeOp, resultTensor);
256  else
257  rewriter.eraseOp(writeOp);
258  return success();
259  }
260 
261 private:
263 };
264 
265 struct OffsetMapInfo {
266  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
267 
268  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
269 
270  static unsigned getHashValue(const SmallVector<int64_t> &v) {
271  return static_cast<unsigned>(llvm::hash_combine_range(v));
272  }
273 
274  static bool isEqual(const SmallVector<int64_t> &lhs,
275  const SmallVector<int64_t> &rhs) {
276  return lhs == rhs;
277  }
278 };
279 
280 struct UnrollContractionPattern
281  : public OpRewritePattern<vector::ContractionOp> {
282  UnrollContractionPattern(MLIRContext *context,
284  PatternBenefit benefit = 1)
285  : OpRewritePattern<vector::ContractionOp>(context, benefit),
286  options(options) {}
287 
288  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289  PatternRewriter &rewriter) const override {
290  auto targetShape = getTargetShape(options, contractOp);
291  if (!targetShape)
292  return failure();
293  auto dstVecType = cast<VectorType>(contractOp.getResultType());
294  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
295 
296  Location loc = contractOp.getLoc();
297  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
299  llvm::MapVector<
301  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
302  accCache;
303 
305  contractOp.getIteratorTypes().size(), contractOp, options);
306 
307  for (SmallVector<int64_t> offsets :
308  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
310 
311  // Helper to compute the new shape of each operand and extract the slice.
312  auto extractOperand = [&](unsigned index, Value operand,
313  AffineMap permutationMap,
314  ArrayRef<int64_t> operandOffets) {
316  permutationMap, ArrayRef<int64_t>(*targetShape));
317  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318  slicesOperands[index] =
319  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
320  loc, operand, operandOffets, operandShape, operandStrides);
321  };
322 
323  // Extract the new lhs operand.
324  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325  SmallVector<int64_t> lhsOffets =
326  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
327  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
328 
329  // Extract the new rhs operand.
330  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331  SmallVector<int64_t> rhsOffets =
332  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
333  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
334 
335  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336  SmallVector<int64_t> accOffets =
337  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
338  // If a version of the accumulator has already been computed, use it
339  // otherwise extract the first version from the original operand.
340  auto *accIt = accCache.find(accOffets);
341  if (accIt != accCache.end())
342  slicesOperands[2] = accIt->second;
343  else
344  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
345 
346  SmallVector<int64_t> dstShape =
347  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
348  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
350  rewriter, loc, contractOp, slicesOperands, targetType);
351 
352  SmallVector<int64_t> dstOffets =
353  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
354  // Save the accumulated value untill all the loops are unrolled since
355  // reduction loop keep updating the accumulator.
356  accCache[dstOffets] = newOp->getResult(0);
357  }
358  // Assemble back the accumulator into a single vector.
359  Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
360  rewriter.getZeroAttr(dstVecType));
361  for (const auto &it : accCache) {
362  SmallVector<int64_t> dstStrides(it.first.size(), 1);
363  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
364  loc, it.second, result, it.first, dstStrides);
365  }
366  rewriter.replaceOp(contractOp, result);
367  return success();
368  }
369 
370 private:
372 };
373 
374 struct UnrollMultiReductionPattern
375  : public OpRewritePattern<vector::MultiDimReductionOp> {
376  UnrollMultiReductionPattern(MLIRContext *context,
378  PatternBenefit benefit = 1)
379  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
380  options(options) {}
381 
382  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383  PatternRewriter &rewriter) const override {
384  auto resultType = reductionOp->getResult(0).getType();
385  if (resultType.isIntOrFloat()) {
386  return rewriter.notifyMatchFailure(reductionOp,
387  "Unrolling scalars is not supported");
388  }
389  std::optional<SmallVector<int64_t>> targetShape =
390  getTargetShape(options, reductionOp);
391  if (!targetShape)
392  return failure();
393  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
394  llvm::MapVector<
396  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
397  accCache;
398  Location loc = reductionOp.getLoc();
399 
400  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
401  // of multiples of the targetShape.
402  for (SmallVector<int64_t> offsets :
403  StaticTileOffsetRange(originalSize, *targetShape)) {
404  SmallVector<Value> operands;
405  SmallVector<int64_t> operandStrides(offsets.size(), 1);
406  Value slicedOperand =
407  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
408  loc, reductionOp.getSource(), offsets, *targetShape,
409  operandStrides);
410  operands.push_back(slicedOperand);
411  SmallVector<int64_t> dstShape;
412  SmallVector<int64_t> destOffset;
413  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
414  if (!reductionOp.isReducedDim(i)) {
415  destOffset.push_back(offsets[i]);
416  dstShape.push_back((*targetShape)[i]);
417  }
418  }
419  Value acc;
420  SmallVector<int64_t> accStrides(destOffset.size(), 1);
421  // If a version of the accumulator has already been computed, use it
422  // otherwise extract the first version from the original operand.
423  auto *accIt = accCache.find(destOffset);
424  if (accIt != accCache.end())
425  acc = accIt->second;
426  else
427  acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
428  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429  operands.push_back(acc);
430  auto targetType = VectorType::get(
431  dstShape, reductionOp.getSourceVectorType().getElementType());
432  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
433  operands, targetType);
434  Value result = newOp->getResult(0);
435  accCache[destOffset] = result;
436  }
437  // Assemble back the accumulator into a single vector.
438  Value result = arith::ConstantOp::create(
439  rewriter, loc, reductionOp.getDestType(),
440  rewriter.getZeroAttr(reductionOp.getDestType()));
441  for (const auto &it : accCache) {
442  SmallVector<int64_t> dstStrides(it.first.size(), 1);
443  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
444  loc, it.second, result, it.first, dstStrides);
445  }
446  rewriter.replaceOp(reductionOp, result);
447  return success();
448  }
449 
450 private:
452 };
453 
454 struct UnrollElementwisePattern : public RewritePattern {
455  UnrollElementwisePattern(MLIRContext *context,
457  PatternBenefit benefit = 1)
458  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
459  options(options) {}
460 
461  LogicalResult matchAndRewrite(Operation *op,
462  PatternRewriter &rewriter) const override {
464  return failure();
465  auto targetShape = getTargetShape(options, op);
466  if (!targetShape)
467  return failure();
468  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
469  SmallVector<int64_t> originalSize =
470  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
471  // Bail-out if rank(source) != rank(target). The main limitation here is the
472  // fact that `ExtractStridedSlice` requires the rank for the input and
473  // output to match. If needed, we can relax this later.
474  if (originalSize.size() != targetShape->size())
475  return rewriter.notifyMatchFailure(
476  op, "expected input vector rank to match target shape rank");
477  Location loc = op->getLoc();
478  // Prepare the result vector.
479  Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
480  rewriter.getZeroAttr(dstVecType));
481  SmallVector<int64_t> strides(targetShape->size(), 1);
482  VectorType newVecType =
483  VectorType::get(*targetShape, dstVecType.getElementType());
484 
485  // Create the unrolled computation.
486  for (SmallVector<int64_t> offsets :
487  StaticTileOffsetRange(originalSize, *targetShape)) {
488  SmallVector<Value> extractOperands;
489  for (OpOperand &operand : op->getOpOperands()) {
490  auto vecType = dyn_cast<VectorType>(operand.get().getType());
491  if (!vecType) {
492  extractOperands.push_back(operand.get());
493  continue;
494  }
495  extractOperands.push_back(
496  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
497  loc, operand.get(), offsets, *targetShape, strides));
498  }
500  rewriter, loc, op, extractOperands, newVecType);
501  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
502  loc, newOp->getResult(0), result, offsets, strides);
503  }
504  rewriter.replaceOp(op, result);
505  return success();
506  }
507 
508 private:
510 };
511 
512 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
513  UnrollReductionPattern(MLIRContext *context,
515  PatternBenefit benefit = 1)
516  : OpRewritePattern<vector::ReductionOp>(context, benefit),
517  options(options) {}
518 
519  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
520  PatternRewriter &rewriter) const override {
521  std::optional<SmallVector<int64_t>> targetShape =
522  getTargetShape(options, reductionOp);
523  if (!targetShape)
524  return failure();
525  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
526 
527  // Create unrolled vector reduction.
528  Location loc = reductionOp.getLoc();
529  Value accumulator = nullptr;
530  for (SmallVector<int64_t> offsets :
531  StaticTileOffsetRange(originalSize, *targetShape)) {
532  SmallVector<int64_t> strides(offsets.size(), 1);
533  Value slicedOperand =
534  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
535  loc, reductionOp.getVector(), offsets, *targetShape, strides);
537  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
538  Value result = newOp->getResult(0);
539 
540  if (!accumulator) {
541  // This is the first reduction.
542  accumulator = result;
543  } else {
544  // On subsequent reduction, combine with the accumulator.
545  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
546  accumulator, result);
547  }
548  }
549 
550  rewriter.replaceOp(reductionOp, accumulator);
551  return success();
552  }
553 
554 private:
556 };
557 
558 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
559  UnrollTransposePattern(MLIRContext *context,
561  PatternBenefit benefit = 1)
562  : OpRewritePattern<vector::TransposeOp>(context, benefit),
563  options(options) {}
564 
565  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
566  PatternRewriter &rewriter) const override {
567  if (transposeOp.getResultVectorType().getRank() == 0)
568  return failure();
569  auto targetShape = getTargetShape(options, transposeOp);
570  if (!targetShape)
571  return failure();
572  auto originalVectorType = transposeOp.getResultVectorType();
573  SmallVector<int64_t> strides(targetShape->size(), 1);
574  Location loc = transposeOp.getLoc();
575  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
576 
577  // Prepare the result vector;
578  Value result =
579  arith::ConstantOp::create(rewriter, loc, originalVectorType,
580  rewriter.getZeroAttr(originalVectorType));
581  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
582 
583  // Unroll the computation.
584  for (SmallVector<int64_t> elementOffsets :
585  StaticTileOffsetRange(originalSize, *targetShape)) {
586  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
587  SmallVector<int64_t> permutedShape(elementOffsets.size());
588  // Compute the source offsets and shape.
589  for (auto indices : llvm::enumerate(permutation)) {
590  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
591  permutedShape[indices.value()] = (*targetShape)[indices.index()];
592  }
593  Value slicedOperand =
594  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
595  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
596  strides);
597  Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
598  loc, slicedOperand, permutation);
599  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
600  loc, transposedSlice, result, elementOffsets, strides);
601  }
602  rewriter.replaceOp(transposeOp, result);
603  return success();
604  }
605 
606 private:
608 };
609 
610 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
611  UnrollGatherPattern(MLIRContext *context,
613  PatternBenefit benefit = 1)
614  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
615  }
616 
617  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
618  PatternRewriter &rewriter) const override {
619  VectorType sourceVectorType = gatherOp.getVectorType();
620  if (sourceVectorType.getRank() == 0)
621  return failure();
622  auto targetShape = getTargetShape(options, gatherOp);
623  if (!targetShape)
624  return failure();
625  SmallVector<int64_t> strides(targetShape->size(), 1);
626  Location loc = gatherOp.getLoc();
627  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
628 
629  // Prepare the result vector;
630  Value result =
631  arith::ConstantOp::create(rewriter, loc, sourceVectorType,
632  rewriter.getZeroAttr(sourceVectorType));
633  auto targetType =
634  VectorType::get(*targetShape, sourceVectorType.getElementType());
635 
636  SmallVector<int64_t> loopOrder =
637  getUnrollOrder(originalSize.size(), gatherOp, options);
638  for (SmallVector<int64_t> elementOffsets :
639  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
640  // To get the unrolled gather, extract the same slice based on the
641  // decomposed shape from each of the index, mask, and pass-through
642  // vectors.
643  Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
644  loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
645  Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
646  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
647  Value passThruSubVec =
648  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
649  loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
650  strides);
651  auto slicedGather = vector::GatherOp::create(
652  rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
653  indexSubVec, maskSubVec, passThruSubVec);
654 
655  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
656  loc, slicedGather, result, elementOffsets, strides);
657  }
658  rewriter.replaceOp(gatherOp, result);
659  return success();
660  }
661 
662 private:
664 };
665 
666 struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
667  UnrollLoadPattern(MLIRContext *context,
669  PatternBenefit benefit = 1)
670  : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
671 
672  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
673  PatternRewriter &rewriter) const override {
674  VectorType vecType = loadOp.getVectorType();
675 
676  auto targetShape = getTargetShape(options, loadOp);
677  if (!targetShape)
678  return failure();
679 
680  Location loc = loadOp.getLoc();
681  ArrayRef<int64_t> originalShape = vecType.getShape();
682  SmallVector<int64_t> strides(targetShape->size(), 1);
683 
684  Value result = arith::ConstantOp::create(rewriter, loc, vecType,
685  rewriter.getZeroAttr(vecType));
686 
687  SmallVector<int64_t> loopOrder =
688  getUnrollOrder(originalShape.size(), loadOp, options);
689 
690  auto targetVecType =
691  VectorType::get(*targetShape, vecType.getElementType());
692 
693  for (SmallVector<int64_t> offsets :
694  StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
695  SmallVector<Value> indices =
696  sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
697  Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
698  loadOp.getBase(), indices);
699  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
700  loc, slicedLoad, result, offsets, strides);
701  }
702  rewriter.replaceOp(loadOp, result);
703  return success();
704  }
705 
706 private:
708 };
709 
710 struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
711  UnrollStorePattern(MLIRContext *context,
713  PatternBenefit benefit = 1)
714  : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
715 
716  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
717  PatternRewriter &rewriter) const override {
718  VectorType vecType = storeOp.getVectorType();
719 
720  auto targetShape = getTargetShape(options, storeOp);
721  if (!targetShape)
722  return failure();
723 
724  Location loc = storeOp.getLoc();
725  ArrayRef<int64_t> originalShape = vecType.getShape();
726  SmallVector<int64_t> strides(targetShape->size(), 1);
727 
728  Value base = storeOp.getBase();
729  Value vector = storeOp.getValueToStore();
730 
731  SmallVector<int64_t> loopOrder =
732  getUnrollOrder(originalShape.size(), storeOp, options);
733 
734  for (SmallVector<int64_t> offsets :
735  StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
736  SmallVector<Value> indices =
737  sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
738  Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
739  loc, vector, offsets, *targetShape, strides);
740  vector::StoreOp::create(rewriter, loc, slice, base, indices);
741  }
742  rewriter.eraseOp(storeOp);
743  return success();
744  }
745 
746 private:
748 };
749 
750 struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
751  UnrollBroadcastPattern(MLIRContext *context,
753  PatternBenefit benefit = 1)
754  : OpRewritePattern<vector::BroadcastOp>(context, benefit),
755  options(options) {}
756 
757  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
758  PatternRewriter &rewriter) const override {
759  auto targetShape = getTargetShape(options, broadcastOp);
760  if (!targetShape)
761  return failure();
762 
763  Location loc = broadcastOp.getLoc();
764  VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
765  VectorType resType = broadcastOp.getResultVectorType();
766  VectorType targetType =
767  resType.cloneWith(*targetShape, resType.getElementType());
768  Value result = arith::ConstantOp::create(rewriter, loc, resType,
769  rewriter.getZeroAttr(resType));
770 
771  SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
772  SmallVector<int64_t> strides(originalShape.size(), 1);
773 
774  for (SmallVector<int64_t> offsets :
775  StaticTileOffsetRange(originalShape, *targetShape)) {
776  Value newSrc;
777  if (!srcType) {
778  // Scalar to vector broadcast.
779  newSrc = broadcastOp.getSource();
780  } else {
781  // Vector to vector broadcast.
782  int64_t rank = srcType.getRank();
783  SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
784  SmallVector<int64_t> srcShape(targetShape->end() - rank,
785  targetShape->end());
786  SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
787  // adjust the offset and shape for src if the corresponding dim is 1.
788  for (int64_t i = 0; i < rank; ++i) {
789  if (srcType.getDimSize(i) == 1) {
790  srcOffsets[i] = 0;
791  srcShape[i] = 1;
792  }
793  }
794  newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
795  loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
796  }
797 
798  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
799  newSrc, targetType);
800 
801  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
802  loc, newOp->getResult(0), result, offsets, strides);
803  }
804 
805  rewriter.replaceOp(broadcastOp, result);
806  return success();
807  }
808 
809 private:
811 };
812 
813 /// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
814 /// outermost dimension of the operand. For example:
815 ///
816 /// ```
817 /// %0:4 = vector.to_elements %v : vector<2x2xf32>
818 ///
819 /// ==>
820 ///
821 /// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
822 /// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
823 /// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
824 /// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
825 /// ```
826 ///
827 /// When this pattern is applied until a fixed-point is reached,
828 /// this will produce a sequence of 1-d from_elements
829 /// ops.
830 struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
831  UnrollToElements(MLIRContext *context,
833  PatternBenefit benefit = 1)
834  : OpRewritePattern<vector::ToElementsOp>(context, benefit),
835  options(options) {}
836 
837  LogicalResult matchAndRewrite(vector::ToElementsOp op,
838  PatternRewriter &rewriter) const override {
839 
840  TypedValue<VectorType> source = op.getSource();
841  FailureOr<SmallVector<Value>> result =
842  vector::unrollVectorValue(source, rewriter);
843  if (failed(result)) {
844  return failure();
845  }
846  SmallVector<Value> vectors = *result;
847 
848  SmallVector<Value> results;
849  for (Value vector : vectors) {
850  auto subElements =
851  vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
852  llvm::append_range(results, subElements.getResults());
853  }
854  rewriter.replaceOp(op, results);
855  return success();
856  }
857 
858 private:
860 };
861 
862 /// This pattern unrolls `vector.step` operations according to the provided
863 /// target unroll shape. It decomposes a large step vector into smaller step
864 /// vectors (segments) and assembles the result by inserting each computed
865 /// segment into the appropriate offset of the original vector.
866 ///
867 /// The pattern does not support scalable vectors and will fail to match them.
868 ///
869 /// For each segment, it adds the base step vector and the segment's offset,
870 /// then inserts the result into the output vector at the corresponding
871 /// position.
872 ///
873 /// Example:
874 /// Given a step operation:
875 /// %0 = vector.step : vector<8xindex>
876 ///
877 /// and a target unroll shape of <4>, the pattern produces:
878 ///
879 /// %base = vector.step : vector<4xindex>
880 /// %zero = arith.constant dense<0> : vector<8xindex>
881 /// %result0 = vector.insert_strided_slice %base, %zero
882 /// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
883 /// %offset = arith.constant dense<4> : vector<4xindex>
884 /// %segment1 = arith.addi %base, %offset : vector<4xindex>
885 /// %result1 = vector.insert_strided_slice %segment1, %result0
886 /// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
887 ///
888 struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
889  UnrollStepPattern(MLIRContext *context,
891  PatternBenefit benefit = 1)
892  : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
893 
894  LogicalResult matchAndRewrite(vector::StepOp stepOp,
895  PatternRewriter &rewriter) const override {
896  std::optional<SmallVector<int64_t>> targetShape =
897  getTargetShape(options, stepOp);
898  if (!targetShape)
899  return failure();
900 
901  VectorType vecType = stepOp.getType();
902  if (vecType.isScalable()) {
903  // Scalable vectors are not supported by this pattern.
904  return failure();
905  }
906  int64_t originalSize = vecType.getShape()[0];
907  Location loc = stepOp.getLoc();
908  SmallVector<int64_t> strides(1, 1);
909 
910  Value result = arith::ConstantOp::create(rewriter, loc, vecType,
911  rewriter.getZeroAttr(vecType));
912 
913  auto targetVecType =
914  VectorType::get(*targetShape, vecType.getElementType());
915  Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
916  for (const SmallVector<int64_t> &offsets :
917  StaticTileOffsetRange({originalSize}, *targetShape)) {
918  Value bcastOffset = arith::ConstantOp::create(
919  rewriter, loc, targetVecType,
921  targetVecType,
922  IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
923  Value tileStep =
924  arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
925 
926  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
927  loc, tileStep, result, offsets, strides);
928  }
929  rewriter.replaceOp(stepOp, result);
930  return success();
931  }
932 
933 private:
935 };
936 
937 /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
938 /// outermost dimension. For example:
939 /// ```
940 /// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
941 ///
942 /// ==>
943 ///
944 /// %0 = ub.poison : vector<2x3xf32>
945 /// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
946 /// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
947 /// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
948 /// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
949 /// ```
950 ///
951 /// When this pattern is applied until a fixed-point is reached,
952 /// this will produce a sequence of 1-d from_elements
953 /// ops.
954 struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
955  UnrollFromElements(MLIRContext *context,
957  PatternBenefit benefit = 1)
958  : OpRewritePattern<vector::FromElementsOp>(context, benefit),
959  options(options) {}
960 
961  LogicalResult matchAndRewrite(vector::FromElementsOp op,
962  PatternRewriter &rewriter) const override {
963  ValueRange allElements = op.getElements();
964 
965  auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
966  VectorType subTy, int64_t index) {
967  size_t subTyNumElements = subTy.getNumElements();
968  assert((index + 1) * subTyNumElements <= allElements.size() &&
969  "out of bounds");
970  ValueRange subElements =
971  allElements.slice(index * subTyNumElements, subTyNumElements);
972  return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
973  };
974 
975  return unrollVectorOp(op, rewriter, unrollFromElementsFn);
976  }
977 
978 private:
980 };
981 
982 } // namespace
983 
984 void mlir::vector::populateVectorUnrollPatterns(
986  PatternBenefit benefit) {
987  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
988  UnrollContractionPattern, UnrollElementwisePattern,
989  UnrollReductionPattern, UnrollMultiReductionPattern,
990  UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
991  UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
992  UnrollToElements, UnrollStepPattern>(patterns.getContext(),
993  options, benefit);
994 }
995 
996 void mlir::vector::populateVectorToElementsUnrollPatterns(
998  patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
999  benefit);
1000 }
1001 
1002 void mlir::vector::populateVectorFromElementsUnrollPatterns(
1004  patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1005  benefit);
1006 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
This class represents an operand of an operation.
Definition: Value.h:257
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options that control the vector unrolling.