MLIR  18.0.0git
LowerVectorTransfer.cpp
Go to the documentation of this file.
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
24 /// permutation based on the given indices.
25 static ArrayAttr
26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27  const SmallVector<unsigned> &permutation) {
28  SmallVector<bool> newInBoundsValues(permutation.size());
29  size_t index = 0;
30  for (unsigned pos : permutation)
31  newInBoundsValues[pos] =
32  cast<BoolAttr>(attr.getValue()[index++]).getValue();
33  return builder.getBoolArrayAttr(newInBoundsValues);
34 }
35 
36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
37 /// dimensions.
38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
39  int64_t addedRank) {
40  auto originalVecType = cast<VectorType>(vec.getType());
41  SmallVector<int64_t> newShape(addedRank, 1);
42  newShape.append(originalVecType.getShape().begin(),
43  originalVecType.getShape().end());
44  VectorType newVecType =
45  VectorType::get(newShape, originalVecType.getElementType());
46  return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
47 }
48 
49 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit
50 /// dimensions.
51 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
52  int64_t addedRank) {
53  Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
54  SmallVector<int64_t> permutation;
55  for (int64_t i = addedRank,
56  e = broadcasted.getType().cast<VectorType>().getRank();
57  i < e; ++i)
58  permutation.push_back(i);
59  for (int64_t i = 0; i < addedRank; ++i)
60  permutation.push_back(i);
61  return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // populateVectorTransferPermutationMapLoweringPatterns
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// Lower transfer_read op with permutation into a transfer_read with a
70 /// permutation map composed of leading zeros followed by a minor identiy +
71 /// vector.transpose op.
72 /// Ex:
73 /// vector.transfer_read ...
74 /// permutation_map: (d0, d1, d2) -> (0, d1)
75 /// into:
76 /// %v = vector.transfer_read ...
77 /// permutation_map: (d0, d1, d2) -> (d1, 0)
78 /// vector.transpose %v, [1, 0]
79 ///
80 /// vector.transfer_read ...
81 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
82 /// into:
83 /// %v = vector.transfer_read ...
84 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
85 /// vector.transpose %v, [0, 1, 3, 2, 4]
86 /// Note that an alternative is to transform it to linalg.transpose +
87 /// vector.transfer_read to do the transpose in memory instead.
88 struct TransferReadPermutationLowering
89  : public OpRewritePattern<vector::TransferReadOp> {
91 
92  LogicalResult matchAndRewrite(vector::TransferReadOp op,
93  PatternRewriter &rewriter) const override {
94  // TODO: support 0-d corner case.
95  if (op.getTransferRank() == 0)
96  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
97 
98  SmallVector<unsigned> permutation;
99  AffineMap map = op.getPermutationMap();
100  if (map.getNumResults() == 0)
101  return rewriter.notifyMatchFailure(op, "0 result permutation map");
102  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
103  return rewriter.notifyMatchFailure(
104  op, "map is not permutable to minor identity, apply another pattern");
105  }
106  AffineMap permutationMap =
107  map.getPermutationMap(permutation, op.getContext());
108  if (permutationMap.isIdentity())
109  return rewriter.notifyMatchFailure(op, "map is not identity");
110 
111  permutationMap = map.getPermutationMap(permutation, op.getContext());
112  // Caluclate the map of the new read by applying the inverse permutation.
113  permutationMap = inversePermutation(permutationMap);
114  AffineMap newMap = permutationMap.compose(map);
115  // Apply the reverse transpose to deduce the type of the transfer_read.
116  ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
117  SmallVector<int64_t> newVectorShape(originalShape.size());
118  ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
119  SmallVector<bool> newScalableDims(originalShape.size());
120  for (const auto &pos : llvm::enumerate(permutation)) {
121  newVectorShape[pos.value()] = originalShape[pos.index()];
122  newScalableDims[pos.value()] = originalScalableDims[pos.index()];
123  }
124 
125  // Transpose in_bounds attribute.
126  ArrayAttr newInBoundsAttr =
127  op.getInBounds() ? inverseTransposeInBoundsAttr(
128  rewriter, op.getInBounds().value(), permutation)
129  : ArrayAttr();
130 
131  // Generate new transfer_read operation.
132  VectorType newReadType = VectorType::get(
133  newVectorShape, op.getVectorType().getElementType(), newScalableDims);
134  Value newRead = rewriter.create<vector::TransferReadOp>(
135  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
136  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
137  newInBoundsAttr);
138 
139  // Transpose result of transfer_read.
140  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
141  rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
142  transposePerm);
143  return success();
144  }
145 };
146 
147 /// Lower transfer_write op with permutation into a transfer_write with a
148 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
149 /// Ex:
150 /// vector.transfer_write %v ...
151 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
152 /// into:
153 /// %tmp = vector.transpose %v, [2, 0, 1]
154 /// vector.transfer_write %tmp ...
155 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
156 ///
157 /// vector.transfer_write %v ...
158 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
159 /// into:
160 /// %tmp = vector.transpose %v, [1, 0]
161 /// %v = vector.transfer_write %tmp ...
162 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
163 struct TransferWritePermutationLowering
164  : public OpRewritePattern<vector::TransferWriteOp> {
166 
167  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
168  PatternRewriter &rewriter) const override {
169  // TODO: support 0-d corner case.
170  if (op.getTransferRank() == 0)
171  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
172 
173  SmallVector<unsigned> permutation;
174  AffineMap map = op.getPermutationMap();
175  if (map.isMinorIdentity())
176  return rewriter.notifyMatchFailure(op, "map is already minor identity");
177 
178  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
179  return rewriter.notifyMatchFailure(
180  op, "map is not permutable to minor identity, apply another pattern");
181  }
182 
183  // Remove unused dims from the permutation map. E.g.:
184  // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
185  // comp = (d0, d1, d2) -> (d2, d0, d1)
186  auto comp = compressUnusedDims(map);
187  AffineMap permutationMap = inversePermutation(comp);
188  // Get positions of remaining result dims.
189  SmallVector<int64_t> indices;
190  llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
191  [](AffineExpr expr) {
192  return dyn_cast<AffineDimExpr>(expr).getPosition();
193  });
194 
195  // Transpose in_bounds attribute.
196  ArrayAttr newInBoundsAttr =
197  op.getInBounds() ? inverseTransposeInBoundsAttr(
198  rewriter, op.getInBounds().value(), permutation)
199  : ArrayAttr();
200 
201  // Generate new transfer_write operation.
202  Value newVec = rewriter.create<vector::TransposeOp>(
203  op.getLoc(), op.getVector(), indices);
204  auto newMap = AffineMap::getMinorIdentityMap(
205  map.getNumDims(), map.getNumResults(), rewriter.getContext());
206  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
207  op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
208  op.getMask(), newInBoundsAttr);
209 
210  return success();
211  }
212 };
213 
214 /// Convert a transfer.write op with a map which isn't the permutation of a
215 /// minor identity into a vector.broadcast + transfer_write with permutation of
216 /// minor identity map by adding unit dim on inner dimension. Ex:
217 /// ```
218 /// vector.transfer_write %v
219 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
220 /// vector<8x16xf32>
221 /// ```
222 /// into:
223 /// ```
224 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
225 /// vector.transfer_write %v1
226 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
227 /// vector<1x8x16xf32>
228 /// ```
229 struct TransferWriteNonPermutationLowering
230  : public OpRewritePattern<vector::TransferWriteOp> {
232 
233  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
234  PatternRewriter &rewriter) const override {
235  // TODO: support 0-d corner case.
236  if (op.getTransferRank() == 0)
237  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
238 
239  SmallVector<unsigned> permutation;
240  AffineMap map = op.getPermutationMap();
241  if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
242  return rewriter.notifyMatchFailure(
243  op,
244  "map is already permutable to minor identity, apply another pattern");
245  }
246 
247  // Missing outer dimensions are allowed, find the most outer existing
248  // dimension then deduce the missing inner dimensions.
249  SmallVector<bool> foundDim(map.getNumDims(), false);
250  for (AffineExpr exp : map.getResults())
251  foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
253  bool foundFirstDim = false;
254  SmallVector<int64_t> missingInnerDim;
255  for (size_t i = 0; i < foundDim.size(); i++) {
256  if (foundDim[i]) {
257  foundFirstDim = true;
258  continue;
259  }
260  if (!foundFirstDim)
261  continue;
262  // Once we found one outer dimension existing in the map keep track of all
263  // the missing dimensions after that.
264  missingInnerDim.push_back(i);
265  exprs.push_back(rewriter.getAffineDimExpr(i));
266  }
267  // Vector: add unit dims at the beginning of the shape.
268  Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
269  missingInnerDim.size());
270  // Mask: add unit dims at the end of the shape.
271  Value newMask;
272  if (op.getMask())
273  newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
274  missingInnerDim.size());
275  exprs.append(map.getResults().begin(), map.getResults().end());
276  AffineMap newMap =
277  AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
278  // All the new dimensions added are inbound.
279  SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
280  for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
281  newInBoundsValues.push_back(op.isDimInBounds(i));
282  }
283  ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
284  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
285  op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
286  newMask, newInBoundsAttr);
287  return success();
288  }
289 };
290 
291 /// Lower transfer_read op with broadcast in the leading dimensions into
292 /// transfer_read of lower rank + vector.broadcast.
293 /// Ex: vector.transfer_read ...
294 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
295 /// into:
296 /// %v = vector.transfer_read ...
297 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
298 /// vector.broadcast %v
299 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
301 
302  LogicalResult matchAndRewrite(vector::TransferReadOp op,
303  PatternRewriter &rewriter) const override {
304  // TODO: support 0-d corner case.
305  if (op.getTransferRank() == 0)
306  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
307 
308  AffineMap map = op.getPermutationMap();
309  unsigned numLeadingBroadcast = 0;
310  for (auto expr : map.getResults()) {
311  auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
312  if (!dimExpr || dimExpr.getValue() != 0)
313  break;
314  numLeadingBroadcast++;
315  }
316  // If there are no leading zeros in the map there is nothing to do.
317  if (numLeadingBroadcast == 0)
318  return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
319 
320  VectorType originalVecType = op.getVectorType();
321  unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
322  // Calculate new map, vector type and masks without the leading zeros.
323  AffineMap newMap = AffineMap::get(
324  map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
325  op.getContext());
326  // Only remove the leading zeros if the rest of the map is a minor identity
327  // with broadasting. Otherwise we first want to permute the map.
328  if (!newMap.isMinorIdentityWithBroadcasting()) {
329  return rewriter.notifyMatchFailure(
330  op, "map is not a minor identity with broadcasting");
331  }
332 
333  // TODO: support zero-dimension vectors natively. See:
334  // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
335  // In the meantime, lower these to a scalar load when they pop up.
336  if (reducedShapeRank == 0) {
337  Value newRead;
338  if (isa<TensorType>(op.getShapedType())) {
339  newRead = rewriter.create<tensor::ExtractOp>(
340  op.getLoc(), op.getSource(), op.getIndices());
341  } else {
342  newRead = rewriter.create<memref::LoadOp>(
343  op.getLoc(), originalVecType.getElementType(), op.getSource(),
344  op.getIndices());
345  }
346  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
347  newRead);
348  return success();
349  }
350 
351  SmallVector<int64_t> newShape(
352  originalVecType.getShape().take_back(reducedShapeRank));
353  SmallVector<bool> newScalableDims(
354  originalVecType.getScalableDims().take_back(reducedShapeRank));
355  // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
356  if (newShape.empty())
357  return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
358 
359  VectorType newReadType = VectorType::get(
360  newShape, originalVecType.getElementType(), newScalableDims);
361  ArrayAttr newInBoundsAttr =
362  op.getInBounds()
363  ? rewriter.getArrayAttr(
364  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
365  : ArrayAttr();
366  Value newRead = rewriter.create<vector::TransferReadOp>(
367  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
368  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
369  newInBoundsAttr);
370  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
371  newRead);
372  return success();
373  }
374 };
375 
376 } // namespace
377 
379  RewritePatternSet &patterns, PatternBenefit benefit) {
380  patterns
381  .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
382  TransferOpReduceRank, TransferWriteNonPermutationLowering>(
383  patterns.getContext(), benefit);
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // populateVectorTransferLoweringPatterns
388 //===----------------------------------------------------------------------===//
389 
390 namespace {
391 /// Progressive lowering of transfer_read. This pattern supports lowering of
392 /// `vector.transfer_read` to a combination of `vector.load` and
393 /// `vector.broadcast` if all of the following hold:
394 /// - Stride of most minor memref dimension must be 1.
395 /// - Out-of-bounds masking is not required.
396 /// - If the memref's element type is a vector type then it coincides with the
397 /// result type.
398 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
399 struct TransferReadToVectorLoadLowering
400  : public OpRewritePattern<vector::TransferReadOp> {
401  TransferReadToVectorLoadLowering(MLIRContext *context,
402  std::optional<unsigned> maxRank,
403  PatternBenefit benefit = 1)
404  : OpRewritePattern<vector::TransferReadOp>(context, benefit),
405  maxTransferRank(maxRank) {}
406 
407  LogicalResult matchAndRewrite(vector::TransferReadOp read,
408  PatternRewriter &rewriter) const override {
409  if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
410  return rewriter.notifyMatchFailure(
411  read, "vector type is greater than max transfer rank");
412  }
413 
414  SmallVector<unsigned> broadcastedDims;
415  // Permutations are handled by VectorToSCF or
416  // populateVectorTransferPermutationMapLoweringPatterns.
417  // We let the 0-d corner case pass-through as it is supported.
418  if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
419  &broadcastedDims))
420  return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
421 
422  auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
423  if (!memRefType)
424  return rewriter.notifyMatchFailure(read, "not a memref source");
425 
426  // Non-unit strides are handled by VectorToSCF.
427  if (!isLastMemrefDimUnitStride(memRefType))
428  return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
429 
430  // If there is broadcasting involved then we first load the unbroadcasted
431  // vector, and then broadcast it with `vector.broadcast`.
432  ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
433  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
434  vectorShape.end());
435  for (unsigned i : broadcastedDims)
436  unbroadcastedVectorShape[i] = 1;
437  VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
438  unbroadcastedVectorShape, read.getVectorType().getElementType());
439 
440  // `vector.load` supports vector types as memref's elements only when the
441  // resulting vector type is the same as the element type.
442  auto memrefElTy = memRefType.getElementType();
443  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
444  return rewriter.notifyMatchFailure(read, "incompatible element type");
445 
446  // Otherwise, element types of the memref and the vector must match.
447  if (!isa<VectorType>(memrefElTy) &&
448  memrefElTy != read.getVectorType().getElementType())
449  return rewriter.notifyMatchFailure(read, "non-matching element type");
450 
451  // Out-of-bounds dims are handled by MaterializeTransferMask.
452  if (read.hasOutOfBoundsDim())
453  return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
454 
455  // Create vector load op.
456  Operation *loadOp;
457  if (read.getMask()) {
458  if (read.getVectorType().getRank() != 1)
459  // vector.maskedload operates on 1-D vectors.
460  return rewriter.notifyMatchFailure(
461  read, "vector type is not rank 1, can't create masked load, needs "
462  "VectorToSCF");
463 
464  Value fill = rewriter.create<vector::SplatOp>(
465  read.getLoc(), unbroadcastedVectorType, read.getPadding());
466  loadOp = rewriter.create<vector::MaskedLoadOp>(
467  read.getLoc(), unbroadcastedVectorType, read.getSource(),
468  read.getIndices(), read.getMask(), fill);
469  } else {
470  loadOp = rewriter.create<vector::LoadOp>(
471  read.getLoc(), unbroadcastedVectorType, read.getSource(),
472  read.getIndices());
473  }
474 
475  // Insert a broadcasting op if required.
476  if (!broadcastedDims.empty()) {
477  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
478  read, read.getVectorType(), loadOp->getResult(0));
479  } else {
480  rewriter.replaceOp(read, loadOp->getResult(0));
481  }
482 
483  return success();
484  }
485 
486  std::optional<unsigned> maxTransferRank;
487 };
488 
489 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
490 // TODO: we shouldn't cross the vector/scalar domains just for this
491 // but atm we lack the infra to avoid it. Possible solutions include:
492 // - go directly to LLVM + bitcast
493 // - introduce a bitcast op and likely a new pointer dialect
494 // - let memref.load/store additionally support the 0-d vector case
495 // There are still deeper data layout issues lingering even in this
496 // trivial case (for architectures for which this matters).
497 struct VectorLoadToMemrefLoadLowering
498  : public OpRewritePattern<vector::LoadOp> {
500 
501  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
502  PatternRewriter &rewriter) const override {
503  auto vecType = loadOp.getVectorType();
504  if (vecType.getNumElements() != 1)
505  return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
506 
507  auto memrefLoad = rewriter.create<memref::LoadOp>(
508  loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
509  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
510  memrefLoad);
511  return success();
512  }
513 };
514 
515 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
516 struct VectorStoreToMemrefStoreLowering
517  : public OpRewritePattern<vector::StoreOp> {
519 
520  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
521  PatternRewriter &rewriter) const override {
522  auto vecType = storeOp.getVectorType();
523  if (vecType.getNumElements() != 1)
524  return rewriter.notifyMatchFailure(storeOp, "not single element vector");
525 
526  Value extracted;
527  if (vecType.getRank() == 0) {
528  // TODO: Unifiy once ExtractOp supports 0-d vectors.
529  extracted = rewriter.create<vector::ExtractElementOp>(
530  storeOp.getLoc(), storeOp.getValueToStore());
531  } else {
532  SmallVector<int64_t> indices(vecType.getRank(), 0);
533  extracted = rewriter.create<vector::ExtractOp>(
534  storeOp.getLoc(), storeOp.getValueToStore(), indices);
535  }
536 
537  rewriter.replaceOpWithNewOp<memref::StoreOp>(
538  storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
539  return success();
540  }
541 };
542 
543 /// Progressive lowering of transfer_write. This pattern supports lowering of
544 /// `vector.transfer_write` to `vector.store` if all of the following hold:
545 /// - Stride of most minor memref dimension must be 1.
546 /// - Out-of-bounds masking is not required.
547 /// - If the memref's element type is a vector type then it coincides with the
548 /// type of the written value.
549 /// - The permutation map is the minor identity map (neither permutation nor
550 /// broadcasting is allowed).
551 struct TransferWriteToVectorStoreLowering
552  : public OpRewritePattern<vector::TransferWriteOp> {
553  TransferWriteToVectorStoreLowering(MLIRContext *context,
554  std::optional<unsigned> maxRank,
555  PatternBenefit benefit = 1)
556  : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
557  maxTransferRank(maxRank) {}
558 
559  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
560  PatternRewriter &rewriter) const override {
561  if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
562  return rewriter.notifyMatchFailure(
563  write, "vector type is greater than max transfer rank");
564  }
565 
566  // Permutations are handled by VectorToSCF or
567  // populateVectorTransferPermutationMapLoweringPatterns.
568  if ( // pass-through for the 0-d corner case.
569  !write.getPermutationMap().isMinorIdentity())
570  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
571  diag << "permutation map is not minor identity: " << write;
572  });
573 
574  auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
575  if (!memRefType)
576  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
577  diag << "not a memref type: " << write;
578  });
579 
580  // Non-unit strides are handled by VectorToSCF.
581  if (!isLastMemrefDimUnitStride(memRefType))
582  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
583  diag << "most minor stride is not 1: " << write;
584  });
585 
586  // `vector.store` supports vector types as memref's elements only when the
587  // type of the vector value being written is the same as the element type.
588  auto memrefElTy = memRefType.getElementType();
589  if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
590  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
591  diag << "elemental type mismatch: " << write;
592  });
593 
594  // Otherwise, element types of the memref and the vector must match.
595  if (!isa<VectorType>(memrefElTy) &&
596  memrefElTy != write.getVectorType().getElementType())
597  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
598  diag << "elemental type mismatch: " << write;
599  });
600 
601  // Out-of-bounds dims are handled by MaterializeTransferMask.
602  if (write.hasOutOfBoundsDim())
603  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
604  diag << "out of bounds dim: " << write;
605  });
606  if (write.getMask()) {
607  if (write.getVectorType().getRank() != 1)
608  // vector.maskedstore operates on 1-D vectors.
609  return rewriter.notifyMatchFailure(
610  write.getLoc(), [=](Diagnostic &diag) {
611  diag << "vector type is not rank 1, can't create masked store, "
612  "needs VectorToSCF: "
613  << write;
614  });
615 
616  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
617  write, write.getSource(), write.getIndices(), write.getMask(),
618  write.getVector());
619  } else {
620  rewriter.replaceOpWithNewOp<vector::StoreOp>(
621  write, write.getVector(), write.getSource(), write.getIndices());
622  }
623  return success();
624  }
625 
626  std::optional<unsigned> maxTransferRank;
627 };
628 } // namespace
629 
631  RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
632  PatternBenefit benefit) {
633  patterns.add<TransferReadToVectorLoadLowering,
634  TransferWriteToVectorStoreLowering>(patterns.getContext(),
635  maxTransferRank, benefit);
636  patterns
637  .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
638  patterns.getContext(), benefit);
639 }
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static ArrayRef< int64_t > vectorShape(Type type)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:132
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:152
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
Definition: AffineMap.cpp:160
unsigned getNumDims() const
Definition: AffineMap.cpp:374
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:200
unsigned getNumResults() const
Definition: AffineMap.cpp:382
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:536
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:323
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
U cast() const
Definition: Types.h:339
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:749
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:679
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361