MLIR  16.0.0git
VectorUnrollDistribute.cpp
Go to the documentation of this file.
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include <numeric>
22 
23 #define DEBUG_TYPE "vector-unrolling"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 
28 /// During unrolling from `originalShape` to `targetShape` return the offset for
29 /// the slice `index`.
31  ArrayRef<int64_t> targetShape,
32  int64_t index) {
33  SmallVector<int64_t, 4> dstSliceStrides =
34  computeStrides(originalShape, targetShape);
35  SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
36  SmallVector<int64_t, 4> elementOffsets =
37  computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
38  return elementOffsets;
39 }
40 
41 /// A functor that accomplishes the same thing as `getVectorOffset` but allows
42 /// for reordering the traversal of the dimensions. The order of traversal is
43 /// given in "for loop order" (outer to inner).
44 namespace {
45 class DecomposeShapeIterator {
46 private:
48  SmallVector<int64_t> loopOrder;
49  SmallVector<int64_t> sliceStrides;
50  int64_t maxIndexVal{1};
51 
52 public:
53  DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
54  ArrayRef<int64_t> targetShape,
55  ArrayRef<int64_t> loopOrder)
56  : vectorShape(targetShape.begin(), targetShape.end()),
57  loopOrder(loopOrder.begin(), loopOrder.end()),
58  sliceStrides(originalShape.size()) {
59  assert(originalShape.size() == targetShape.size());
60  assert(loopOrder.size() == targetShape.size());
61 
62  // Compute the count for each dimension.
63  SmallVector<int64_t> sliceDimCounts(originalShape.size());
64  for (unsigned r = 0; r < originalShape.size(); ++r) {
65  sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
66  maxIndexVal *= sliceDimCounts[r];
67  }
68 
69  // Reversing "loop order" gives dimensions from fastest varying to slowest
70  // varying (smallest stride to largest stride).
71  int64_t accum = 1;
72  for (auto idx : llvm::reverse(loopOrder)) {
73  sliceStrides[idx] = accum;
74  accum *= sliceDimCounts[idx];
75  }
76  }
77 
78  // Turn the linear index into a d-tuple based on units of vectors of size
79  // `vectorShape`. The linear index is assumed to represent traversal of the
80  // dimensions based on `order`.
81  SmallVector<int64_t> delinearize(int64_t index) const {
82  // Traverse in for loop order (largest stride to smallest stride).
83  SmallVector<int64_t> vectorOffsets(sliceStrides.size());
84  for (auto idx : loopOrder) {
85  vectorOffsets[idx] = index / sliceStrides[idx];
86  index %= sliceStrides[idx];
87  }
88  return vectorOffsets;
89  }
90 
91  int64_t maxIndex() const { return maxIndexVal; }
92 
93  /// Return the offset within d-tuple based on the ordering given by
94  /// `loopOrder`.
95  SmallVector<int64_t> getVectorOffset(int64_t index) const {
96  SmallVector<int64_t> vectorOffsets = delinearize(index);
97  SmallVector<int64_t> elementOffsets =
98  computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
99  return elementOffsets;
100  }
101 };
102 } // namespace
103 
104 /// Compute the indices of the slice `index` for a tranfer op.
106  ArrayRef<Value> indices,
107  AffineMap permutationMap,
108  Location loc,
109  OpBuilder &builder) {
110  MLIRContext *ctx = builder.getContext();
111  auto isBroadcast = [](AffineExpr expr) {
112  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
113  return constExpr.getValue() == 0;
114  return false;
115  };
116  // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
117  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
118  for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
119  if (isBroadcast(dim.value()))
120  continue;
121  unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
122  auto expr = getAffineDimExpr(0, builder.getContext()) +
123  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
124  auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
125  slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
126  }
127  return slicedIndices;
128 }
129 
130 // Clones `op` into a new operations that takes `operands` and returns
131 // `resultTypes`.
133  Operation *op,
134  ArrayRef<Value> operands,
135  ArrayRef<Type> resultTypes) {
136  return builder.create(loc, op->getName().getIdentifier(), operands,
137  resultTypes, op->getAttrs());
138 }
139 
140 /// Return the target shape for unrolling for the given `op`. Return llvm::None
141 /// if the op shouldn't be or cannot be unrolled.
142 static Optional<SmallVector<int64_t, 4>>
144  if (options.filterConstraint && failed(options.filterConstraint(op)))
145  return llvm::None;
146  assert(options.nativeShape &&
147  "vector unrolling expects the native shape or native"
148  "shape call back function to be set");
149  auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
150  if (!unrollableVectorOp)
151  return llvm::None;
152  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
153  if (!maybeUnrollShape)
154  return llvm::None;
155  Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
156  if (!targetShape)
157  return llvm::None;
158  auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
159  if (!maybeShapeRatio ||
160  llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
161  return llvm::None;
162  return targetShape;
163 }
164 
166 getUnrollOrder(unsigned numLoops, Operation *op,
168  SmallVector<int64_t> loopOrder =
169  llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
170  if (options.traversalOrderCallback != nullptr) {
171  Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
172  if (order) {
173  loopOrder = std::move(*order);
174  }
175  }
176  return loopOrder;
177 }
178 
179 namespace {
180 
181 struct UnrollTransferReadPattern
182  : public OpRewritePattern<vector::TransferReadOp> {
183  UnrollTransferReadPattern(MLIRContext *context,
185  : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
186  options(options) {}
187  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
188  PatternRewriter &rewriter) const override {
189  // TODO: support 0-d corner case.
190  if (readOp.getTransferRank() == 0)
191  return failure();
192  if (readOp.getMask())
193  return failure();
194  auto targetShape = getTargetShape(options, readOp);
195  if (!targetShape)
196  return failure();
197  auto sourceVectorType = readOp.getVectorType();
198  SmallVector<int64_t, 4> strides(targetShape->size(), 1);
199  Location loc = readOp.getLoc();
200  ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
201 
202  // Prepare the result vector;
203  Value result = rewriter.create<arith::ConstantOp>(
204  loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
205  auto targetType =
206  VectorType::get(*targetShape, sourceVectorType.getElementType());
207  SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
208  readOp.getIndices().end());
209 
210  SmallVector<int64_t> loopOrder =
211  getUnrollOrder(originalSize.size(), readOp, options);
212  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
213  loopOrder);
214  for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
215  SmallVector<int64_t, 4> elementOffsets =
216  indexToOffsets.getVectorOffset(i);
217  SmallVector<Value, 4> indices =
218  sliceTransferIndices(elementOffsets, originalIndices,
219  readOp.getPermutationMap(), loc, rewriter);
220  auto slicedRead = rewriter.create<vector::TransferReadOp>(
221  loc, targetType, readOp.getSource(), indices,
222  readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
223  readOp.getInBoundsAttr());
224 
225  result = rewriter.create<vector::InsertStridedSliceOp>(
226  loc, slicedRead, result, elementOffsets, strides);
227  }
228  rewriter.replaceOp(readOp, result);
229  return success();
230  }
231 
232 private:
234 };
235 
236 struct UnrollTransferWritePattern
237  : public OpRewritePattern<vector::TransferWriteOp> {
238  UnrollTransferWritePattern(MLIRContext *context,
240  : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
241  options(options) {}
242  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
243  PatternRewriter &rewriter) const override {
244  // TODO: support 0-d corner case.
245  if (writeOp.getTransferRank() == 0)
246  return failure();
247 
248  if (writeOp.getMask())
249  return failure();
250  auto targetShape = getTargetShape(options, writeOp);
251  if (!targetShape)
252  return failure();
253  auto sourceVectorType = writeOp.getVectorType();
254  SmallVector<int64_t, 4> strides(targetShape->size(), 1);
255  Location loc = writeOp.getLoc();
256  ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
257  SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
258  writeOp.getIndices().end());
259 
260  SmallVector<int64_t> loopOrder =
261  getUnrollOrder(originalSize.size(), writeOp, options);
262  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
263  loopOrder);
264  Value resultTensor;
265  for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
266  SmallVector<int64_t, 4> elementOffsets =
267  indexToOffsets.getVectorOffset(i);
268  Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
269  loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
270  SmallVector<Value, 4> indices =
271  sliceTransferIndices(elementOffsets, originalIndices,
272  writeOp.getPermutationMap(), loc, rewriter);
273  Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
274  loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
275  indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
276  // For the tensor case update the destination for the next transfer write.
277  if (!slicedWrite->getResults().empty())
278  resultTensor = slicedWrite->getResult(0);
279  }
280  if (resultTensor)
281  rewriter.replaceOp(writeOp, resultTensor);
282  else
283  rewriter.eraseOp(writeOp);
284  return success();
285  }
286 
287 private:
289 };
290 
291 struct OffsetMapInfo {
292  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
293 
294  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
295 
296  static unsigned getHashValue(const SmallVector<int64_t> &v) {
297  return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
298  }
299 
300  static bool isEqual(const SmallVector<int64_t> &lhs,
301  const SmallVector<int64_t> &rhs) {
302  return lhs == rhs;
303  }
304 };
305 
306 struct UnrollContractionPattern
307  : public OpRewritePattern<vector::ContractionOp> {
308  UnrollContractionPattern(MLIRContext *context,
310  : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
311  options(options) {}
312 
313  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
314  PatternRewriter &rewriter) const override {
315  auto targetShape = getTargetShape(options, contractOp);
316  if (!targetShape)
317  return failure();
318  auto dstVecType = contractOp.getResultType().cast<VectorType>();
319  SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
320 
321  Location loc = contractOp.getLoc();
322  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323  AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
324  llvm::MapVector<
327  accCache;
328 
329  SmallVector<int64_t> loopOrder = getUnrollOrder(
330  contractOp.getIteratorTypes().size(), contractOp, options);
331  DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
332  loopOrder);
333  const int64_t sliceCount = indexToOffsets.maxIndex();
334  for (int64_t i = 0; i < sliceCount; i++) {
335  SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
336  SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
337 
338  // Helper to coompute the new shape of each operand and extract the slice.
339  auto extractOperand = [&](unsigned index, Value operand,
340  AffineMap permutationMap,
341  ArrayRef<int64_t> operandOffets) {
342  SmallVector<int64_t> operandShape = applyPermutationMap(
343  permutationMap, ArrayRef<int64_t>(*targetShape));
344  SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
345  slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
346  loc, operand, operandOffets, operandShape, operandStrides);
347  };
348 
349  // Extract the new lhs operand.
350  AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
351  SmallVector<int64_t> lhsOffets =
352  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
353  extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
354  // If there is a mask associated to lhs, extract it as well.
355  if (slicesOperands.size() > 3)
356  extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
357  lhsOffets);
358 
359  // Extract the new rhs operand.
360  AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
361  SmallVector<int64_t> rhsOffets =
362  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
363  extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
364  // If there is a mask associated to rhs, extract it as well.
365  if (slicesOperands.size() > 4)
366  extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
367  rhsOffets);
368 
369  AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
370  SmallVector<int64_t> accOffets =
371  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
372  // If a version of the accumulator has already been computed, use it
373  // otherwise extract the first version from the original operand.
374  auto accIt = accCache.find(accOffets);
375  if (accIt != accCache.end())
376  slicesOperands[2] = accIt->second;
377  else
378  extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
379 
380  SmallVector<int64_t> dstShape =
381  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
382  auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
384  rewriter, loc, contractOp, slicesOperands, targetType);
385 
386  SmallVector<int64_t> dstOffets =
387  applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
388  // Save the accumulated value untill all the loops are unrolled since
389  // reduction loop keep updating the accumulator.
390  accCache[dstOffets] = newOp->getResult(0);
391  }
392  // Assemble back the accumulator into a single vector.
393  Value result = rewriter.create<arith::ConstantOp>(
394  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
395  for (const auto &it : accCache) {
396  SmallVector<int64_t> dstStrides(it.first.size(), 1);
397  result = rewriter.create<vector::InsertStridedSliceOp>(
398  loc, it.second, result, it.first, dstStrides);
399  }
400  rewriter.replaceOp(contractOp, result);
401  return success();
402  }
403 
404 private:
406 };
407 
408 struct UnrollMultiReductionPattern
409  : public OpRewritePattern<vector::MultiDimReductionOp> {
410  UnrollMultiReductionPattern(MLIRContext *context,
412  : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
413  options(options) {}
414 
415  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
416  PatternRewriter &rewriter) const override {
417  Optional<SmallVector<int64_t, 4>> targetShape =
418  getTargetShape(options, reductionOp);
419  if (!targetShape)
420  return failure();
421  SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
422  SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
423  llvm::MapVector<
426  accCache;
427  // Compute shape ratio of 'shape' and 'sizes'.
428  int64_t sliceCount = computeMaxLinearIndex(ratio);
429  Location loc = reductionOp.getLoc();
430  for (int64_t i = 0; i < sliceCount; i++) {
431  SmallVector<int64_t, 4> offsets =
432  getVectorOffset(originalSize, *targetShape, i);
433 
434  SmallVector<Value> operands;
435  SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
436  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
437  loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
438  operands.push_back(slicedOperand);
439  SmallVector<int64_t> dstShape;
440  SmallVector<int64_t> destOffset;
441  for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
442  if (!reductionOp.isReducedDim(i)) {
443  destOffset.push_back(offsets[i]);
444  dstShape.push_back((*targetShape)[i]);
445  }
446  }
447  Value acc;
448  SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
449  // If a version of the accumulator has already been computed, use it
450  // otherwise extract the first version from the original operand.
451  auto accIt = accCache.find(destOffset);
452  if (accIt != accCache.end())
453  acc = accIt->second;
454  else
455  acc = rewriter.create<vector::ExtractStridedSliceOp>(
456  loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
457  operands.push_back(acc);
458  auto targetType = VectorType::get(
459  dstShape, reductionOp.getSourceVectorType().getElementType());
460  Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
461  operands, targetType);
462  Value result = newOp->getResult(0);
463  accCache[destOffset] = result;
464  }
465  // Assemble back the accumulator into a single vector.
466  Value result = rewriter.create<arith::ConstantOp>(
467  loc, reductionOp.getDestType(),
468  rewriter.getZeroAttr(reductionOp.getDestType()));
469  for (const auto &it : accCache) {
470  SmallVector<int64_t> dstStrides(it.first.size(), 1);
471  result = rewriter.create<vector::InsertStridedSliceOp>(
472  loc, it.second, result, it.first, dstStrides);
473  }
474  rewriter.replaceOp(reductionOp, result);
475  return success();
476  }
477 
478 private:
480 };
481 
482 struct UnrollElementwisePattern : public RewritePattern {
483  UnrollElementwisePattern(MLIRContext *context,
485  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
486  options(options) {}
487  LogicalResult matchAndRewrite(Operation *op,
488  PatternRewriter &rewriter) const override {
490  return failure();
491  auto targetShape = getTargetShape(options, op);
492  if (!targetShape)
493  return failure();
494  auto dstVecType = op->getResult(0).getType().cast<VectorType>();
495  SmallVector<int64_t, 4> originalSize =
496  *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
497  SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
498  int64_t sliceCount = computeMaxLinearIndex(ratio);
499  Location loc = op->getLoc();
500  // Prepare the result vector.
501  Value result = rewriter.create<arith::ConstantOp>(
502  loc, dstVecType, rewriter.getZeroAttr(dstVecType));
503  SmallVector<int64_t, 4> strides(targetShape->size(), 1);
504  VectorType newVecType =
505  VectorType::get(*targetShape, dstVecType.getElementType());
506  for (int64_t i = 0; i < sliceCount; i++) {
507  SmallVector<int64_t, 4> offsets =
508  getVectorOffset(originalSize, *targetShape, i);
509  SmallVector<Value, 4> extractOperands;
510  for (OpOperand &operand : op->getOpOperands()) {
511  auto vecType = operand.get().getType().template dyn_cast<VectorType>();
512  if (!vecType) {
513  extractOperands.push_back(operand.get());
514  continue;
515  }
516  extractOperands.push_back(
517  rewriter.create<vector::ExtractStridedSliceOp>(
518  loc, operand.get(), offsets, *targetShape, strides));
519  }
521  rewriter, loc, op, extractOperands, newVecType);
522  result = rewriter.create<vector::InsertStridedSliceOp>(
523  loc, newOp->getResult(0), result, offsets, strides);
524  }
525  rewriter.replaceOp(op, result);
526  return success();
527  }
528 
529 private:
531 };
532 
533 /// Canonicalize an extract_map using the result of a pointwise operation.
534 /// Transforms:
535 /// %v = arith.addf %a, %b : vector32xf32>
536 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
537 /// to:
538 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
539 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
540 /// %dv = arith.addf %da, %db : vector<1xf32>
541 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
543  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
544  PatternRewriter &rewriter) const override {
545  Operation *definedOp = extract.getVector().getDefiningOp();
546  if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
547  definedOp->getNumResults() != 1)
548  return failure();
549  Location loc = extract.getLoc();
550  SmallVector<Value, 4> extractOperands;
551  for (OpOperand &operand : definedOp->getOpOperands()) {
552  auto vecType = operand.get().getType().template dyn_cast<VectorType>();
553  if (!vecType) {
554  extractOperands.push_back(operand.get());
555  continue;
556  }
557  extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
558  loc,
559  VectorType::get(extract.getResultType().getShape(),
560  vecType.getElementType()),
561  operand.get(), extract.getIds()));
562  }
564  rewriter, loc, definedOp, extractOperands, extract.getResultType());
565  rewriter.replaceOp(extract, newOp->getResult(0));
566  return success();
567  }
568 };
569 
570 /// Canonicalize an extract_map using the result of a contract operation.
571 /// This propagate the extract_map to operands.
572 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
574  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
575  PatternRewriter &rewriter) const override {
576  Operation *definedOp = extract.getVector().getDefiningOp();
577  auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
578  if (!contract)
579  return failure();
580  Location loc = contract.getLoc();
581  unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
582  AffineMap affineMap = contract.getIndexingMapsArray()[accIndex];
583  // Create a map of the dimensions distributed based on the acc affine map.
584  // Only parallel dimensions are being distributed, reduction dimensions are
585  // untouched.
587  for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
588  map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
589  SmallVector<Value, 4> extractOperands;
590  for (const auto &it : llvm::enumerate(contract.getIndexingMapsArray())) {
591  // For each operands calculate the new vector type after distribution.
592  Value operand = contract->getOperand(it.index());
593  auto vecType = operand.getType().cast<VectorType>();
594  SmallVector<int64_t> operandShape(vecType.getShape().begin(),
595  vecType.getShape().end());
596  for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
597  unsigned dim = it.value().getDimPosition(i);
598  auto distributedDim = map.find(dim);
599  // If the dimension is not in the map it means it is a reduction and
600  // doesn't get distributed.
601  if (distributedDim == map.end())
602  continue;
603  operandShape[i] = distributedDim->second;
604  }
605  VectorType newVecType =
606  VectorType::get(operandShape, vecType.getElementType());
607  extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
608  loc, newVecType, operand, extract.getIds()));
609  }
610  Operation *newOp =
611  cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
612  extract.getResult().getType());
613  rewriter.replaceOp(extract, newOp->getResult(0));
614  return success();
615  }
616 };
617 
618 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
619 /// TransferRead.
620 /// Example:
621 /// ```
622 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
623 /// memref<64x64x64xf32>, vector<64x4x32xf32>
624 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
625 /// ```
626 /// to:
627 /// ```
628 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
629 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
630 /// memref<64x64x64xf32>, vector<2x4x1xf32>
631 /// ```
632 struct TransferReadExtractPattern
633  : public OpRewritePattern<vector::TransferReadOp> {
634  TransferReadExtractPattern(MLIRContext *context)
636  LogicalResult matchAndRewrite(vector::TransferReadOp read,
637  PatternRewriter &rewriter) const override {
638  // TODO: support 0-d corner case.
639  if (read.getTransferRank() == 0)
640  return failure();
641 
642  if (!read.getResult().hasOneUse())
643  return failure();
644  auto extract =
645  dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
646  if (!extract)
647  return failure();
648  if (read.getMask())
649  return failure();
650 
651  SmallVector<Value, 4> indices(read.getIndices().begin(),
652  read.getIndices().end());
653  AffineMap indexMap = extract.map().compose(read.getPermutationMap());
654  unsigned idCount = 0;
655  ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
656  for (auto it :
657  llvm::zip(indexMap.getResults(), extract.map().getResults())) {
658  AffineExpr d0, d1;
659  bindDims(read.getContext(), d0, d1);
660  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
661  if (!indexExpr)
662  continue;
663  unsigned indexPos = indexExpr.getPosition();
664  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
665  auto scale = getAffineConstantExpr(
666  extract.getResultType().getDimSize(vectorPos), read.getContext());
667  indices[indexPos] = makeComposedAffineApply(
668  rewriter, read.getLoc(), d0 + scale * d1,
669  {indices[indexPos], extract.getIds()[idCount++]});
670  }
671  Value newRead = lb.create<vector::TransferReadOp>(
672  extract.getType(), read.getSource(), indices,
673  read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
674  read.getInBoundsAttr());
675  Value dest = lb.create<arith::ConstantOp>(
676  read.getType(), rewriter.getZeroAttr(read.getType()));
677  newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
678  rewriter.replaceOp(read, newRead);
679  return success();
680  }
681 };
682 
683 struct TransferWriteInsertPattern
684  : public OpRewritePattern<vector::TransferWriteOp> {
685  TransferWriteInsertPattern(MLIRContext *context)
687  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
688  PatternRewriter &rewriter) const override {
689  // TODO: support 0-d corner case.
690  if (write.getTransferRank() == 0)
691  return failure();
692 
693  auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
694  if (!insert)
695  return failure();
696  if (write.getMask())
697  return failure();
698  SmallVector<Value, 4> indices(write.getIndices().begin(),
699  write.getIndices().end());
700  AffineMap indexMap = insert.map().compose(write.getPermutationMap());
701  unsigned idCount = 0;
702  Location loc = write.getLoc();
703  for (auto it :
704  llvm::zip(indexMap.getResults(), insert.map().getResults())) {
705  AffineExpr d0, d1;
706  bindDims(write.getContext(), d0, d1);
707  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
708  if (!indexExpr)
709  continue;
710  unsigned indexPos = indexExpr.getPosition();
711  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
712  auto scale = getAffineConstantExpr(
713  insert.getSourceVectorType().getDimSize(vectorPos),
714  write.getContext());
715  indices[indexPos] = makeComposedAffineApply(
716  rewriter, loc, d0 + scale * d1,
717  {indices[indexPos], insert.getIds()[idCount++]});
718  }
719  rewriter.create<vector::TransferWriteOp>(
720  loc, insert.getVector(), write.getSource(), indices,
721  write.getPermutationMapAttr(), write.getInBoundsAttr());
722  rewriter.eraseOp(write);
723  return success();
724  }
725 };
726 
727 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
728  UnrollReductionPattern(MLIRContext *context,
730  : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
731  options(options) {}
732 
733  LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
734  PatternRewriter &rewriter) const override {
735  Optional<SmallVector<int64_t, 4>> targetShape =
736  getTargetShape(options, reductionOp);
737  if (!targetShape)
738  return failure();
739  SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
740  int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
741 
742  // Create unrolled vector reduction.
743  Location loc = reductionOp.getLoc();
744  Value accumulator = nullptr;
745  for (int64_t i = 0; i < ratio; ++i) {
746  SmallVector<int64_t> offsets =
747  getVectorOffset(originalSize, *targetShape, i);
748  SmallVector<int64_t> strides(offsets.size(), 1);
749  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
750  loc, reductionOp.getVector(), offsets, *targetShape, strides);
752  rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
753  Value result = newOp->getResult(0);
754 
755  if (!accumulator) {
756  // This is the first reduction.
757  accumulator = result;
758  } else {
759  // On subsequent reduction, combine with the accumulator.
760  accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
761  accumulator, result);
762  }
763  }
764 
765  rewriter.replaceOp(reductionOp, accumulator);
766  return success();
767  }
768 
769 private:
771 };
772 
773 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
774  UnrollTranposePattern(MLIRContext *context,
776  : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
777  options(options) {}
778  LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
779  PatternRewriter &rewriter) const override {
780  if (tranposeOp.getResultType().getRank() == 0)
781  return failure();
782  auto targetShape = getTargetShape(options, tranposeOp);
783  if (!targetShape)
784  return failure();
785  auto originalVectorType = tranposeOp.getResultType();
786  SmallVector<int64_t, 4> strides(targetShape->size(), 1);
787  Location loc = tranposeOp.getLoc();
788  ArrayRef<int64_t> originalSize = originalVectorType.getShape();
789  SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
790  int64_t sliceCount = computeMaxLinearIndex(ratio);
791  // Prepare the result vector;
792  Value result = rewriter.create<arith::ConstantOp>(
793  loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
794  SmallVector<int64_t> permutation;
795  tranposeOp.getTransp(permutation);
796  for (int64_t i = 0; i < sliceCount; i++) {
797  SmallVector<int64_t, 4> elementOffsets =
798  getVectorOffset(originalSize, *targetShape, i);
799  SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
800  SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
801  // Compute the source offsets and shape.
802  for (auto &indices : llvm::enumerate(permutation)) {
803  permutedOffsets[indices.value()] = elementOffsets[indices.index()];
804  permutedShape[indices.value()] = (*targetShape)[indices.index()];
805  }
806  Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
807  loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
808  Value tranposedSlice =
809  rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
810  result = rewriter.create<vector::InsertStridedSliceOp>(
811  loc, tranposedSlice, result, elementOffsets, strides);
812  }
813  rewriter.replaceOp(tranposeOp, result);
814  return success();
815  }
816 
817 private:
819 };
820 
821 } // namespace
822 
824  RewritePatternSet &patterns, const UnrollVectorOptions &options) {
825  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
826  UnrollContractionPattern, UnrollElementwisePattern,
827  UnrollReductionPattern, UnrollMultiReductionPattern,
828  UnrollTranposePattern>(patterns.getContext(), options);
829 }
830 
832  RewritePatternSet &patterns) {
833  patterns.add<PointwiseExtractPattern, ContractExtractPattern,
834  TransferReadExtractPattern, TransferWriteInsertPattern>(
835  patterns.getContext());
836 }
Include the generated interface declarations.
SmallVector< int64_t, 4 > computeStrides(ArrayRef< int64_t > shape, ArrayRef< int64_t > sizes)
Given the shape and sizes of a vector, returns the corresponding strides for each dimension...
Definition: VectorUtils.cpp:54
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Optional< SmallVector< int64_t, 4 > > shapeRatio(ArrayRef< int64_t > superShape, ArrayRef< int64_t > subShape)
Computes and returns the multi-dimensional ratio of superShape to subShape.
Definition: VectorUtils.cpp:77
MLIRContext * getContext() const
Definition: Builders.h:54
SmallVector< int64_t, 4 > computeElementOffsetsFromVectorSliceOffsets(ArrayRef< int64_t > sizes, ArrayRef< int64_t > vectorOffsets)
Given the target sizes of a vector, together with vector-space offsets, returns the element-space off...
Definition: VectorUtils.cpp:69
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:514
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1122
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:798
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
static ArrayRef< int64_t > vectorShape(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Options that control the vector unrolling.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation. ...
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:548
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR&#39;s ceildiv operation on constants.
Definition: MathExtras.h:23
UnrollTraversalOrderFnType traversalOrderCallback
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
Base type for affine expression.
Definition: AffineExpr.h:68
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
unsigned getNumResults() const
Definition: AffineMap.cpp:302
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis, 0 if empty.
Definition: VectorUtils.cpp:47
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
static Optional< SmallVector< int64_t, 4 > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:315
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
static llvm::ManagedStatic< PassManagerOptions > options
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to propagate insert_map/extract_map in the ssa chain.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Type getType() const
Return the type of this value.
Definition: Value.h:118
Do not split vector transfer operations.
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class represents an operand of an operation.
Definition: Value.h:251
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
static SmallVector< int64_t, 4 > getVectorOffset(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > targetShape, int64_t index)
During unrolling from originalShape to targetShape return the offset for the slice index...
This class helps build Operations.
Definition: Builders.h:192
MLIRContext * getContext() const
U cast() const
Definition: Types.h:278