MLIR  22.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/DebugLog.h"
20 #include "llvm/Support/InterleavedRange.h"
21 #include <optional>
22 
23 #define DEBUG_TYPE "vector-unroll"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 
28 /// Compute the indices of the slice `index` for a transfer op.
30  ArrayRef<Value> indices,
31  AffineMap permutationMap,
32  Location loc,
33  OpBuilder &builder) {
34  MLIRContext *ctx = builder.getContext();
35  auto isBroadcast = [](AffineExpr expr) {
36  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
37  return constExpr.getValue() == 0;
38  return false;
39  };
40  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
41  SmallVector<Value> slicedIndices(indices);
42  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
43  if (isBroadcast(dim.value()))
44  continue;
45  unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
46  auto expr = getAffineDimExpr(0, builder.getContext()) +
47  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
48  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
49  slicedIndices[pos] =
50  affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
51  }
52  return slicedIndices;
53 }
54 
55 // Compute the new indices by adding `offsets` to `originalIndices`.
56 // If m < n (m = offsets.size(), n = originalIndices.size()),
57 // then only the trailing m values in `originalIndices` are updated.
59  Location loc,
60  OperandRange originalIndices,
61  ArrayRef<int64_t> offsets) {
62  assert(offsets.size() <= originalIndices.size() &&
63  "Offsets should not exceed the number of original indices");
64  SmallVector<Value> indices(originalIndices);
65 
66  auto start = indices.size() - offsets.size();
67  for (auto [i, offset] : llvm::enumerate(offsets)) {
68  if (offset != 0) {
69  indices[start + i] = arith::AddIOp::create(
70  rewriter, loc, originalIndices[start + i],
71  arith::ConstantIndexOp::create(rewriter, loc, offset));
72  }
73  }
74  return indices;
75 }
76 
77 // Clones `op` into a new operations that takes `operands` and returns
78 // `resultTypes`.
80  Operation *op,
81  ArrayRef<Value> operands,
82  ArrayRef<Type> resultTypes) {
83  return builder.create(loc, op->getName().getIdentifier(), operands,
84  resultTypes, op->getAttrs());
85 }
86 
87 /// Return the target shape for unrolling for the given `op`. Return
88 /// std::nullopt if the op shouldn't be or cannot be unrolled.
89 static std::optional<SmallVector<int64_t>>
91  LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
92  if (options.filterConstraint && failed(options.filterConstraint(op))) {
93  LDBG() << "--no filter constraint -> BAIL";
94  return std::nullopt;
95  }
96  assert(options.nativeShape &&
97  "vector unrolling expects the native shape or native"
98  "shape call back function to be set");
99  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
100  if (!unrollableVectorOp) {
101  LDBG() << "--not an unrollable op -> BAIL";
102  return std::nullopt;
103  }
104  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
105  if (!maybeUnrollShape) {
106  LDBG() << "--could not get shape of op " << *op << " -> BAIL";
107  return std::nullopt;
108  }
109  LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
110 
111  std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
112  if (!targetShape) {
113  LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
114  return std::nullopt;
115  }
116  LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
117 
118  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
119  if (!maybeShapeRatio) {
120  LDBG() << "--could not compute integral shape ratio -> BAIL";
121  return std::nullopt;
122  }
123  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
124  LDBG() << "--no unrolling needed -> SKIP";
125  return std::nullopt;
126  }
127  LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
128  return targetShape;
129 }
130 
132 getUnrollOrder(unsigned numLoops, Operation *op,
134  SmallVector<int64_t> loopOrder =
135  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
136  if (options.traversalOrderCallback != nullptr) {
137  std::optional<SmallVector<int64_t>> order =
138  options.traversalOrderCallback(op);
139  if (order) {
140  loopOrder = std::move(*order);
141  }
142  }
143  return loopOrder;
144 }
145 
146 namespace {
147 
148 struct UnrollTransferReadPattern
149  : public OpRewritePattern<vector::TransferReadOp> {
150  UnrollTransferReadPattern(MLIRContext *context,
152  PatternBenefit benefit = 1)
153  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
154  options(options) {}
155 
156  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
157  PatternRewriter &rewriter) const override {
158  // TODO: support 0-d corner case.
159  if (readOp.getTransferRank() == 0)
160  return failure();
161  if (readOp.getMask())
162  return failure();
163  auto targetShape = getTargetShape(options, readOp);
164  if (!targetShape)
165  return failure();
166  auto sourceVectorType = readOp.getVectorType();
167  SmallVector<int64_t> strides(targetShape->size(), 1);
168  Location loc = readOp.getLoc();
169  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
170 
171  // Prepare the result vector;
172  Value result =
173  arith::ConstantOp::create(rewriter, loc, sourceVectorType,
174  rewriter.getZeroAttr(sourceVectorType));
175  auto targetType =
176  VectorType::get(*targetShape, sourceVectorType.getElementType());
177  SmallVector<Value> originalIndices(readOp.getIndices().begin(),
178  readOp.getIndices().end());
179  SmallVector<int64_t> loopOrder =
180  getUnrollOrder(originalSize.size(), readOp, options);
181  for (SmallVector<int64_t> elementOffsets :
182  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
183  SmallVector<Value> indices =
184  sliceTransferIndices(elementOffsets, originalIndices,
185  readOp.getPermutationMap(), loc, rewriter);
186  auto slicedRead = vector::TransferReadOp::create(
187  rewriter, loc, targetType, readOp.getBase(), indices,
188  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
189  readOp.getInBoundsAttr());
190 
191  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
192  loc, slicedRead, result, elementOffsets, strides);
193  }
194  rewriter.replaceOp(readOp, result);
195  return success();
196  }
197 
198 private:
200 };
201 
202 struct UnrollTransferWritePattern
203  : public OpRewritePattern<vector::TransferWriteOp> {
204  UnrollTransferWritePattern(MLIRContext *context,
206  PatternBenefit benefit = 1)
207  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
208  options(options) {}
209 
210  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
211  PatternRewriter &rewriter) const override {
212  // TODO: support 0-d corner case.
213  if (writeOp.getTransferRank() == 0)
214  return failure();
215 
216  if (writeOp.getMask())
217  return failure();
218  auto targetShape = getTargetShape(options, writeOp);
219  if (!targetShape)
220  return failure();
221  auto sourceVectorType = writeOp.getVectorType();
222  SmallVector<int64_t> strides(targetShape->size(), 1);
223  Location loc = writeOp.getLoc();
224  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
225  // Bail-out if rank(source) != rank(target). The main limitation here is the
226  // fact that `ExtractStridedSlice` requires the rank for the input and
227  // output to match. If needed, we can relax this later.
228  if (originalSize.size() != targetShape->size())
229  return rewriter.notifyMatchFailure(
230  writeOp,
231  "expected source input vector rank to match target shape rank");
232 
233  SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
234  writeOp.getIndices().end());
235  SmallVector<int64_t> loopOrder =
236  getUnrollOrder(originalSize.size(), writeOp, options);
237  Value resultTensor;
238  for (SmallVector<int64_t> elementOffsets :
239  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
240  Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
241  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
242  SmallVector<Value> indices =
243  sliceTransferIndices(elementOffsets, originalIndices,
244  writeOp.getPermutationMap(), loc, rewriter);
245  Operation *slicedWrite = vector::TransferWriteOp::create(
246  rewriter, loc, slicedVector,
247  resultTensor ? resultTensor : writeOp.getBase(), indices,
248  writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
249  // For the tensor case update the destination for the next transfer write.
250  if (!slicedWrite->getResults().empty())
251  resultTensor = slicedWrite->getResult(0);
252  }
253  if (resultTensor)
254  rewriter.replaceOp(writeOp, resultTensor);
255  else
256  rewriter.eraseOp(writeOp);
257  return success();
258  }
259 
260 private:
262 };
263 
264 struct OffsetMapInfo {
265  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
266 
267  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
268 
269  static unsigned getHashValue(const SmallVector<int64_t> &v) {
270  return static_cast<unsigned>(llvm::hash_combine_range(v));
271  }
272 
273  static bool isEqual(const SmallVector<int64_t> &lhs,
274  const SmallVector<int64_t> &rhs) {
275  return lhs == rhs;
276  }
277 };
278 
279 struct UnrollContractionPattern
280  : public OpRewritePattern<vector::ContractionOp> {
281  UnrollContractionPattern(MLIRContext *context,
283  PatternBenefit benefit = 1)
284  : OpRewritePattern<vector::ContractionOp>(context, benefit),
285  options(options) {}
286 
287  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
288  PatternRewriter &rewriter) const override {
289  auto targetShape = getTargetShape(options, contractOp);
290  if (!targetShape)
291  return failure();
292  auto dstVecType = cast<VectorType>(contractOp.getResultType());
293  SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
294 
295  Location loc = contractOp.getLoc();
296  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
297  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
298  llvm::MapVector<
300  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
301  accCache;
302 
304  contractOp.getIteratorTypes().size(), contractOp, options);
305 
306  for (SmallVector<int64_t> offsets :
307  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
308  SmallVector<Value> slicesOperands(contractOp.getNumOperands());
309 
310  // Helper to compute the new shape of each operand and extract the slice.
311  auto extractOperand = [&](unsigned index, Value operand,
312  AffineMap permutationMap,
313  ArrayRef<int64_t> operandOffets) {
315  permutationMap, ArrayRef<int64_t>(*targetShape));
316  SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
317  slicesOperands[index] =
318  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
319  loc, operand, operandOffets, operandShape, operandStrides);
320  };
321 
322  // Extract the new lhs operand.
323  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
324  SmallVector<int64_t> lhsOffets =
325  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
326  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
327 
328  // Extract the new rhs operand.
329  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
330  SmallVector<int64_t> rhsOffets =
331  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
332  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
333 
334  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
335  SmallVector<int64_t> accOffets =
336  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
337  // If a version of the accumulator has already been computed, use it
338  // otherwise extract the first version from the original operand.
339  auto *accIt = accCache.find(accOffets);
340  if (accIt != accCache.end())
341  slicesOperands[2] = accIt->second;
342  else
343  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
344 
345  SmallVector<int64_t> dstShape =
346  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
347  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
349  rewriter, loc, contractOp, slicesOperands, targetType);
350 
351  SmallVector<int64_t> dstOffets =
352  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
353  // Save the accumulated value untill all the loops are unrolled since
354  // reduction loop keep updating the accumulator.
355  accCache[dstOffets] = newOp->getResult(0);
356  }
357  // Assemble back the accumulator into a single vector.
358  Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
359  rewriter.getZeroAttr(dstVecType));
360  for (const auto &it : accCache) {
361  SmallVector<int64_t> dstStrides(it.first.size(), 1);
362  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
363  loc, it.second, result, it.first, dstStrides);
364  }
365  rewriter.replaceOp(contractOp, result);
366  return success();
367  }
368 
369 private:
371 };
372 
373 struct UnrollMultiReductionPattern
374  : public OpRewritePattern<vector::MultiDimReductionOp> {
375  UnrollMultiReductionPattern(MLIRContext *context,
377  PatternBenefit benefit = 1)
378  : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
379  options(options) {}
380 
381  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
382  PatternRewriter &rewriter) const override {
383  auto resultType = reductionOp->getResult(0).getType();
384  if (resultType.isIntOrFloat()) {
385  return rewriter.notifyMatchFailure(reductionOp,
386  "Unrolling scalars is not supported");
387  }
388  std::optional<SmallVector<int64_t>> targetShape =
389  getTargetShape(options, reductionOp);
390  if (!targetShape)
391  return failure();
392  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
393  llvm::MapVector<
395  llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
396  accCache;
397  Location loc = reductionOp.getLoc();
398 
399  // Stride of the ratios, this gives us the offsets of sliceCount in a basis
400  // of multiples of the targetShape.
401  for (SmallVector<int64_t> offsets :
402  StaticTileOffsetRange(originalSize, *targetShape)) {
403  SmallVector<Value> operands;
404  SmallVector<int64_t> operandStrides(offsets.size(), 1);
405  Value slicedOperand =
406  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
407  loc, reductionOp.getSource(), offsets, *targetShape,
408  operandStrides);
409  operands.push_back(slicedOperand);
410  SmallVector<int64_t> dstShape;
411  SmallVector<int64_t> destOffset;
412  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
413  if (!reductionOp.isReducedDim(i)) {
414  destOffset.push_back(offsets[i]);
415  dstShape.push_back((*targetShape)[i]);
416  }
417  }
418  Value acc;
419  SmallVector<int64_t> accStrides(destOffset.size(), 1);
420  // If a version of the accumulator has already been computed, use it
421  // otherwise extract the first version from the original operand.
422  auto *accIt = accCache.find(destOffset);
423  if (accIt != accCache.end())
424  acc = accIt->second;
425  else
426  acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
427  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
428  operands.push_back(acc);
429  auto targetType = VectorType::get(
430  dstShape, reductionOp.getSourceVectorType().getElementType());
431  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
432  operands, targetType);
433  Value result = newOp->getResult(0);
434  accCache[destOffset] = result;
435  }
436  // Assemble back the accumulator into a single vector.
437  Value result = arith::ConstantOp::create(
438  rewriter, loc, reductionOp.getDestType(),
439  rewriter.getZeroAttr(reductionOp.getDestType()));
440  for (const auto &it : accCache) {
441  SmallVector<int64_t> dstStrides(it.first.size(), 1);
442  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
443  loc, it.second, result, it.first, dstStrides);
444  }
445  rewriter.replaceOp(reductionOp, result);
446  return success();
447  }
448 
449 private:
451 };
452 
453 struct UnrollElementwisePattern : public RewritePattern {
454  UnrollElementwisePattern(MLIRContext *context,
456  PatternBenefit benefit = 1)
457  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
458  options(options) {}
459 
460  LogicalResult matchAndRewrite(Operation *op,
461  PatternRewriter &rewriter) const override {
463  return failure();
464  auto targetShape = getTargetShape(options, op);
465  if (!targetShape)
466  return failure();
467  auto dstVecType = cast<VectorType>(op->getResult(0).getType());
468  SmallVector<int64_t> originalSize =
469  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
470  // Bail-out if rank(source) != rank(target). The main limitation here is the
471  // fact that `ExtractStridedSlice` requires the rank for the input and
472  // output to match. If needed, we can relax this later.
473  if (originalSize.size() != targetShape->size())
474  return rewriter.notifyMatchFailure(
475  op, "expected input vector rank to match target shape rank");
476  Location loc = op->getLoc();
477  // Prepare the result vector.
478  Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
479  rewriter.getZeroAttr(dstVecType));
480  SmallVector<int64_t> strides(targetShape->size(), 1);
481  VectorType newVecType =
482  VectorType::get(*targetShape, dstVecType.getElementType());
483 
484  // Create the unrolled computation.
485  for (SmallVector<int64_t> offsets :
486  StaticTileOffsetRange(originalSize, *targetShape)) {
487  SmallVector<Value> extractOperands;
488  for (OpOperand &operand : op->getOpOperands()) {
489  auto vecType = dyn_cast<VectorType>(operand.get().getType());
490  if (!vecType) {
491  extractOperands.push_back(operand.get());
492  continue;
493  }
494  extractOperands.push_back(
495  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
496  loc, operand.get(), offsets, *targetShape, strides));
497  }
499  rewriter, loc, op, extractOperands, newVecType);
500  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
501  loc, newOp->getResult(0), result, offsets, strides);
502  }
503  rewriter.replaceOp(op, result);
504  return success();
505  }
506 
507 private:
509 };
510 
511 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
512  UnrollReductionPattern(MLIRContext *context,
514  PatternBenefit benefit = 1)
515  : OpRewritePattern<vector::ReductionOp>(context, benefit),
516  options(options) {}
517 
518  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
519  PatternRewriter &rewriter) const override {
520  std::optional<SmallVector<int64_t>> targetShape =
521  getTargetShape(options, reductionOp);
522  if (!targetShape)
523  return failure();
524  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
525 
526  // Create unrolled vector reduction.
527  Location loc = reductionOp.getLoc();
528  Value accumulator = nullptr;
529  for (SmallVector<int64_t> offsets :
530  StaticTileOffsetRange(originalSize, *targetShape)) {
531  SmallVector<int64_t> strides(offsets.size(), 1);
532  Value slicedOperand =
533  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
534  loc, reductionOp.getVector(), offsets, *targetShape, strides);
536  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
537  Value result = newOp->getResult(0);
538 
539  if (!accumulator) {
540  // This is the first reduction.
541  accumulator = result;
542  } else {
543  // On subsequent reduction, combine with the accumulator.
544  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
545  accumulator, result);
546  }
547  }
548 
549  rewriter.replaceOp(reductionOp, accumulator);
550  return success();
551  }
552 
553 private:
555 };
556 
557 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
558  UnrollTransposePattern(MLIRContext *context,
560  PatternBenefit benefit = 1)
561  : OpRewritePattern<vector::TransposeOp>(context, benefit),
562  options(options) {}
563 
564  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
565  PatternRewriter &rewriter) const override {
566  if (transposeOp.getResultVectorType().getRank() == 0)
567  return failure();
568  auto targetShape = getTargetShape(options, transposeOp);
569  if (!targetShape)
570  return failure();
571  auto originalVectorType = transposeOp.getResultVectorType();
572  SmallVector<int64_t> strides(targetShape->size(), 1);
573  Location loc = transposeOp.getLoc();
574  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
575 
576  // Prepare the result vector;
577  Value result =
578  arith::ConstantOp::create(rewriter, loc, originalVectorType,
579  rewriter.getZeroAttr(originalVectorType));
580  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
581 
582  // Unroll the computation.
583  for (SmallVector<int64_t> elementOffsets :
584  StaticTileOffsetRange(originalSize, *targetShape)) {
585  SmallVector<int64_t> permutedOffsets(elementOffsets.size());
586  SmallVector<int64_t> permutedShape(elementOffsets.size());
587  // Compute the source offsets and shape.
588  for (auto indices : llvm::enumerate(permutation)) {
589  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
590  permutedShape[indices.value()] = (*targetShape)[indices.index()];
591  }
592  Value slicedOperand =
593  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
594  loc, transposeOp.getVector(), permutedOffsets, permutedShape,
595  strides);
596  Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
597  loc, slicedOperand, permutation);
598  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
599  loc, transposedSlice, result, elementOffsets, strides);
600  }
601  rewriter.replaceOp(transposeOp, result);
602  return success();
603  }
604 
605 private:
607 };
608 
609 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
610  UnrollGatherPattern(MLIRContext *context,
612  PatternBenefit benefit = 1)
613  : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
614  }
615 
616  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
617  PatternRewriter &rewriter) const override {
618  VectorType sourceVectorType = gatherOp.getVectorType();
619  if (sourceVectorType.getRank() == 0)
620  return failure();
621  auto targetShape = getTargetShape(options, gatherOp);
622  if (!targetShape)
623  return failure();
624  SmallVector<int64_t> strides(targetShape->size(), 1);
625  Location loc = gatherOp.getLoc();
626  ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
627 
628  // Prepare the result vector;
629  Value result =
630  arith::ConstantOp::create(rewriter, loc, sourceVectorType,
631  rewriter.getZeroAttr(sourceVectorType));
632  auto targetType =
633  VectorType::get(*targetShape, sourceVectorType.getElementType());
634 
635  SmallVector<int64_t> loopOrder =
636  getUnrollOrder(originalSize.size(), gatherOp, options);
637  for (SmallVector<int64_t> elementOffsets :
638  StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
639  // To get the unrolled gather, extract the same slice based on the
640  // decomposed shape from each of the index, mask, and pass-through
641  // vectors.
642  Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
643  loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
644  Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
645  loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
646  Value passThruSubVec =
647  rewriter.createOrFold<vector::ExtractStridedSliceOp>(
648  loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
649  strides);
650  auto slicedGather = vector::GatherOp::create(
651  rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
652  indexSubVec, maskSubVec, passThruSubVec);
653 
654  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
655  loc, slicedGather, result, elementOffsets, strides);
656  }
657  rewriter.replaceOp(gatherOp, result);
658  return success();
659  }
660 
661 private:
663 };
664 
665 struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
666  UnrollLoadPattern(MLIRContext *context,
668  PatternBenefit benefit = 1)
669  : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
670 
671  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
672  PatternRewriter &rewriter) const override {
673  VectorType vecType = loadOp.getVectorType();
674 
675  auto targetShape = getTargetShape(options, loadOp);
676  if (!targetShape)
677  return failure();
678 
679  Location loc = loadOp.getLoc();
680  ArrayRef<int64_t> originalShape = vecType.getShape();
681  SmallVector<int64_t> strides(targetShape->size(), 1);
682 
683  Value result = arith::ConstantOp::create(rewriter, loc, vecType,
684  rewriter.getZeroAttr(vecType));
685 
686  SmallVector<int64_t> loopOrder =
687  getUnrollOrder(originalShape.size(), loadOp, options);
688 
689  auto targetVecType =
690  VectorType::get(*targetShape, vecType.getElementType());
691 
692  for (SmallVector<int64_t> offsets :
693  StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
694  SmallVector<Value> indices =
695  sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
696  Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
697  loadOp.getBase(), indices);
698  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
699  loc, slicedLoad, result, offsets, strides);
700  }
701  rewriter.replaceOp(loadOp, result);
702  return success();
703  }
704 
705 private:
707 };
708 
709 struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
710  UnrollStorePattern(MLIRContext *context,
712  PatternBenefit benefit = 1)
713  : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
714 
715  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
716  PatternRewriter &rewriter) const override {
717  VectorType vecType = storeOp.getVectorType();
718 
719  auto targetShape = getTargetShape(options, storeOp);
720  if (!targetShape)
721  return failure();
722 
723  Location loc = storeOp.getLoc();
724  ArrayRef<int64_t> originalShape = vecType.getShape();
725  SmallVector<int64_t> strides(targetShape->size(), 1);
726 
727  Value base = storeOp.getBase();
728  Value vector = storeOp.getValueToStore();
729 
730  SmallVector<int64_t> loopOrder =
731  getUnrollOrder(originalShape.size(), storeOp, options);
732 
733  for (SmallVector<int64_t> offsets :
734  StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
735  SmallVector<Value> indices =
736  sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
737  Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
738  loc, vector, offsets, *targetShape, strides);
739  vector::StoreOp::create(rewriter, loc, slice, base, indices);
740  }
741  rewriter.eraseOp(storeOp);
742  return success();
743  }
744 
745 private:
747 };
748 
749 struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
750  UnrollBroadcastPattern(MLIRContext *context,
752  PatternBenefit benefit = 1)
753  : OpRewritePattern<vector::BroadcastOp>(context, benefit),
754  options(options) {}
755 
756  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
757  PatternRewriter &rewriter) const override {
758  auto targetShape = getTargetShape(options, broadcastOp);
759  if (!targetShape)
760  return failure();
761 
762  Location loc = broadcastOp.getLoc();
763  VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
764  VectorType resType = broadcastOp.getResultVectorType();
765  VectorType targetType =
766  resType.cloneWith(*targetShape, resType.getElementType());
767  Value result = arith::ConstantOp::create(rewriter, loc, resType,
768  rewriter.getZeroAttr(resType));
769 
770  SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
771  SmallVector<int64_t> strides(originalShape.size(), 1);
772 
773  for (SmallVector<int64_t> offsets :
774  StaticTileOffsetRange(originalShape, *targetShape)) {
775  Value newSrc;
776  if (!srcType) {
777  // Scalar to vector broadcast.
778  newSrc = broadcastOp.getSource();
779  } else {
780  // Vector to vector broadcast.
781  int64_t rank = srcType.getRank();
782  SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
783  SmallVector<int64_t> srcShape(targetShape->end() - rank,
784  targetShape->end());
785  SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
786  // adjust the offset and shape for src if the corresponding dim is 1.
787  for (int64_t i = 0; i < rank; ++i) {
788  if (srcType.getDimSize(i) == 1) {
789  srcOffsets[i] = 0;
790  srcShape[i] = 1;
791  }
792  }
793  newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
794  loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
795  }
796 
797  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
798  newSrc, targetType);
799 
800  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
801  loc, newOp->getResult(0), result, offsets, strides);
802  }
803 
804  rewriter.replaceOp(broadcastOp, result);
805  return success();
806  }
807 
808 private:
810 };
811 
812 } // namespace
813 
814 void mlir::vector::populateVectorUnrollPatterns(
816  PatternBenefit benefit) {
817  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
818  UnrollContractionPattern, UnrollElementwisePattern,
819  UnrollReductionPattern, UnrollMultiReductionPattern,
820  UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
821  UnrollStorePattern, UnrollBroadcastPattern>(
822  patterns.getContext(), options, benefit);
823 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents an operand of an operation.
Definition: Value.h:257
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options that control the vector unrolling.