MLIR  19.0.0git
LowerVectorTransfer.cpp
Go to the documentation of this file.
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
24 /// permutation based on the given indices.
25 static ArrayAttr
26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27  const SmallVector<unsigned> &permutation) {
28  SmallVector<bool> newInBoundsValues(permutation.size());
29  size_t index = 0;
30  for (unsigned pos : permutation)
31  newInBoundsValues[pos] =
32  cast<BoolAttr>(attr.getValue()[index++]).getValue();
33  return builder.getBoolArrayAttr(newInBoundsValues);
34 }
35 
36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
37 /// dimensions.
38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
39  int64_t addedRank) {
40  auto originalVecType = cast<VectorType>(vec.getType());
41  SmallVector<int64_t> newShape(addedRank, 1);
42  newShape.append(originalVecType.getShape().begin(),
43  originalVecType.getShape().end());
44 
45  SmallVector<bool> newScalableDims(addedRank, false);
46  newScalableDims.append(originalVecType.getScalableDims().begin(),
47  originalVecType.getScalableDims().end());
48  VectorType newVecType = VectorType::get(
49  newShape, originalVecType.getElementType(), newScalableDims);
50  return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
51 }
52 
53 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit
54 /// dimensions.
55 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
56  int64_t addedRank) {
57  Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
58  SmallVector<int64_t> permutation;
59  for (int64_t i = addedRank,
60  e = cast<VectorType>(broadcasted.getType()).getRank();
61  i < e; ++i)
62  permutation.push_back(i);
63  for (int64_t i = 0; i < addedRank; ++i)
64  permutation.push_back(i);
65  return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // populateVectorTransferPermutationMapLoweringPatterns
70 //===----------------------------------------------------------------------===//
71 
72 namespace {
73 /// Lower transfer_read op with permutation into a transfer_read with a
74 /// permutation map composed of leading zeros followed by a minor identiy +
75 /// vector.transpose op.
76 /// Ex:
77 /// vector.transfer_read ...
78 /// permutation_map: (d0, d1, d2) -> (0, d1)
79 /// into:
80 /// %v = vector.transfer_read ...
81 /// permutation_map: (d0, d1, d2) -> (d1, 0)
82 /// vector.transpose %v, [1, 0]
83 ///
84 /// vector.transfer_read ...
85 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
86 /// into:
87 /// %v = vector.transfer_read ...
88 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
89 /// vector.transpose %v, [0, 1, 3, 2, 4]
90 /// Note that an alternative is to transform it to linalg.transpose +
91 /// vector.transfer_read to do the transpose in memory instead.
92 struct TransferReadPermutationLowering
93  : public OpRewritePattern<vector::TransferReadOp> {
95 
96  LogicalResult matchAndRewrite(vector::TransferReadOp op,
97  PatternRewriter &rewriter) const override {
98  // TODO: support 0-d corner case.
99  if (op.getTransferRank() == 0)
100  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
101 
102  SmallVector<unsigned> permutation;
103  AffineMap map = op.getPermutationMap();
104  if (map.getNumResults() == 0)
105  return rewriter.notifyMatchFailure(op, "0 result permutation map");
106  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
107  return rewriter.notifyMatchFailure(
108  op, "map is not permutable to minor identity, apply another pattern");
109  }
110  AffineMap permutationMap =
111  map.getPermutationMap(permutation, op.getContext());
112  if (permutationMap.isIdentity())
113  return rewriter.notifyMatchFailure(op, "map is not identity");
114 
115  permutationMap = map.getPermutationMap(permutation, op.getContext());
116  // Caluclate the map of the new read by applying the inverse permutation.
117  permutationMap = inversePermutation(permutationMap);
118  AffineMap newMap = permutationMap.compose(map);
119  // Apply the reverse transpose to deduce the type of the transfer_read.
120  ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
121  SmallVector<int64_t> newVectorShape(originalShape.size());
122  ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
123  SmallVector<bool> newScalableDims(originalShape.size());
124  for (const auto &pos : llvm::enumerate(permutation)) {
125  newVectorShape[pos.value()] = originalShape[pos.index()];
126  newScalableDims[pos.value()] = originalScalableDims[pos.index()];
127  }
128 
129  // Transpose in_bounds attribute.
130  ArrayAttr newInBoundsAttr =
131  op.getInBounds() ? inverseTransposeInBoundsAttr(
132  rewriter, op.getInBounds().value(), permutation)
133  : ArrayAttr();
134 
135  // Generate new transfer_read operation.
136  VectorType newReadType = VectorType::get(
137  newVectorShape, op.getVectorType().getElementType(), newScalableDims);
138  Value newRead = rewriter.create<vector::TransferReadOp>(
139  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
140  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
141  newInBoundsAttr);
142 
143  // Transpose result of transfer_read.
144  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145  rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
146  transposePerm);
147  return success();
148  }
149 };
150 
151 /// Lower transfer_write op with permutation into a transfer_write with a
152 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
153 /// Ex:
154 /// vector.transfer_write %v ...
155 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
156 /// into:
157 /// %tmp = vector.transpose %v, [2, 0, 1]
158 /// vector.transfer_write %tmp ...
159 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
160 ///
161 /// vector.transfer_write %v ...
162 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
163 /// into:
164 /// %tmp = vector.transpose %v, [1, 0]
165 /// %v = vector.transfer_write %tmp ...
166 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167 struct TransferWritePermutationLowering
168  : public OpRewritePattern<vector::TransferWriteOp> {
170 
171  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
172  PatternRewriter &rewriter) const override {
173  // TODO: support 0-d corner case.
174  if (op.getTransferRank() == 0)
175  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
176 
177  SmallVector<unsigned> permutation;
178  AffineMap map = op.getPermutationMap();
179  if (map.isMinorIdentity())
180  return rewriter.notifyMatchFailure(op, "map is already minor identity");
181 
182  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
183  return rewriter.notifyMatchFailure(
184  op, "map is not permutable to minor identity, apply another pattern");
185  }
186 
187  // Remove unused dims from the permutation map. E.g.:
188  // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
189  // comp = (d0, d1, d2) -> (d2, d0, d1)
190  auto comp = compressUnusedDims(map);
191  AffineMap permutationMap = inversePermutation(comp);
192  // Get positions of remaining result dims.
193  SmallVector<int64_t> indices;
194  llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
195  [](AffineExpr expr) {
196  return dyn_cast<AffineDimExpr>(expr).getPosition();
197  });
198 
199  // Transpose in_bounds attribute.
200  ArrayAttr newInBoundsAttr =
201  op.getInBounds() ? inverseTransposeInBoundsAttr(
202  rewriter, op.getInBounds().value(), permutation)
203  : ArrayAttr();
204 
205  // Generate new transfer_write operation.
206  Value newVec = rewriter.create<vector::TransposeOp>(
207  op.getLoc(), op.getVector(), indices);
208  auto newMap = AffineMap::getMinorIdentityMap(
209  map.getNumDims(), map.getNumResults(), rewriter.getContext());
210  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
211  op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
212  op.getMask(), newInBoundsAttr);
213 
214  return success();
215  }
216 };
217 
218 /// Convert a transfer.write op with a map which isn't the permutation of a
219 /// minor identity into a vector.broadcast + transfer_write with permutation of
220 /// minor identity map by adding unit dim on inner dimension. Ex:
221 /// ```
222 /// vector.transfer_write %v
223 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
224 /// vector<8x16xf32>
225 /// ```
226 /// into:
227 /// ```
228 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
229 /// vector.transfer_write %v1
230 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
231 /// vector<1x8x16xf32>
232 /// ```
233 struct TransferWriteNonPermutationLowering
234  : public OpRewritePattern<vector::TransferWriteOp> {
236 
237  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
238  PatternRewriter &rewriter) const override {
239  // TODO: support 0-d corner case.
240  if (op.getTransferRank() == 0)
241  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
242 
243  SmallVector<unsigned> permutation;
244  AffineMap map = op.getPermutationMap();
245  if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
246  return rewriter.notifyMatchFailure(
247  op,
248  "map is already permutable to minor identity, apply another pattern");
249  }
250 
251  // Missing outer dimensions are allowed, find the most outer existing
252  // dimension then deduce the missing inner dimensions.
253  SmallVector<bool> foundDim(map.getNumDims(), false);
254  for (AffineExpr exp : map.getResults())
255  foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
257  bool foundFirstDim = false;
258  SmallVector<int64_t> missingInnerDim;
259  for (size_t i = 0; i < foundDim.size(); i++) {
260  if (foundDim[i]) {
261  foundFirstDim = true;
262  continue;
263  }
264  if (!foundFirstDim)
265  continue;
266  // Once we found one outer dimension existing in the map keep track of all
267  // the missing dimensions after that.
268  missingInnerDim.push_back(i);
269  exprs.push_back(rewriter.getAffineDimExpr(i));
270  }
271  // Vector: add unit dims at the beginning of the shape.
272  Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
273  missingInnerDim.size());
274  // Mask: add unit dims at the end of the shape.
275  Value newMask;
276  if (op.getMask())
277  newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
278  missingInnerDim.size());
279  exprs.append(map.getResults().begin(), map.getResults().end());
280  AffineMap newMap =
281  AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
282  // All the new dimensions added are inbound.
283  SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
284  for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
285  newInBoundsValues.push_back(op.isDimInBounds(i));
286  }
287  ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
288  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
289  op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
290  newMask, newInBoundsAttr);
291  return success();
292  }
293 };
294 
295 /// Lower transfer_read op with broadcast in the leading dimensions into
296 /// transfer_read of lower rank + vector.broadcast.
297 /// Ex: vector.transfer_read ...
298 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
299 /// into:
300 /// %v = vector.transfer_read ...
301 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
302 /// vector.broadcast %v
303 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
305 
306  LogicalResult matchAndRewrite(vector::TransferReadOp op,
307  PatternRewriter &rewriter) const override {
308  // TODO: support 0-d corner case.
309  if (op.getTransferRank() == 0)
310  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
311 
312  AffineMap map = op.getPermutationMap();
313  unsigned numLeadingBroadcast = 0;
314  for (auto expr : map.getResults()) {
315  auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
316  if (!dimExpr || dimExpr.getValue() != 0)
317  break;
318  numLeadingBroadcast++;
319  }
320  // If there are no leading zeros in the map there is nothing to do.
321  if (numLeadingBroadcast == 0)
322  return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
323 
324  VectorType originalVecType = op.getVectorType();
325  unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
326  // Calculate new map, vector type and masks without the leading zeros.
327  AffineMap newMap = AffineMap::get(
328  map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
329  op.getContext());
330  // Only remove the leading zeros if the rest of the map is a minor identity
331  // with broadasting. Otherwise we first want to permute the map.
332  if (!newMap.isMinorIdentityWithBroadcasting()) {
333  return rewriter.notifyMatchFailure(
334  op, "map is not a minor identity with broadcasting");
335  }
336 
337  // TODO: support zero-dimension vectors natively. See:
338  // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
339  // In the meantime, lower these to a scalar load when they pop up.
340  if (reducedShapeRank == 0) {
341  Value newRead;
342  if (isa<TensorType>(op.getShapedType())) {
343  newRead = rewriter.create<tensor::ExtractOp>(
344  op.getLoc(), op.getSource(), op.getIndices());
345  } else {
346  newRead = rewriter.create<memref::LoadOp>(
347  op.getLoc(), originalVecType.getElementType(), op.getSource(),
348  op.getIndices());
349  }
350  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
351  newRead);
352  return success();
353  }
354 
355  SmallVector<int64_t> newShape(
356  originalVecType.getShape().take_back(reducedShapeRank));
357  SmallVector<bool> newScalableDims(
358  originalVecType.getScalableDims().take_back(reducedShapeRank));
359  // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
360  if (newShape.empty())
361  return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
362 
363  VectorType newReadType = VectorType::get(
364  newShape, originalVecType.getElementType(), newScalableDims);
365  ArrayAttr newInBoundsAttr =
366  op.getInBounds()
367  ? rewriter.getArrayAttr(
368  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
369  : ArrayAttr();
370  Value newRead = rewriter.create<vector::TransferReadOp>(
371  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
372  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
373  newInBoundsAttr);
374  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
375  newRead);
376  return success();
377  }
378 };
379 
380 } // namespace
381 
383  RewritePatternSet &patterns, PatternBenefit benefit) {
384  patterns
385  .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386  TransferOpReduceRank, TransferWriteNonPermutationLowering>(
387  patterns.getContext(), benefit);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // populateVectorTransferLoweringPatterns
392 //===----------------------------------------------------------------------===//
393 
394 namespace {
395 /// Progressive lowering of transfer_read. This pattern supports lowering of
396 /// `vector.transfer_read` to a combination of `vector.load` and
397 /// `vector.broadcast` if all of the following hold:
398 /// - Stride of most minor memref dimension must be 1.
399 /// - Out-of-bounds masking is not required.
400 /// - If the memref's element type is a vector type then it coincides with the
401 /// result type.
402 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
403 struct TransferReadToVectorLoadLowering
404  : public OpRewritePattern<vector::TransferReadOp> {
405  TransferReadToVectorLoadLowering(MLIRContext *context,
406  std::optional<unsigned> maxRank,
407  PatternBenefit benefit = 1)
408  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
409  maxTransferRank(maxRank) {}
410 
411  LogicalResult matchAndRewrite(vector::TransferReadOp read,
412  PatternRewriter &rewriter) const override {
413  if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
414  return rewriter.notifyMatchFailure(
415  read, "vector type is greater than max transfer rank");
416  }
417 
418  SmallVector<unsigned> broadcastedDims;
419  // Permutations are handled by VectorToSCF or
420  // populateVectorTransferPermutationMapLoweringPatterns.
421  // We let the 0-d corner case pass-through as it is supported.
422  if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
423  &broadcastedDims))
424  return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
425 
426  auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
427  if (!memRefType)
428  return rewriter.notifyMatchFailure(read, "not a memref source");
429 
430  // Non-unit strides are handled by VectorToSCF.
431  if (!isLastMemrefDimUnitStride(memRefType))
432  return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
433 
434  // If there is broadcasting involved then we first load the unbroadcasted
435  // vector, and then broadcast it with `vector.broadcast`.
436  ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
437  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
438  vectorShape.end());
439  for (unsigned i : broadcastedDims)
440  unbroadcastedVectorShape[i] = 1;
441  VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
442  unbroadcastedVectorShape, read.getVectorType().getElementType());
443 
444  // `vector.load` supports vector types as memref's elements only when the
445  // resulting vector type is the same as the element type.
446  auto memrefElTy = memRefType.getElementType();
447  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
448  return rewriter.notifyMatchFailure(read, "incompatible element type");
449 
450  // Otherwise, element types of the memref and the vector must match.
451  if (!isa<VectorType>(memrefElTy) &&
452  memrefElTy != read.getVectorType().getElementType())
453  return rewriter.notifyMatchFailure(read, "non-matching element type");
454 
455  // Out-of-bounds dims are handled by MaterializeTransferMask.
456  if (read.hasOutOfBoundsDim())
457  return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
458 
459  // Create vector load op.
460  Operation *loadOp;
461  if (read.getMask()) {
462  if (read.getVectorType().getRank() != 1)
463  // vector.maskedload operates on 1-D vectors.
464  return rewriter.notifyMatchFailure(
465  read, "vector type is not rank 1, can't create masked load, needs "
466  "VectorToSCF");
467 
468  Value fill = rewriter.create<vector::SplatOp>(
469  read.getLoc(), unbroadcastedVectorType, read.getPadding());
470  loadOp = rewriter.create<vector::MaskedLoadOp>(
471  read.getLoc(), unbroadcastedVectorType, read.getSource(),
472  read.getIndices(), read.getMask(), fill);
473  } else {
474  loadOp = rewriter.create<vector::LoadOp>(
475  read.getLoc(), unbroadcastedVectorType, read.getSource(),
476  read.getIndices());
477  }
478 
479  // Insert a broadcasting op if required.
480  if (!broadcastedDims.empty()) {
481  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
482  read, read.getVectorType(), loadOp->getResult(0));
483  } else {
484  rewriter.replaceOp(read, loadOp->getResult(0));
485  }
486 
487  return success();
488  }
489 
490  std::optional<unsigned> maxTransferRank;
491 };
492 
493 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
494 // TODO: we shouldn't cross the vector/scalar domains just for this
495 // but atm we lack the infra to avoid it. Possible solutions include:
496 // - go directly to LLVM + bitcast
497 // - introduce a bitcast op and likely a new pointer dialect
498 // - let memref.load/store additionally support the 0-d vector case
499 // There are still deeper data layout issues lingering even in this
500 // trivial case (for architectures for which this matters).
501 struct VectorLoadToMemrefLoadLowering
502  : public OpRewritePattern<vector::LoadOp> {
504 
505  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
506  PatternRewriter &rewriter) const override {
507  auto vecType = loadOp.getVectorType();
508  if (vecType.getNumElements() != 1)
509  return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
510 
511  auto memrefLoad = rewriter.create<memref::LoadOp>(
512  loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
513  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
514  memrefLoad);
515  return success();
516  }
517 };
518 
519 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
520 struct VectorStoreToMemrefStoreLowering
521  : public OpRewritePattern<vector::StoreOp> {
523 
524  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
525  PatternRewriter &rewriter) const override {
526  auto vecType = storeOp.getVectorType();
527  if (vecType.getNumElements() != 1)
528  return rewriter.notifyMatchFailure(storeOp, "not single element vector");
529 
530  Value extracted;
531  if (vecType.getRank() == 0) {
532  // TODO: Unifiy once ExtractOp supports 0-d vectors.
533  extracted = rewriter.create<vector::ExtractElementOp>(
534  storeOp.getLoc(), storeOp.getValueToStore());
535  } else {
536  SmallVector<int64_t> indices(vecType.getRank(), 0);
537  extracted = rewriter.create<vector::ExtractOp>(
538  storeOp.getLoc(), storeOp.getValueToStore(), indices);
539  }
540 
541  rewriter.replaceOpWithNewOp<memref::StoreOp>(
542  storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
543  return success();
544  }
545 };
546 
547 /// Progressive lowering of transfer_write. This pattern supports lowering of
548 /// `vector.transfer_write` to `vector.store` if all of the following hold:
549 /// - Stride of most minor memref dimension must be 1.
550 /// - Out-of-bounds masking is not required.
551 /// - If the memref's element type is a vector type then it coincides with the
552 /// type of the written value.
553 /// - The permutation map is the minor identity map (neither permutation nor
554 /// broadcasting is allowed).
555 struct TransferWriteToVectorStoreLowering
556  : public OpRewritePattern<vector::TransferWriteOp> {
557  TransferWriteToVectorStoreLowering(MLIRContext *context,
558  std::optional<unsigned> maxRank,
559  PatternBenefit benefit = 1)
560  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
561  maxTransferRank(maxRank) {}
562 
563  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
564  PatternRewriter &rewriter) const override {
565  if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
566  return rewriter.notifyMatchFailure(
567  write, "vector type is greater than max transfer rank");
568  }
569 
570  // Permutations are handled by VectorToSCF or
571  // populateVectorTransferPermutationMapLoweringPatterns.
572  if ( // pass-through for the 0-d corner case.
573  !write.getPermutationMap().isMinorIdentity())
574  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
575  diag << "permutation map is not minor identity: " << write;
576  });
577 
578  auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
579  if (!memRefType)
580  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
581  diag << "not a memref type: " << write;
582  });
583 
584  // Non-unit strides are handled by VectorToSCF.
585  if (!isLastMemrefDimUnitStride(memRefType))
586  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
587  diag << "most minor stride is not 1: " << write;
588  });
589 
590  // `vector.store` supports vector types as memref's elements only when the
591  // type of the vector value being written is the same as the element type.
592  auto memrefElTy = memRefType.getElementType();
593  if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
594  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
595  diag << "elemental type mismatch: " << write;
596  });
597 
598  // Otherwise, element types of the memref and the vector must match.
599  if (!isa<VectorType>(memrefElTy) &&
600  memrefElTy != write.getVectorType().getElementType())
601  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
602  diag << "elemental type mismatch: " << write;
603  });
604 
605  // Out-of-bounds dims are handled by MaterializeTransferMask.
606  if (write.hasOutOfBoundsDim())
607  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
608  diag << "out of bounds dim: " << write;
609  });
610  if (write.getMask()) {
611  if (write.getVectorType().getRank() != 1)
612  // vector.maskedstore operates on 1-D vectors.
613  return rewriter.notifyMatchFailure(
614  write.getLoc(), [=](Diagnostic &diag) {
615  diag << "vector type is not rank 1, can't create masked store, "
616  "needs VectorToSCF: "
617  << write;
618  });
619 
620  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
621  write, write.getSource(), write.getIndices(), write.getMask(),
622  write.getVector());
623  } else {
624  rewriter.replaceOpWithNewOp<vector::StoreOp>(
625  write, write.getVector(), write.getSource(), write.getIndices());
626  }
627  return success();
628  }
629 
630  std::optional<unsigned> maxTransferRank;
631 };
632 } // namespace
633 
635  RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
636  PatternBenefit benefit) {
637  patterns.add<TransferReadToVectorLoadLowering,
638  TransferWriteToVectorStoreLowering>(patterns.getContext(),
639  maxTransferRank, benefit);
640  patterns
641  .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
642  patterns.getContext(), benefit);
643 }
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static VectorShape vectorShape(Type type)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:132
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:152
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
Definition: AffineMap.cpp:160
unsigned getNumDims() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:200
unsigned getNumResults() const
Definition: AffineMap.cpp:386
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:329
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:683
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362