MLIR  17.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include <numeric>
22 #include <optional>
23 
24 #define DEBUG_TYPE "vector-unrolling"
25 
26 using namespace mlir;
27 using namespace mlir::vector;
28 
29 /// During unrolling from `originalShape` to `targetShape` return the offset for
30 /// the slice `index`.
32  int64_t index,
33  ArrayRef<int64_t> targetShape) {
34  return computeElementwiseMul(delinearize(index, ratioStrides), targetShape);
35 }
36 
37 /// A functor that accomplishes the same thing as `getVectorOffset` but
38 /// allows for reordering the traversal of the dimensions. The order of
39 /// traversal is given in "for loop order" (outer to inner).
40 namespace {
41 class DecomposeShapeIterator {
42 private:
44  SmallVector<int64_t> loopOrder;
45  SmallVector<int64_t> sliceStrides;
46  int64_t maxIndexVal{1};
47 
48 public:
49  DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
50  ArrayRef<int64_t> targetShape,
51  ArrayRef<int64_t> loopOrder)
52  : vectorShape(targetShape.begin(), targetShape.end()),
53  loopOrder(loopOrder.begin(), loopOrder.end()),
54  sliceStrides(originalShape.size()) {
55  assert(originalShape.size() >= targetShape.size());
56  assert(loopOrder.size() == originalShape.size());
57 
58  // Compute the count for each dimension.
59  auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape);
60  assert(maybeShapeRatio && "Shape does not evenly divide");
61  // Pad `sliceDimCounts` with leading 1s so that all sizes match.
62  SmallVector<int64_t> sliceDimCounts = *maybeShapeRatio;
63  maxIndexVal = computeMaxLinearIndex(sliceDimCounts);
64 
65  // Reversing "loop order" gives dimensions from fastest varying to slowest
66  // varying (smallest stride to largest stride).
67  int64_t accum = 1;
68  for (auto idx : llvm::reverse(loopOrder)) {
69  sliceStrides[idx] = accum;
70  accum *= sliceDimCounts[idx];
71  }
72  }
73 
74  // Turn the linear index into a d-tuple based on units of vectors of size
75  // `vectorShape`. The linear index is assumed to represent traversal of the
76  // dimensions based on `order`.
77  SmallVector<int64_t> delinearize(int64_t index) const {
78  // Traverse in for loop order (largest stride to smallest stride).
79  SmallVector<int64_t> vectorOffsets(sliceStrides.size());
80  for (auto idx : loopOrder) {
81  vectorOffsets[idx] = index / sliceStrides[idx];
82  index %= sliceStrides[idx];
83  }
84  return vectorOffsets;
85  }
86 
87  int64_t maxIndex() const { return maxIndexVal; }
88 
89  /// Return the offset within d-tuple based on the ordering given by
90  /// `loopOrder`.
91  SmallVector<int64_t> getVectorOffset(int64_t index) const {
92  SmallVector<int64_t> vectorOffsets = delinearize(index);
93  SmallVector<int64_t> elementOffsets =
94  computeElementwiseMul(vectorShape, vectorOffsets);
95  return elementOffsets;
96  }
97 };
98 } // namespace
99 
100 /// Compute the indices of the slice `index` for a tranfer op.
102  ArrayRef<Value> indices,
103  AffineMap permutationMap,
104  Location loc,
105  OpBuilder &builder) {
106  MLIRContext *ctx = builder.getContext();
107  auto isBroadcast = [](AffineExpr expr) {
108  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
109  return constExpr.getValue() == 0;
110  return false;
111  };
112  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
113  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
114  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
115  if (isBroadcast(dim.value()))
116  continue;
117  unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
118  auto expr = getAffineDimExpr(0, builder.getContext()) +
119  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
120  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
121  slicedIndices[pos] =
122  builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
123  }
124  return slicedIndices;
125 }
126 
127 // Clones `op` into a new operations that takes `operands` and returns
128 // `resultTypes`.
130  Operation *op,
131  ArrayRef<Value> operands,
132  ArrayRef<Type> resultTypes) {
133  return builder.create(loc, op->getName().getIdentifier(), operands,
134  resultTypes, op->getAttrs());
135 }
136 
137 /// Return the target shape for unrolling for the given `op`. Return
138 /// std::nullopt if the op shouldn't be or cannot be unrolled.
139 static std::optional<SmallVector<int64_t>>
141  if (options.filterConstraint && failed(options.filterConstraint(op)))
142  return std::nullopt;
143  assert(options.nativeShape &&
144  "vector unrolling expects the native shape or native"
145  "shape call back function to be set");
146  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
147  if (!unrollableVectorOp)
148  return std::nullopt;
149  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
150  if (!maybeUnrollShape)
151  return std::nullopt;
152  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
153  if (!targetShape)
154  return std::nullopt;
155  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
156  if (!maybeShapeRatio ||
157  llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
158  return std::nullopt;
159  return targetShape;
160 }
161 
163 getUnrollOrder(unsigned numLoops, Operation *op,
165  SmallVector<int64_t> loopOrder =
166  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
167  if (options.traversalOrderCallback != nullptr) {
168  std::optional<SmallVector<int64_t>> order =
169  options.traversalOrderCallback(op);
170  if (order) {
171  loopOrder = std::move(*order);
172  }
173  }
174  return loopOrder;
175 }
176 
177 namespace {
178 
179 struct UnrollTransferReadPattern
180  : public OpRewritePattern<vector::TransferReadOp> {
181  UnrollTransferReadPattern(MLIRContext *context,
183  PatternBenefit benefit = 1)
184  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
185  options(options) {}
186 
187  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
188  PatternRewriter &rewriter) const override {
189  // TODO: support 0-d corner case.
190  if (readOp.getTransferRank() == 0)
191  return failure();
192  if (readOp.getMask())
193  return failure();
194  auto targetShape = getTargetShape(options, readOp);
195  if (!targetShape)
196  return failure();
197  auto sourceVectorType = readOp.getVectorType();
198  SmallVector<int64_t> strides(targetShape->size(), 1);
199  Location loc = readOp.getLoc();
200  ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
201 
202  // Prepare the result vector;
203  Value result = rewriter.create<arith::ConstantOp>(
204  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
205  auto targetType =
206  VectorType::get(*targetShape, sourceVectorType.getElementType());
207  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
208  readOp.getIndices().end());
209 
210  SmallVector<int64_t> loopOrder =
211  getUnrollOrder(originalSize.size(), readOp, options);
212  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
213  loopOrder);
214  for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
215  SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
216  SmallVector<Value> indices =
217  sliceTransferIndices(elementOffsets, originalIndices,
218  readOp.getPermutationMap(), loc, rewriter);
219  auto slicedRead = rewriter.create<vector::TransferReadOp>(
220  loc, targetType, readOp.getSource(), indices,
221  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
222  readOp.getInBoundsAttr());
223 
224  result = rewriter.create<vector::InsertStridedSliceOp>(
225  loc, slicedRead, result, elementOffsets, strides);
226  }
227  rewriter.replaceOp(readOp, result);
228  return success();
229  }
230 
231 private:
233 };
234 
235 struct UnrollTransferWritePattern
236  : public OpRewritePattern<vector::TransferWriteOp> {
237  UnrollTransferWritePattern(MLIRContext *context,
239  PatternBenefit benefit = 1)
240  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
241  options(options) {}
242 
243  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
244  PatternRewriter &rewriter) const override {
245  // TODO: support 0-d corner case.
246  if (writeOp.getTransferRank() == 0)
247  return failure();
248 
249  if (writeOp.getMask())
250  return failure();
251  auto targetShape = getTargetShape(options, writeOp);
252  if (!targetShape)
253  return failure();
254  auto sourceVectorType = writeOp.getVectorType();
255  SmallVector<int64_t> strides(targetShape->size(), 1);
256  Location loc = writeOp.getLoc();
257  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
258  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
259  writeOp.getIndices().end());
260 
261  SmallVector<int64_t> loopOrder =
262  getUnrollOrder(originalSize.size(), writeOp, options);
263  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
264  loopOrder);
265  Value resultTensor;
266  for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
267  SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
268  Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
269  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
270  SmallVector<Value> indices =
271  sliceTransferIndices(elementOffsets, originalIndices,
272  writeOp.getPermutationMap(), loc, rewriter);
273  Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
274  loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
275  indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
276  // For the tensor case update the destination for the next transfer write.
277  if (!slicedWrite->getResults().empty())
278  resultTensor = slicedWrite->getResult(0);
279  }
280  if (resultTensor)
281  rewriter.replaceOp(writeOp, resultTensor);
282  else
283  rewriter.eraseOp(writeOp);
284  return success();
285  }
286 
287 private:
289 };
290 
291 struct OffsetMapInfo {
292  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
293 
294  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
295 
296  static unsigned getHashValue(const SmallVector<int64_t> &v) {
297  return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
298  }
299 
300  static bool isEqual(const SmallVector<int64_t> &lhs,
301  const SmallVector<int64_t> &rhs) {
302  return lhs == rhs;
303  }
304 };
305 
306 struct UnrollContractionPattern
307  : public OpRewritePattern<vector::ContractionOp> {
308  UnrollContractionPattern(MLIRContext *context,
310  PatternBenefit benefit = 1)
311  : OpRewritePattern<vector::ContractionOp>(context, benefit),
312  options(options) {}
313 
314  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
315  PatternRewriter &rewriter) const override {
316  auto targetShape = getTargetShape(options, contractOp);
317  if (!targetShape)
318  return failure();
319  auto dstVecType = cast<VectorType>(contractOp.getResultType());
320  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
321 
322  Location loc = contractOp.getLoc();
323  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
324  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
325  llvm::MapVector<
327  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
328  accCache;
329 
331  contractOp.getIteratorTypes().size(), contractOp, options);
332  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
333  loopOrder);
334  const int64_t sliceCount = indexToOffsets.maxIndex();
335  for (int64_t i = 0; i < sliceCount; i++) {
336  SmallVector<int64_t> offsets = indexToOffsets.getVectorOffset(i);
337  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
338 
339  // Helper to compute the new shape of each operand and extract the slice.
340  auto extractOperand = [&](unsigned index, Value operand,
341  AffineMap permutationMap,
342  ArrayRef<int64_t> operandOffets) {
344  permutationMap, ArrayRef<int64_t>(*targetShape));
345  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
346  slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
347  loc, operand, operandOffets, operandShape, operandStrides);
348  };
349 
350  // Extract the new lhs operand.
351  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
352  SmallVector<int64_t> lhsOffets =
353  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
354  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
355 
356  // Extract the new rhs operand.
357  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
358  SmallVector<int64_t> rhsOffets =
359  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
360  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
361 
362  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
363  SmallVector<int64_t> accOffets =
364  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
365  // If a version of the accumulator has already been computed, use it
366  // otherwise extract the first version from the original operand.
367  auto accIt = accCache.find(accOffets);
368  if (accIt != accCache.end())
369  slicesOperands[2] = accIt->second;
370  else
371  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
372 
373  SmallVector<int64_t> dstShape =
374  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
375  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
377  rewriter, loc, contractOp, slicesOperands, targetType);
378 
379  SmallVector<int64_t> dstOffets =
380  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
381  // Save the accumulated value untill all the loops are unrolled since
382  // reduction loop keep updating the accumulator.
383  accCache[dstOffets] = newOp->getResult(0);
384  }
385  // Assemble back the accumulator into a single vector.
386  Value result = rewriter.create<arith::ConstantOp>(
387  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
388  for (const auto &it : accCache) {
389  SmallVector<int64_t> dstStrides(it.first.size(), 1);
390  result = rewriter.create<vector::InsertStridedSliceOp>(
391  loc, it.second, result, it.first, dstStrides);
392  }
393  rewriter.replaceOp(contractOp, result);
394  return success();
395  }
396 
397 private:
399 };
400 
401 struct UnrollMultiReductionPattern
402  : public OpRewritePattern<vector::MultiDimReductionOp> {
403  UnrollMultiReductionPattern(MLIRContext *context,
405  PatternBenefit benefit = 1)
406  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
407  options(options) {}
408 
409  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
410  PatternRewriter &rewriter) const override {
411  std::optional<SmallVector<int64_t>> targetShape =
412  getTargetShape(options, reductionOp);
413  if (!targetShape)
414  return failure();
415  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
416  SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
417  llvm::MapVector<
419  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
420  accCache;
421  // Compute shape ratio of 'shape' and 'sizes'.
422  int64_t sliceCount = computeMaxLinearIndex(ratio);
423  Location loc = reductionOp.getLoc();
424 
425  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
426  // of multiples of the targetShape.
427  auto ratioStrides = computeStrides(ratio);
428  for (int64_t i = 0; i < sliceCount; i++) {
429  SmallVector<int64_t> offsets =
430  getVectorOffset(ratioStrides, i, *targetShape);
431 
432  SmallVector<Value> operands;
433  SmallVector<int64_t> operandStrides(offsets.size(), 1);
434  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
435  loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
436  operands.push_back(slicedOperand);
437  SmallVector<int64_t> dstShape;
438  SmallVector<int64_t> destOffset;
439  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
440  if (!reductionOp.isReducedDim(i)) {
441  destOffset.push_back(offsets[i]);
442  dstShape.push_back((*targetShape)[i]);
443  }
444  }
445  Value acc;
446  SmallVector<int64_t> accStrides(destOffset.size(), 1);
447  // If a version of the accumulator has already been computed, use it
448  // otherwise extract the first version from the original operand.
449  auto accIt = accCache.find(destOffset);
450  if (accIt != accCache.end())
451  acc = accIt->second;
452  else
453  acc = rewriter.create<vector::ExtractStridedSliceOp>(
454  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
455  operands.push_back(acc);
456  auto targetType = VectorType::get(
457  dstShape, reductionOp.getSourceVectorType().getElementType());
458  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
459  operands, targetType);
460  Value result = newOp->getResult(0);
461  accCache[destOffset] = result;
462  }
463  // Assemble back the accumulator into a single vector.
464  Value result = rewriter.create<arith::ConstantOp>(
465  loc, reductionOp.getDestType(),
466  rewriter.getZeroAttr(reductionOp.getDestType()));
467  for (const auto &it : accCache) {
468  SmallVector<int64_t> dstStrides(it.first.size(), 1);
469  result = rewriter.create<vector::InsertStridedSliceOp>(
470  loc, it.second, result, it.first, dstStrides);
471  }
472  rewriter.replaceOp(reductionOp, result);
473  return success();
474  }
475 
476 private:
478 };
479 
480 struct UnrollElementwisePattern : public RewritePattern {
481  UnrollElementwisePattern(MLIRContext *context,
483  PatternBenefit benefit = 1)
484  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
485  options(options) {}
486 
487  LogicalResult matchAndRewrite(Operation *op,
488  PatternRewriter &rewriter) const override {
490  return failure();
491  auto targetShape = getTargetShape(options, op);
492  if (!targetShape)
493  return failure();
494  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
495  SmallVector<int64_t> originalSize =
496  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
497  SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
498  int64_t sliceCount = computeMaxLinearIndex(ratio);
499  Location loc = op->getLoc();
500  // Prepare the result vector.
501  Value result = rewriter.create<arith::ConstantOp>(
502  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
503  SmallVector<int64_t> strides(targetShape->size(), 1);
504  VectorType newVecType =
505  VectorType::get(*targetShape, dstVecType.getElementType());
506 
507  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
508  // of multiples of the targetShape.
509  auto ratioStrides = computeStrides(ratio);
510  for (int64_t i = 0; i < sliceCount; i++) {
511  SmallVector<int64_t> offsets =
512  getVectorOffset(ratioStrides, i, *targetShape);
513  SmallVector<Value> extractOperands;
514  for (OpOperand &operand : op->getOpOperands()) {
515  auto vecType = dyn_cast<VectorType>(operand.get().getType());
516  if (!vecType) {
517  extractOperands.push_back(operand.get());
518  continue;
519  }
520  extractOperands.push_back(
521  rewriter.create<vector::ExtractStridedSliceOp>(
522  loc, operand.get(), offsets, *targetShape, strides));
523  }
525  rewriter, loc, op, extractOperands, newVecType);
526  result = rewriter.create<vector::InsertStridedSliceOp>(
527  loc, newOp->getResult(0), result, offsets, strides);
528  }
529  rewriter.replaceOp(op, result);
530  return success();
531  }
532 
533 private:
535 };
536 
537 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
538  UnrollReductionPattern(MLIRContext *context,
540  PatternBenefit benefit = 1)
541  : OpRewritePattern<vector::ReductionOp>(context, benefit),
542  options(options) {}
543 
544  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
545  PatternRewriter &rewriter) const override {
546  std::optional<SmallVector<int64_t>> targetShape =
547  getTargetShape(options, reductionOp);
548  if (!targetShape)
549  return failure();
550  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
551  auto ratio = *computeShapeRatio(originalSize, *targetShape);
552  int64_t sliceCount = ratio[0];
553 
554  // Create unrolled vector reduction.
555  Location loc = reductionOp.getLoc();
556  Value accumulator = nullptr;
557 
558  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
559  // of multiples of the targetShape.
560  auto ratioStrides = computeStrides(ratio);
561  for (int64_t i = 0; i < sliceCount; ++i) {
562  SmallVector<int64_t> offsets =
563  getVectorOffset(ratioStrides, i, *targetShape);
564  SmallVector<int64_t> strides(offsets.size(), 1);
565  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
566  loc, reductionOp.getVector(), offsets, *targetShape, strides);
568  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
569  Value result = newOp->getResult(0);
570 
571  if (!accumulator) {
572  // This is the first reduction.
573  accumulator = result;
574  } else {
575  // On subsequent reduction, combine with the accumulator.
576  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
577  accumulator, result);
578  }
579  }
580 
581  rewriter.replaceOp(reductionOp, accumulator);
582  return success();
583  }
584 
585 private:
587 };
588 
589 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
590  UnrollTransposePattern(MLIRContext *context,
592  PatternBenefit benefit = 1)
593  : OpRewritePattern<vector::TransposeOp>(context, benefit),
594  options(options) {}
595 
596  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
597  PatternRewriter &rewriter) const override {
598  if (transposeOp.getResultVectorType().getRank() == 0)
599  return failure();
600  auto targetShape = getTargetShape(options, transposeOp);
601  if (!targetShape)
602  return failure();
603  auto originalVectorType = transposeOp.getResultVectorType();
604  SmallVector<int64_t> strides(targetShape->size(), 1);
605  Location loc = transposeOp.getLoc();
606  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
607  SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
608  int64_t sliceCount = computeMaxLinearIndex(ratio);
609  // Prepare the result vector;
610  Value result = rewriter.create<arith::ConstantOp>(
611  loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
612  SmallVector<int64_t> permutation;
613  transposeOp.getTransp(permutation);
614 
615  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
616  // of multiples of the targetShape.
617  auto ratioStrides = computeStrides(ratio);
618  for (int64_t i = 0; i < sliceCount; i++) {
619  SmallVector<int64_t> elementOffsets =
620  getVectorOffset(ratioStrides, i, *targetShape);
621  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
622  SmallVector<int64_t> permutedShape(elementOffsets.size());
623  // Compute the source offsets and shape.
624  for (auto indices : llvm::enumerate(permutation)) {
625  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
626  permutedShape[indices.value()] = (*targetShape)[indices.index()];
627  }
628  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
629  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
630  strides);
631  Value transposedSlice =
632  rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
633  result = rewriter.create<vector::InsertStridedSliceOp>(
634  loc, transposedSlice, result, elementOffsets, strides);
635  }
636  rewriter.replaceOp(transposeOp, result);
637  return success();
638  }
639 
640 private:
642 };
643 
644 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
645  UnrollGatherPattern(MLIRContext *context,
647  PatternBenefit benefit = 1)
648  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
649  }
650 
651  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
652  PatternRewriter &rewriter) const override {
653  VectorType sourceVectorType = gatherOp.getVectorType();
654  if (sourceVectorType.getRank() == 0)
655  return failure();
656  auto targetShape = getTargetShape(options, gatherOp);
657  if (!targetShape)
658  return failure();
659  SmallVector<int64_t> strides(targetShape->size(), 1);
660  Location loc = gatherOp.getLoc();
661  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
662 
663  // Prepare the result vector;
664  Value result = rewriter.create<arith::ConstantOp>(
665  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
666  auto targetType =
667  VectorType::get(*targetShape, sourceVectorType.getElementType());
668 
669  SmallVector<int64_t> loopOrder =
670  getUnrollOrder(originalSize.size(), gatherOp, options);
671  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
672  loopOrder);
673  for (int64_t i = 0, e = indexToOffsets.maxIndex(); i < e; ++i) {
674  // To get the unrolled gather, extract the same slice based on the
675  // decomposed shape from each of the index, mask, and pass-through
676  // vectors.
677  SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
678  Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
679  loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
680  Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
681  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
682  Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
683  loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
684  auto slicedGather = rewriter.create<vector::GatherOp>(
685  loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
686  indexSubVec, maskSubVec, passThruSubVec);
687 
688  result = rewriter.create<vector::InsertStridedSliceOp>(
689  loc, slicedGather, result, elementOffsets, strides);
690  }
691  rewriter.replaceOp(gatherOp, result);
692  return success();
693  }
694 
695 private:
697 };
698 
699 } // namespace
700 
703  PatternBenefit benefit) {
704  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
705  UnrollContractionPattern, UnrollElementwisePattern,
706  UnrollReductionPattern, UnrollMultiReductionPattern,
707  UnrollTransposePattern, UnrollGatherPattern>(
708  patterns.getContext(), options, benefit);
709 }
static llvm::ManagedStatic< PassManagerOptions > options
static ArrayRef< int64_t > vectorShape(Type type)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static SmallVector< int64_t > getVectorOffset(ArrayRef< int64_t > ratioStrides, int64_t index, ArrayRef< int64_t > targetShape)
During unrolling from originalShape to targetShape return the offset for the slice index.
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:337
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:318
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:202
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:433
This class represents an operand of an operation.
Definition: Value.h:261
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:469
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:700
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1273
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:262
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:610
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:45
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Options that control the vector unrolling.