MLIR  22.0.0git
LowerVectorTransfer.cpp
Go to the documentation of this file.
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
21 /// permutation based on the given indices.
22 static ArrayAttr
23 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
24  const SmallVector<unsigned> &permutation) {
25  SmallVector<bool> newInBoundsValues(permutation.size());
26  size_t index = 0;
27  for (unsigned pos : permutation)
28  newInBoundsValues[pos] =
29  cast<BoolAttr>(attr.getValue()[index++]).getValue();
30  return builder.getBoolArrayAttr(newInBoundsValues);
31 }
32 
33 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
34 /// dimensions.
35 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
36  int64_t addedRank) {
37  auto originalVecType = cast<VectorType>(vec.getType());
38  SmallVector<int64_t> newShape(addedRank, 1);
39  newShape.append(originalVecType.getShape().begin(),
40  originalVecType.getShape().end());
41 
42  SmallVector<bool> newScalableDims(addedRank, false);
43  newScalableDims.append(originalVecType.getScalableDims().begin(),
44  originalVecType.getScalableDims().end());
45  VectorType newVecType = VectorType::get(
46  newShape, originalVecType.getElementType(), newScalableDims);
47  return vector::BroadcastOp::create(builder, loc, newVecType, vec);
48 }
49 
50 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit
51 /// dimensions.
52 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
53  int64_t addedRank) {
54  Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
55  SmallVector<int64_t> permutation;
56  for (int64_t i = addedRank,
57  e = cast<VectorType>(broadcasted.getType()).getRank();
58  i < e; ++i)
59  permutation.push_back(i);
60  for (int64_t i = 0; i < addedRank; ++i)
61  permutation.push_back(i);
62  return vector::TransposeOp::create(builder, loc, broadcasted, permutation);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // populateVectorTransferPermutationMapLoweringPatterns
67 //===----------------------------------------------------------------------===//
68 
69 namespace {
70 /// Lower transfer_read op with permutation into a transfer_read with a
71 /// permutation map composed of leading zeros followed by a minor identiy +
72 /// vector.transpose op.
73 /// Ex:
74 /// vector.transfer_read ...
75 /// permutation_map: (d0, d1, d2) -> (0, d1)
76 /// into:
77 /// %v = vector.transfer_read ...
78 /// permutation_map: (d0, d1, d2) -> (d1, 0)
79 /// vector.transpose %v, [1, 0]
80 ///
81 /// vector.transfer_read ...
82 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
83 /// into:
84 /// %v = vector.transfer_read ...
85 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
86 /// vector.transpose %v, [0, 1, 3, 2, 4]
87 /// Note that an alternative is to transform it to linalg.transpose +
88 /// vector.transfer_read to do the transpose in memory instead.
89 struct TransferReadPermutationLowering
90  : public MaskableOpRewritePattern<vector::TransferReadOp> {
91  using MaskableOpRewritePattern::MaskableOpRewritePattern;
92 
93  FailureOr<mlir::Value>
94  matchAndRewriteMaskableOp(vector::TransferReadOp op,
95  MaskingOpInterface maskOp,
96  PatternRewriter &rewriter) const override {
97  // TODO: support 0-d corner case.
98  if (op.getTransferRank() == 0)
99  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
100  // TODO: Support transfer_read inside MaskOp case.
101  if (maskOp)
102  return rewriter.notifyMatchFailure(op, "Masked case not supported");
103 
104  SmallVector<unsigned> permutation;
105  AffineMap map = op.getPermutationMap();
106  if (map.getNumResults() == 0)
107  return rewriter.notifyMatchFailure(op, "0 result permutation map");
108  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
109  return rewriter.notifyMatchFailure(
110  op, "map is not permutable to minor identity, apply another pattern");
111  }
112  AffineMap permutationMap =
113  map.getPermutationMap(permutation, op.getContext());
114  if (permutationMap.isIdentity())
115  return rewriter.notifyMatchFailure(op, "map is not identity");
116 
117  permutationMap = map.getPermutationMap(permutation, op.getContext());
118  // Caluclate the map of the new read by applying the inverse permutation.
119  permutationMap = inversePermutation(permutationMap);
120  AffineMap newMap = permutationMap.compose(map);
121  // Apply the reverse transpose to deduce the type of the transfer_read.
122  ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
123  SmallVector<int64_t> newVectorShape(originalShape.size());
124  ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
125  SmallVector<bool> newScalableDims(originalShape.size());
126  for (const auto &pos : llvm::enumerate(permutation)) {
127  newVectorShape[pos.value()] = originalShape[pos.index()];
128  newScalableDims[pos.value()] = originalScalableDims[pos.index()];
129  }
130 
131  // Transpose in_bounds attribute.
132  ArrayAttr newInBoundsAttr =
133  inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
134 
135  // Generate new transfer_read operation.
136  VectorType newReadType = VectorType::get(
137  newVectorShape, op.getVectorType().getElementType(), newScalableDims);
138  Value newRead = vector::TransferReadOp::create(
139  rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
140  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
141  newInBoundsAttr);
142 
143  // Transpose result of transfer_read.
144  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145  return vector::TransposeOp::create(rewriter, op.getLoc(), newRead,
146  transposePerm)
147  .getResult();
148  }
149 };
150 
151 /// Lower transfer_write op with permutation into a transfer_write with a
152 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
153 /// Ex:
154 /// vector.transfer_write %v ...
155 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
156 /// into:
157 /// %tmp = vector.transpose %v, [2, 0, 1]
158 /// vector.transfer_write %tmp ...
159 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
160 ///
161 /// vector.transfer_write %v ...
162 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
163 /// into:
164 /// %tmp = vector.transpose %v, [1, 0]
165 /// %v = vector.transfer_write %tmp ...
166 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167 struct TransferWritePermutationLowering
168  : public MaskableOpRewritePattern<vector::TransferWriteOp> {
169  using MaskableOpRewritePattern::MaskableOpRewritePattern;
170 
171  FailureOr<mlir::Value>
172  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
173  MaskingOpInterface maskOp,
174  PatternRewriter &rewriter) const override {
175  // TODO: support 0-d corner case.
176  if (op.getTransferRank() == 0)
177  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
178  // TODO: Support transfer_write inside MaskOp case.
179  if (maskOp)
180  return rewriter.notifyMatchFailure(op, "Masked case not supported");
181 
182  SmallVector<unsigned> permutation;
183  AffineMap map = op.getPermutationMap();
184  if (map.isMinorIdentity())
185  return rewriter.notifyMatchFailure(op, "map is already minor identity");
186 
187  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
188  return rewriter.notifyMatchFailure(
189  op, "map is not permutable to minor identity, apply another pattern");
190  }
191 
192  // Remove unused dims from the permutation map. E.g.:
193  // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
194  // comp = (d0, d1, d2) -> (d2, d0, d1)
195  auto comp = compressUnusedDims(map);
196  AffineMap permutationMap = inversePermutation(comp);
197  // Get positions of remaining result dims.
198  SmallVector<int64_t> indices;
199  llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
200  [](AffineExpr expr) {
201  return dyn_cast<AffineDimExpr>(expr).getPosition();
202  });
203 
204  // Transpose in_bounds attribute.
205  ArrayAttr newInBoundsAttr =
206  inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
207 
208  // Generate new transfer_write operation.
209  Value newVec = vector::TransposeOp::create(rewriter, op.getLoc(),
210  op.getVector(), indices);
211  auto newMap = AffineMap::getMinorIdentityMap(
212  map.getNumDims(), map.getNumResults(), rewriter.getContext());
213  auto newWrite = vector::TransferWriteOp::create(
214  rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
215  AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
216  if (newWrite.hasPureTensorSemantics())
217  return newWrite.getResult();
218  // In the memref case there's no return value. Use empty value to signal
219  // success.
220  return Value();
221  }
222 };
223 
224 /// Convert a transfer.write op with a map which isn't the permutation of a
225 /// minor identity into a vector.broadcast + transfer_write with permutation of
226 /// minor identity map by adding unit dim on inner dimension. Ex:
227 /// ```
228 /// vector.transfer_write %v
229 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
230 /// vector<8x16xf32>
231 /// ```
232 /// into:
233 /// ```
234 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
235 /// vector.transfer_write %v1
236 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
237 /// vector<1x8x16xf32>
238 /// ```
239 struct TransferWriteNonPermutationLowering
240  : public MaskableOpRewritePattern<vector::TransferWriteOp> {
241  using MaskableOpRewritePattern::MaskableOpRewritePattern;
242 
243  FailureOr<mlir::Value>
244  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
245  MaskingOpInterface maskOp,
246  PatternRewriter &rewriter) const override {
247  // TODO: support 0-d corner case.
248  if (op.getTransferRank() == 0)
249  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
250  // TODO: Support transfer_write inside MaskOp case.
251  if (maskOp)
252  return rewriter.notifyMatchFailure(op, "Masked case not supported");
253 
254  SmallVector<unsigned> permutation;
255  AffineMap map = op.getPermutationMap();
256  if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
257  return rewriter.notifyMatchFailure(
258  op,
259  "map is already permutable to minor identity, apply another pattern");
260  }
261 
262  // Missing outer dimensions are allowed, find the most outer existing
263  // dimension then deduce the missing inner dimensions.
264  SmallVector<bool> foundDim(map.getNumDims(), false);
265  for (AffineExpr exp : map.getResults())
266  foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
268  bool foundFirstDim = false;
269  SmallVector<int64_t> missingInnerDim;
270  for (size_t i = 0; i < foundDim.size(); i++) {
271  if (foundDim[i]) {
272  foundFirstDim = true;
273  continue;
274  }
275  if (!foundFirstDim)
276  continue;
277  // Once we found one outer dimension existing in the map keep track of all
278  // the missing dimensions after that.
279  missingInnerDim.push_back(i);
280  exprs.push_back(rewriter.getAffineDimExpr(i));
281  }
282  // Vector: add unit dims at the beginning of the shape.
283  Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
284  missingInnerDim.size());
285  // Mask: add unit dims at the end of the shape.
286  Value newMask;
287  if (op.getMask())
288  newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
289  missingInnerDim.size());
290  exprs.append(map.getResults().begin(), map.getResults().end());
291  AffineMap newMap =
292  AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
293  // All the new dimensions added are inbound.
294  SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
295  for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
296  newInBoundsValues.push_back(op.isDimInBounds(i));
297  }
298  ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
299  auto newWrite = vector::TransferWriteOp::create(
300  rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
301  AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
302  if (newWrite.hasPureTensorSemantics())
303  return newWrite.getResult();
304  // In the memref case there's no return value. Use empty value to signal
305  // success.
306  return Value();
307  }
308 };
309 
310 /// Lower transfer_read op with broadcast in the leading dimensions into
311 /// transfer_read of lower rank + vector.broadcast.
312 /// Ex: vector.transfer_read ...
313 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
314 /// into:
315 /// %v = vector.transfer_read ...
316 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
317 /// vector.broadcast %v
318 struct TransferOpReduceRank
319  : public MaskableOpRewritePattern<vector::TransferReadOp> {
320  using MaskableOpRewritePattern::MaskableOpRewritePattern;
321 
322  FailureOr<mlir::Value>
323  matchAndRewriteMaskableOp(vector::TransferReadOp op,
324  MaskingOpInterface maskOp,
325  PatternRewriter &rewriter) const override {
326  // TODO: support 0-d corner case.
327  if (op.getTransferRank() == 0)
328  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
329  // TODO: support masked case.
330  if (maskOp)
331  return rewriter.notifyMatchFailure(op, "Masked case not supported");
332 
333  AffineMap map = op.getPermutationMap();
334  unsigned numLeadingBroadcast = 0;
335  for (auto expr : map.getResults()) {
336  auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
337  if (!dimExpr || dimExpr.getValue() != 0)
338  break;
339  numLeadingBroadcast++;
340  }
341  // If there are no leading zeros in the map there is nothing to do.
342  if (numLeadingBroadcast == 0)
343  return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
344 
345  VectorType originalVecType = op.getVectorType();
346  unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
347  // Calculate new map, vector type and masks without the leading zeros.
348  AffineMap newMap = AffineMap::get(
349  map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
350  op.getContext());
351  // Only remove the leading zeros if the rest of the map is a minor identity
352  // with broadasting. Otherwise we first want to permute the map.
353  if (!newMap.isMinorIdentityWithBroadcasting()) {
354  return rewriter.notifyMatchFailure(
355  op, "map is not a minor identity with broadcasting");
356  }
357 
358  SmallVector<int64_t> newShape(
359  originalVecType.getShape().take_back(reducedShapeRank));
360  SmallVector<bool> newScalableDims(
361  originalVecType.getScalableDims().take_back(reducedShapeRank));
362 
363  VectorType newReadType = VectorType::get(
364  newShape, originalVecType.getElementType(), newScalableDims);
365  ArrayAttr newInBoundsAttr =
366  op.getInBounds()
367  ? rewriter.getArrayAttr(
368  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
369  : ArrayAttr();
370  Value newRead = vector::TransferReadOp::create(
371  rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
372  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
373  newInBoundsAttr);
374  return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
375  newRead)
376  .getVector();
377  }
378 };
379 
380 } // namespace
381 
384  patterns
385  .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386  TransferOpReduceRank, TransferWriteNonPermutationLowering>(
387  patterns.getContext(), benefit);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // populateVectorTransferLoweringPatterns
392 //===----------------------------------------------------------------------===//
393 
394 namespace {
395 /// Progressive lowering of transfer_read. This pattern supports lowering of
396 /// `vector.transfer_read` to a combination of `vector.load` and
397 /// `vector.broadcast` if all of the following hold:
398 /// - Stride of most minor memref dimension must be 1.
399 /// - Out-of-bounds masking is not required.
400 /// - If the memref's element type is a vector type then it coincides with the
401 /// result type.
402 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
403 struct TransferReadToVectorLoadLowering
404  : public MaskableOpRewritePattern<vector::TransferReadOp> {
405  TransferReadToVectorLoadLowering(MLIRContext *context,
406  std::optional<unsigned> maxRank,
407  PatternBenefit benefit = 1)
408  : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
409  maxTransferRank(maxRank) {}
410 
411  FailureOr<mlir::Value>
412  matchAndRewriteMaskableOp(vector::TransferReadOp read,
413  MaskingOpInterface maskOp,
414  PatternRewriter &rewriter) const override {
415  if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
416  return rewriter.notifyMatchFailure(
417  read, "vector type is greater than max transfer rank");
418  }
419 
420  if (maskOp)
421  return rewriter.notifyMatchFailure(read, "Masked case not supported");
422  SmallVector<unsigned> broadcastedDims;
423  // Permutations are handled by VectorToSCF or
424  // populateVectorTransferPermutationMapLoweringPatterns.
425  // We let the 0-d corner case pass-through as it is supported.
426  if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
427  &broadcastedDims))
428  return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
429 
430  auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
431  if (!memRefType)
432  return rewriter.notifyMatchFailure(read, "not a memref source");
433 
434  // Non-unit strides are handled by VectorToSCF.
435  if (!memRefType.isLastDimUnitStride())
436  return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
437 
438  // If there is broadcasting involved then we first load the unbroadcasted
439  // vector, and then broadcast it with `vector.broadcast`.
440  ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
441  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
442  for (unsigned i : broadcastedDims)
443  unbroadcastedVectorShape[i] = 1;
444  VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
445  unbroadcastedVectorShape, read.getVectorType().getElementType());
446 
447  // `vector.load` supports vector types as memref's elements only when the
448  // resulting vector type is the same as the element type.
449  auto memrefElTy = memRefType.getElementType();
450  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
451  return rewriter.notifyMatchFailure(read, "incompatible element type");
452 
453  // Otherwise, element types of the memref and the vector must match.
454  if (!isa<VectorType>(memrefElTy) &&
455  memrefElTy != read.getVectorType().getElementType())
456  return rewriter.notifyMatchFailure(read, "non-matching element type");
457 
458  // Out-of-bounds dims are handled by MaterializeTransferMask.
459  if (read.hasOutOfBoundsDim())
460  return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
461 
462  // Create vector load op.
463  Operation *res;
464  if (read.getMask()) {
465  if (read.getVectorType().getRank() != 1)
466  // vector.maskedload operates on 1-D vectors.
467  return rewriter.notifyMatchFailure(
468  read, "vector type is not rank 1, can't create masked load, needs "
469  "VectorToSCF");
470 
471  Value fill = vector::BroadcastOp::create(
472  rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
473  res = vector::MaskedLoadOp::create(
474  rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
475  read.getIndices(), read.getMask(), fill);
476  } else {
477  res = vector::LoadOp::create(rewriter, read.getLoc(),
478  unbroadcastedVectorType, read.getBase(),
479  read.getIndices());
480  }
481 
482  // Insert a broadcasting op if required.
483  if (!broadcastedDims.empty())
484  res = vector::BroadcastOp::create(
485  rewriter, read.getLoc(), read.getVectorType(), res->getResult(0));
486  return res->getResult(0);
487  }
488 
489  std::optional<unsigned> maxTransferRank;
490 };
491 
492 /// Progressive lowering of transfer_write. This pattern supports lowering of
493 /// `vector.transfer_write` to `vector.store` if all of the following hold:
494 /// - Stride of most minor memref dimension must be 1.
495 /// - Out-of-bounds masking is not required.
496 /// - If the memref's element type is a vector type then it coincides with the
497 /// type of the written value.
498 /// - The permutation map is the minor identity map (neither permutation nor
499 /// broadcasting is allowed).
500 struct TransferWriteToVectorStoreLowering
501  : public MaskableOpRewritePattern<vector::TransferWriteOp> {
502  TransferWriteToVectorStoreLowering(MLIRContext *context,
503  std::optional<unsigned> maxRank,
504  PatternBenefit benefit = 1)
505  : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
506  maxTransferRank(maxRank) {}
507 
508  FailureOr<mlir::Value>
509  matchAndRewriteMaskableOp(vector::TransferWriteOp write,
510  MaskingOpInterface maskOp,
511  PatternRewriter &rewriter) const override {
512  if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
513  return rewriter.notifyMatchFailure(
514  write, "vector type is greater than max transfer rank");
515  }
516  if (maskOp)
517  return rewriter.notifyMatchFailure(write, "Masked case not supported");
518 
519  // Permutations are handled by VectorToSCF or
520  // populateVectorTransferPermutationMapLoweringPatterns.
521  if ( // pass-through for the 0-d corner case.
522  !write.getPermutationMap().isMinorIdentity())
523  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
524  diag << "permutation map is not minor identity: " << write;
525  });
526 
527  auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
528  if (!memRefType)
529  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
530  diag << "not a memref type: " << write;
531  });
532 
533  // Non-unit strides are handled by VectorToSCF.
534  if (!memRefType.isLastDimUnitStride())
535  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
536  diag << "most minor stride is not 1: " << write;
537  });
538 
539  // `vector.store` supports vector types as memref's elements only when the
540  // type of the vector value being written is the same as the element type.
541  auto memrefElTy = memRefType.getElementType();
542  if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
543  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
544  diag << "elemental type mismatch: " << write;
545  });
546 
547  // Otherwise, element types of the memref and the vector must match.
548  if (!isa<VectorType>(memrefElTy) &&
549  memrefElTy != write.getVectorType().getElementType())
550  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
551  diag << "elemental type mismatch: " << write;
552  });
553 
554  // Out-of-bounds dims are handled by MaterializeTransferMask.
555  if (write.hasOutOfBoundsDim())
556  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
557  diag << "out of bounds dim: " << write;
558  });
559  if (write.getMask()) {
560  if (write.getVectorType().getRank() != 1)
561  // vector.maskedstore operates on 1-D vectors.
562  return rewriter.notifyMatchFailure(
563  write.getLoc(), [=](Diagnostic &diag) {
564  diag << "vector type is not rank 1, can't create masked store, "
565  "needs VectorToSCF: "
566  << write;
567  });
568 
569  vector::MaskedStoreOp::create(rewriter, write.getLoc(), write.getBase(),
570  write.getIndices(), write.getMask(),
571  write.getVector());
572  } else {
573  vector::StoreOp::create(rewriter, write.getLoc(), write.getVector(),
574  write.getBase(), write.getIndices());
575  }
576  // There's no return value for StoreOps. Use Value() to signal success to
577  // matchAndRewrite.
578  return Value();
579  }
580 
581  std::optional<unsigned> maxTransferRank;
582 };
583 } // namespace
584 
586  RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
587  PatternBenefit benefit) {
588  patterns.add<TransferReadToVectorLoadLowering,
589  TransferWriteToVectorStoreLowering>(patterns.getContext(),
590  maxTransferRank, benefit);
591 }
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:151
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
Definition: AffineMap.cpp:172
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:212
unsigned getNumResults() const
Definition: AffineMap.cpp:398
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
const FrozenRewritePatternSet & patterns
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163