MLIR  17.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include <numeric>
22 #include <optional>
23 
24 #define DEBUG_TYPE "vector-unrolling"
25 
26 using namespace mlir;
27 using namespace mlir::vector;
28 
29 /// During unrolling from `originalShape` to `targetShape` return the offset for
30 /// the slice `index`.
32  int64_t index,
33  ArrayRef<int64_t> targetShape) {
34  return computeElementwiseMul(delinearize(ratioStrides, index), targetShape);
35 }
36 
37 /// A functor that accomplishes the same thing as `getVectorOffset` but
38 /// allows for reordering the traversal of the dimensions. The order of
39 /// traversal is given in "for loop order" (outer to inner).
40 namespace {
41 class DecomposeShapeIterator {
42 private:
44  SmallVector<int64_t> loopOrder;
45  SmallVector<int64_t> sliceStrides;
46  int64_t maxIndexVal{1};
47 
48 public:
49  DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
50  ArrayRef<int64_t> targetShape,
51  ArrayRef<int64_t> loopOrder)
52  : vectorShape(targetShape.begin(), targetShape.end()),
53  loopOrder(loopOrder.begin(), loopOrder.end()),
54  sliceStrides(originalShape.size()) {
55  assert(originalShape.size() >= targetShape.size());
56  assert(loopOrder.size() == originalShape.size());
57 
58  // Compute the count for each dimension.
59  auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape);
60  assert(maybeShapeRatio && "Shape does not evenly divide");
61  // Pad `sliceDimCounts` with leading 1s so that all sizes match.
62  SmallVector<int64_t> sliceDimCounts = *maybeShapeRatio;
63  maxIndexVal = computeMaxLinearIndex(sliceDimCounts);
64 
65  // Reversing "loop order" gives dimensions from fastest varying to slowest
66  // varying (smallest stride to largest stride).
67  int64_t accum = 1;
68  for (auto idx : llvm::reverse(loopOrder)) {
69  sliceStrides[idx] = accum;
70  accum *= sliceDimCounts[idx];
71  }
72  }
73 
74  // Turn the linear index into a d-tuple based on units of vectors of size
75  // `vectorShape`. The linear index is assumed to represent traversal of the
76  // dimensions based on `order`.
77  SmallVector<int64_t> delinearize(int64_t index) const {
78  // Traverse in for loop order (largest stride to smallest stride).
79  SmallVector<int64_t> vectorOffsets(sliceStrides.size());
80  for (auto idx : loopOrder) {
81  vectorOffsets[idx] = index / sliceStrides[idx];
82  index %= sliceStrides[idx];
83  }
84  return vectorOffsets;
85  }
86 
87  int64_t maxIndex() const { return maxIndexVal; }
88 
89  /// Return the offset within d-tuple based on the ordering given by
90  /// `loopOrder`.
91  SmallVector<int64_t> getVectorOffset(int64_t index) const {
92  SmallVector<int64_t> vectorOffsets = delinearize(index);
93  SmallVector<int64_t> elementOffsets =
94  computeElementwiseMul(vectorShape, vectorOffsets);
95  return elementOffsets;
96  }
97 };
98 } // namespace
99 
100 /// Compute the indices of the slice `index` for a tranfer op.
102  ArrayRef<Value> indices,
103  AffineMap permutationMap,
104  Location loc,
105  OpBuilder &builder) {
106  MLIRContext *ctx = builder.getContext();
107  auto isBroadcast = [](AffineExpr expr) {
108  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
109  return constExpr.getValue() == 0;
110  return false;
111  };
112  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
113  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
114  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
115  if (isBroadcast(dim.value()))
116  continue;
117  unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
118  auto expr = getAffineDimExpr(0, builder.getContext()) +
119  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
120  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
121  slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
122  }
123  return slicedIndices;
124 }
125 
126 // Clones `op` into a new operations that takes `operands` and returns
127 // `resultTypes`.
129  Operation *op,
130  ArrayRef<Value> operands,
131  ArrayRef<Type> resultTypes) {
132  return builder.create(loc, op->getName().getIdentifier(), operands,
133  resultTypes, op->getAttrs());
134 }
135 
136 /// Return the target shape for unrolling for the given `op`. Return
137 /// std::nullopt if the op shouldn't be or cannot be unrolled.
138 static std::optional<SmallVector<int64_t>>
140  if (options.filterConstraint && failed(options.filterConstraint(op)))
141  return std::nullopt;
142  assert(options.nativeShape &&
143  "vector unrolling expects the native shape or native"
144  "shape call back function to be set");
145  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
146  if (!unrollableVectorOp)
147  return std::nullopt;
148  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
149  if (!maybeUnrollShape)
150  return std::nullopt;
151  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
152  if (!targetShape)
153  return std::nullopt;
154  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
155  if (!maybeShapeRatio ||
156  llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
157  return std::nullopt;
158  return targetShape;
159 }
160 
162 getUnrollOrder(unsigned numLoops, Operation *op,
164  SmallVector<int64_t> loopOrder =
165  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
166  if (options.traversalOrderCallback != nullptr) {
167  std::optional<SmallVector<int64_t>> order =
168  options.traversalOrderCallback(op);
169  if (order) {
170  loopOrder = std::move(*order);
171  }
172  }
173  return loopOrder;
174 }
175 
176 namespace {
177 
178 struct UnrollTransferReadPattern
179  : public OpRewritePattern<vector::TransferReadOp> {
180  UnrollTransferReadPattern(MLIRContext *context,
182  PatternBenefit benefit = 1)
183  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
184  options(options) {}
185 
186  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
187  PatternRewriter &rewriter) const override {
188  // TODO: support 0-d corner case.
189  if (readOp.getTransferRank() == 0)
190  return failure();
191  if (readOp.getMask())
192  return failure();
193  auto targetShape = getTargetShape(options, readOp);
194  if (!targetShape)
195  return failure();
196  auto sourceVectorType = readOp.getVectorType();
197  SmallVector<int64_t> strides(targetShape->size(), 1);
198  Location loc = readOp.getLoc();
199  ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
200 
201  // Prepare the result vector;
202  Value result = rewriter.create<arith::ConstantOp>(
203  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
204  auto targetType =
205  VectorType::get(*targetShape, sourceVectorType.getElementType());
206  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
207  readOp.getIndices().end());
208 
209  SmallVector<int64_t> loopOrder =
210  getUnrollOrder(originalSize.size(), readOp, options);
211  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
212  loopOrder);
213  for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
214  SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
215  SmallVector<Value> indices =
216  sliceTransferIndices(elementOffsets, originalIndices,
217  readOp.getPermutationMap(), loc, rewriter);
218  auto slicedRead = rewriter.create<vector::TransferReadOp>(
219  loc, targetType, readOp.getSource(), indices,
220  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
221  readOp.getInBoundsAttr());
222 
223  result = rewriter.create<vector::InsertStridedSliceOp>(
224  loc, slicedRead, result, elementOffsets, strides);
225  }
226  rewriter.replaceOp(readOp, result);
227  return success();
228  }
229 
230 private:
232 };
233 
234 struct UnrollTransferWritePattern
235  : public OpRewritePattern<vector::TransferWriteOp> {
236  UnrollTransferWritePattern(MLIRContext *context,
238  PatternBenefit benefit = 1)
239  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
240  options(options) {}
241 
242  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
243  PatternRewriter &rewriter) const override {
244  // TODO: support 0-d corner case.
245  if (writeOp.getTransferRank() == 0)
246  return failure();
247 
248  if (writeOp.getMask())
249  return failure();
250  auto targetShape = getTargetShape(options, writeOp);
251  if (!targetShape)
252  return failure();
253  auto sourceVectorType = writeOp.getVectorType();
254  SmallVector<int64_t> strides(targetShape->size(), 1);
255  Location loc = writeOp.getLoc();
256  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
257  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
258  writeOp.getIndices().end());
259 
260  SmallVector<int64_t> loopOrder =
261  getUnrollOrder(originalSize.size(), writeOp, options);
262  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
263  loopOrder);
264  Value resultTensor;
265  for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
266  SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
267  Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
268  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
269  SmallVector<Value> indices =
270  sliceTransferIndices(elementOffsets, originalIndices,
271  writeOp.getPermutationMap(), loc, rewriter);
272  Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
273  loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
274  indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
275  // For the tensor case update the destination for the next transfer write.
276  if (!slicedWrite->getResults().empty())
277  resultTensor = slicedWrite->getResult(0);
278  }
279  if (resultTensor)
280  rewriter.replaceOp(writeOp, resultTensor);
281  else
282  rewriter.eraseOp(writeOp);
283  return success();
284  }
285 
286 private:
288 };
289 
290 struct OffsetMapInfo {
291  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
292 
293  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
294 
295  static unsigned getHashValue(const SmallVector<int64_t> &v) {
296  return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
297  }
298 
299  static bool isEqual(const SmallVector<int64_t> &lhs,
300  const SmallVector<int64_t> &rhs) {
301  return lhs == rhs;
302  }
303 };
304 
305 struct UnrollContractionPattern
306  : public OpRewritePattern<vector::ContractionOp> {
307  UnrollContractionPattern(MLIRContext *context,
309  PatternBenefit benefit = 1)
310  : OpRewritePattern<vector::ContractionOp>(context, benefit),
311  options(options) {}
312 
313  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
314  PatternRewriter &rewriter) const override {
315  auto targetShape = getTargetShape(options, contractOp);
316  if (!targetShape)
317  return failure();
318  auto dstVecType = contractOp.getResultType().cast<VectorType>();
319  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
320 
321  Location loc = contractOp.getLoc();
322  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
324  llvm::MapVector<
326  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
327  accCache;
328 
330  contractOp.getIteratorTypes().size(), contractOp, options);
331  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
332  loopOrder);
333  const int64_t sliceCount = indexToOffsets.maxIndex();
334  for (int64_t i = 0; i < sliceCount; i++) {
335  SmallVector<int64_t> offsets = indexToOffsets.getVectorOffset(i);
336  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
337 
338  // Helper to compute the new shape of each operand and extract the slice.
339  auto extractOperand = [&](unsigned index, Value operand,
340  AffineMap permutationMap,
341  ArrayRef<int64_t> operandOffets) {
343  permutationMap, ArrayRef<int64_t>(*targetShape));
344  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
345  slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
346  loc, operand, operandOffets, operandShape, operandStrides);
347  };
348 
349  // Extract the new lhs operand.
350  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
351  SmallVector<int64_t> lhsOffets =
352  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
353  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
354  // If there is a mask associated to lhs, extract it as well.
355  if (slicesOperands.size() > 3)
356  extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
357  lhsOffets);
358 
359  // Extract the new rhs operand.
360  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
361  SmallVector<int64_t> rhsOffets =
362  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
363  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
364  // If there is a mask associated to rhs, extract it as well.
365  if (slicesOperands.size() > 4)
366  extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
367  rhsOffets);
368 
369  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
370  SmallVector<int64_t> accOffets =
371  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
372  // If a version of the accumulator has already been computed, use it
373  // otherwise extract the first version from the original operand.
374  auto accIt = accCache.find(accOffets);
375  if (accIt != accCache.end())
376  slicesOperands[2] = accIt->second;
377  else
378  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
379 
380  SmallVector<int64_t> dstShape =
381  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
382  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
384  rewriter, loc, contractOp, slicesOperands, targetType);
385 
386  SmallVector<int64_t> dstOffets =
387  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
388  // Save the accumulated value untill all the loops are unrolled since
389  // reduction loop keep updating the accumulator.
390  accCache[dstOffets] = newOp->getResult(0);
391  }
392  // Assemble back the accumulator into a single vector.
393  Value result = rewriter.create<arith::ConstantOp>(
394  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
395  for (const auto &it : accCache) {
396  SmallVector<int64_t> dstStrides(it.first.size(), 1);
397  result = rewriter.create<vector::InsertStridedSliceOp>(
398  loc, it.second, result, it.first, dstStrides);
399  }
400  rewriter.replaceOp(contractOp, result);
401  return success();
402  }
403 
404 private:
406 };
407 
408 struct UnrollMultiReductionPattern
409  : public OpRewritePattern<vector::MultiDimReductionOp> {
410  UnrollMultiReductionPattern(MLIRContext *context,
412  PatternBenefit benefit = 1)
413  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
414  options(options) {}
415 
416  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
417  PatternRewriter &rewriter) const override {
418  std::optional<SmallVector<int64_t>> targetShape =
419  getTargetShape(options, reductionOp);
420  if (!targetShape)
421  return failure();
422  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
423  SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
424  llvm::MapVector<
426  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
427  accCache;
428  // Compute shape ratio of 'shape' and 'sizes'.
429  int64_t sliceCount = computeMaxLinearIndex(ratio);
430  Location loc = reductionOp.getLoc();
431 
432  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
433  // of multiples of the targetShape.
434  auto ratioStrides = computeStrides(ratio);
435  for (int64_t i = 0; i < sliceCount; i++) {
436  SmallVector<int64_t> offsets =
437  getVectorOffset(ratioStrides, i, *targetShape);
438 
439  SmallVector<Value> operands;
440  SmallVector<int64_t> operandStrides(offsets.size(), 1);
441  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
442  loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
443  operands.push_back(slicedOperand);
444  SmallVector<int64_t> dstShape;
445  SmallVector<int64_t> destOffset;
446  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
447  if (!reductionOp.isReducedDim(i)) {
448  destOffset.push_back(offsets[i]);
449  dstShape.push_back((*targetShape)[i]);
450  }
451  }
452  Value acc;
453  SmallVector<int64_t> accStrides(destOffset.size(), 1);
454  // If a version of the accumulator has already been computed, use it
455  // otherwise extract the first version from the original operand.
456  auto accIt = accCache.find(destOffset);
457  if (accIt != accCache.end())
458  acc = accIt->second;
459  else
460  acc = rewriter.create<vector::ExtractStridedSliceOp>(
461  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
462  operands.push_back(acc);
463  auto targetType = VectorType::get(
464  dstShape, reductionOp.getSourceVectorType().getElementType());
465  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
466  operands, targetType);
467  Value result = newOp->getResult(0);
468  accCache[destOffset] = result;
469  }
470  // Assemble back the accumulator into a single vector.
471  Value result = rewriter.create<arith::ConstantOp>(
472  loc, reductionOp.getDestType(),
473  rewriter.getZeroAttr(reductionOp.getDestType()));
474  for (const auto &it : accCache) {
475  SmallVector<int64_t> dstStrides(it.first.size(), 1);
476  result = rewriter.create<vector::InsertStridedSliceOp>(
477  loc, it.second, result, it.first, dstStrides);
478  }
479  rewriter.replaceOp(reductionOp, result);
480  return success();
481  }
482 
483 private:
485 };
486 
487 struct UnrollElementwisePattern : public RewritePattern {
488  UnrollElementwisePattern(MLIRContext *context,
490  PatternBenefit benefit = 1)
491  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
492  options(options) {}
493 
494  LogicalResult matchAndRewrite(Operation *op,
495  PatternRewriter &rewriter) const override {
497  return failure();
498  auto targetShape = getTargetShape(options, op);
499  if (!targetShape)
500  return failure();
501  auto dstVecType = op->getResult(0).getType().cast<VectorType>();
502  SmallVector<int64_t> originalSize =
503  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
504  SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
505  int64_t sliceCount = computeMaxLinearIndex(ratio);
506  Location loc = op->getLoc();
507  // Prepare the result vector.
508  Value result = rewriter.create<arith::ConstantOp>(
509  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
510  SmallVector<int64_t> strides(targetShape->size(), 1);
511  VectorType newVecType =
512  VectorType::get(*targetShape, dstVecType.getElementType());
513 
514  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
515  // of multiples of the targetShape.
516  auto ratioStrides = computeStrides(ratio);
517  for (int64_t i = 0; i < sliceCount; i++) {
518  SmallVector<int64_t> offsets =
519  getVectorOffset(ratioStrides, i, *targetShape);
520  SmallVector<Value> extractOperands;
521  for (OpOperand &operand : op->getOpOperands()) {
522  auto vecType = operand.get().getType().template dyn_cast<VectorType>();
523  if (!vecType) {
524  extractOperands.push_back(operand.get());
525  continue;
526  }
527  extractOperands.push_back(
528  rewriter.create<vector::ExtractStridedSliceOp>(
529  loc, operand.get(), offsets, *targetShape, strides));
530  }
532  rewriter, loc, op, extractOperands, newVecType);
533  result = rewriter.create<vector::InsertStridedSliceOp>(
534  loc, newOp->getResult(0), result, offsets, strides);
535  }
536  rewriter.replaceOp(op, result);
537  return success();
538  }
539 
540 private:
542 };
543 
544 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
545  UnrollReductionPattern(MLIRContext *context,
547  PatternBenefit benefit = 1)
548  : OpRewritePattern<vector::ReductionOp>(context, benefit),
549  options(options) {}
550 
551  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
552  PatternRewriter &rewriter) const override {
553  std::optional<SmallVector<int64_t>> targetShape =
554  getTargetShape(options, reductionOp);
555  if (!targetShape)
556  return failure();
557  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
558  auto ratio = *computeShapeRatio(originalSize, *targetShape);
559  int64_t sliceCount = ratio[0];
560 
561  // Create unrolled vector reduction.
562  Location loc = reductionOp.getLoc();
563  Value accumulator = nullptr;
564 
565  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
566  // of multiples of the targetShape.
567  auto ratioStrides = computeStrides(ratio);
568  for (int64_t i = 0; i < sliceCount; ++i) {
569  SmallVector<int64_t> offsets =
570  getVectorOffset(ratioStrides, i, *targetShape);
571  SmallVector<int64_t> strides(offsets.size(), 1);
572  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
573  loc, reductionOp.getVector(), offsets, *targetShape, strides);
575  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
576  Value result = newOp->getResult(0);
577 
578  if (!accumulator) {
579  // This is the first reduction.
580  accumulator = result;
581  } else {
582  // On subsequent reduction, combine with the accumulator.
583  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
584  accumulator, result);
585  }
586  }
587 
588  rewriter.replaceOp(reductionOp, accumulator);
589  return success();
590  }
591 
592 private:
594 };
595 
596 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
597  UnrollTransposePattern(MLIRContext *context,
599  PatternBenefit benefit = 1)
600  : OpRewritePattern<vector::TransposeOp>(context, benefit),
601  options(options) {}
602 
603  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
604  PatternRewriter &rewriter) const override {
605  if (transposeOp.getResultType().getRank() == 0)
606  return failure();
607  auto targetShape = getTargetShape(options, transposeOp);
608  if (!targetShape)
609  return failure();
610  auto originalVectorType = transposeOp.getResultType();
611  SmallVector<int64_t> strides(targetShape->size(), 1);
612  Location loc = transposeOp.getLoc();
613  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
614  SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
615  int64_t sliceCount = computeMaxLinearIndex(ratio);
616  // Prepare the result vector;
617  Value result = rewriter.create<arith::ConstantOp>(
618  loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
619  SmallVector<int64_t> permutation;
620  transposeOp.getTransp(permutation);
621 
622  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
623  // of multiples of the targetShape.
624  auto ratioStrides = computeStrides(ratio);
625  for (int64_t i = 0; i < sliceCount; i++) {
626  SmallVector<int64_t> elementOffsets =
627  getVectorOffset(ratioStrides, i, *targetShape);
628  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
629  SmallVector<int64_t> permutedShape(elementOffsets.size());
630  // Compute the source offsets and shape.
631  for (auto &indices : llvm::enumerate(permutation)) {
632  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
633  permutedShape[indices.value()] = (*targetShape)[indices.index()];
634  }
635  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
636  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
637  strides);
638  Value transposedSlice =
639  rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
640  result = rewriter.create<vector::InsertStridedSliceOp>(
641  loc, transposedSlice, result, elementOffsets, strides);
642  }
643  rewriter.replaceOp(transposeOp, result);
644  return success();
645  }
646 
647 private:
649 };
650 
651 } // namespace
652 
655  PatternBenefit benefit) {
656  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
657  UnrollContractionPattern, UnrollElementwisePattern,
658  UnrollReductionPattern, UnrollMultiReductionPattern,
659  UnrollTransposePattern>(patterns.getContext(), options, benefit);
660 }
static llvm::ManagedStatic< PassManagerOptions > options
static ArrayRef< int64_t > vectorShape(Type type)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static SmallVector< int64_t > getVectorOffset(ArrayRef< int64_t > ratioStrides, int64_t index, ArrayRef< int64_t > targetShape)
During unrolling from originalShape to targetShape return the offset for the slice index.
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:332
MLIRContext * getContext() const
Definition: Builders.h:55
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:306
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class helps build Operations.
Definition: Builders.h:199
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class represents an operand of an operation.
Definition: Value.h:255
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:368
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:198
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:400
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:94
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:344
result_range getResults()
Definition: Operation.h:376
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:365
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
U cast() const
Definition: Types.h:321
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:568
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Given a set of sizes, compute and return the strides (i.e.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Compute and return the multi-dimensional integral ratio of subShape to the trailing dimensions of sha...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Options that control the vector unrolling.