MLIR  20.0.0git
LowerVectorTransfer.cpp
Go to the documentation of this file.
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
24 /// permutation based on the given indices.
25 static ArrayAttr
26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27  const SmallVector<unsigned> &permutation) {
28  SmallVector<bool> newInBoundsValues(permutation.size());
29  size_t index = 0;
30  for (unsigned pos : permutation)
31  newInBoundsValues[pos] =
32  cast<BoolAttr>(attr.getValue()[index++]).getValue();
33  return builder.getBoolArrayAttr(newInBoundsValues);
34 }
35 
36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
37 /// dimensions.
38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
39  int64_t addedRank) {
40  auto originalVecType = cast<VectorType>(vec.getType());
41  SmallVector<int64_t> newShape(addedRank, 1);
42  newShape.append(originalVecType.getShape().begin(),
43  originalVecType.getShape().end());
44 
45  SmallVector<bool> newScalableDims(addedRank, false);
46  newScalableDims.append(originalVecType.getScalableDims().begin(),
47  originalVecType.getScalableDims().end());
48  VectorType newVecType = VectorType::get(
49  newShape, originalVecType.getElementType(), newScalableDims);
50  return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
51 }
52 
53 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit
54 /// dimensions.
55 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
56  int64_t addedRank) {
57  Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
58  SmallVector<int64_t> permutation;
59  for (int64_t i = addedRank,
60  e = cast<VectorType>(broadcasted.getType()).getRank();
61  i < e; ++i)
62  permutation.push_back(i);
63  for (int64_t i = 0; i < addedRank; ++i)
64  permutation.push_back(i);
65  return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // populateVectorTransferPermutationMapLoweringPatterns
70 //===----------------------------------------------------------------------===//
71 
72 namespace {
73 /// Lower transfer_read op with permutation into a transfer_read with a
74 /// permutation map composed of leading zeros followed by a minor identiy +
75 /// vector.transpose op.
76 /// Ex:
77 /// vector.transfer_read ...
78 /// permutation_map: (d0, d1, d2) -> (0, d1)
79 /// into:
80 /// %v = vector.transfer_read ...
81 /// permutation_map: (d0, d1, d2) -> (d1, 0)
82 /// vector.transpose %v, [1, 0]
83 ///
84 /// vector.transfer_read ...
85 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
86 /// into:
87 /// %v = vector.transfer_read ...
88 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
89 /// vector.transpose %v, [0, 1, 3, 2, 4]
90 /// Note that an alternative is to transform it to linalg.transpose +
91 /// vector.transfer_read to do the transpose in memory instead.
92 struct TransferReadPermutationLowering
93  : public MaskableOpRewritePattern<vector::TransferReadOp> {
94  using MaskableOpRewritePattern::MaskableOpRewritePattern;
95 
96  FailureOr<mlir::Value>
97  matchAndRewriteMaskableOp(vector::TransferReadOp op,
98  MaskingOpInterface maskOp,
99  PatternRewriter &rewriter) const override {
100  // TODO: support 0-d corner case.
101  if (op.getTransferRank() == 0)
102  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
103  // TODO: Support transfer_read inside MaskOp case.
104  if (maskOp)
105  return rewriter.notifyMatchFailure(op, "Masked case not supported");
106 
107  SmallVector<unsigned> permutation;
108  AffineMap map = op.getPermutationMap();
109  if (map.getNumResults() == 0)
110  return rewriter.notifyMatchFailure(op, "0 result permutation map");
111  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
112  return rewriter.notifyMatchFailure(
113  op, "map is not permutable to minor identity, apply another pattern");
114  }
115  AffineMap permutationMap =
116  map.getPermutationMap(permutation, op.getContext());
117  if (permutationMap.isIdentity())
118  return rewriter.notifyMatchFailure(op, "map is not identity");
119 
120  permutationMap = map.getPermutationMap(permutation, op.getContext());
121  // Caluclate the map of the new read by applying the inverse permutation.
122  permutationMap = inversePermutation(permutationMap);
123  AffineMap newMap = permutationMap.compose(map);
124  // Apply the reverse transpose to deduce the type of the transfer_read.
125  ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
126  SmallVector<int64_t> newVectorShape(originalShape.size());
127  ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
128  SmallVector<bool> newScalableDims(originalShape.size());
129  for (const auto &pos : llvm::enumerate(permutation)) {
130  newVectorShape[pos.value()] = originalShape[pos.index()];
131  newScalableDims[pos.value()] = originalScalableDims[pos.index()];
132  }
133 
134  // Transpose in_bounds attribute.
135  ArrayAttr newInBoundsAttr =
136  inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
137 
138  // Generate new transfer_read operation.
139  VectorType newReadType = VectorType::get(
140  newVectorShape, op.getVectorType().getElementType(), newScalableDims);
141  Value newRead = rewriter.create<vector::TransferReadOp>(
142  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
143  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
144  newInBoundsAttr);
145 
146  // Transpose result of transfer_read.
147  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
148  return rewriter
149  .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
150  .getResult();
151  }
152 };
153 
154 /// Lower transfer_write op with permutation into a transfer_write with a
155 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
156 /// Ex:
157 /// vector.transfer_write %v ...
158 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
159 /// into:
160 /// %tmp = vector.transpose %v, [2, 0, 1]
161 /// vector.transfer_write %tmp ...
162 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
163 ///
164 /// vector.transfer_write %v ...
165 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
166 /// into:
167 /// %tmp = vector.transpose %v, [1, 0]
168 /// %v = vector.transfer_write %tmp ...
169 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
170 struct TransferWritePermutationLowering
171  : public MaskableOpRewritePattern<vector::TransferWriteOp> {
172  using MaskableOpRewritePattern::MaskableOpRewritePattern;
173 
174  FailureOr<mlir::Value>
175  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
176  MaskingOpInterface maskOp,
177  PatternRewriter &rewriter) const override {
178  // TODO: support 0-d corner case.
179  if (op.getTransferRank() == 0)
180  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
181  // TODO: Support transfer_write inside MaskOp case.
182  if (maskOp)
183  return rewriter.notifyMatchFailure(op, "Masked case not supported");
184 
185  SmallVector<unsigned> permutation;
186  AffineMap map = op.getPermutationMap();
187  if (map.isMinorIdentity())
188  return rewriter.notifyMatchFailure(op, "map is already minor identity");
189 
190  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
191  return rewriter.notifyMatchFailure(
192  op, "map is not permutable to minor identity, apply another pattern");
193  }
194 
195  // Remove unused dims from the permutation map. E.g.:
196  // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
197  // comp = (d0, d1, d2) -> (d2, d0, d1)
198  auto comp = compressUnusedDims(map);
199  AffineMap permutationMap = inversePermutation(comp);
200  // Get positions of remaining result dims.
201  SmallVector<int64_t> indices;
202  llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
203  [](AffineExpr expr) {
204  return dyn_cast<AffineDimExpr>(expr).getPosition();
205  });
206 
207  // Transpose in_bounds attribute.
208  ArrayAttr newInBoundsAttr =
209  inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
210 
211  // Generate new transfer_write operation.
212  Value newVec = rewriter.create<vector::TransposeOp>(
213  op.getLoc(), op.getVector(), indices);
214  auto newMap = AffineMap::getMinorIdentityMap(
215  map.getNumDims(), map.getNumResults(), rewriter.getContext());
216  auto newWrite = rewriter.create<vector::TransferWriteOp>(
217  op.getLoc(), newVec, op.getSource(), op.getIndices(),
218  AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
219  if (newWrite.hasPureTensorSemantics())
220  return newWrite.getResult();
221  // In the memref case there's no return value. Use empty value to signal
222  // success.
223  return Value();
224  }
225 };
226 
227 /// Convert a transfer.write op with a map which isn't the permutation of a
228 /// minor identity into a vector.broadcast + transfer_write with permutation of
229 /// minor identity map by adding unit dim on inner dimension. Ex:
230 /// ```
231 /// vector.transfer_write %v
232 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
233 /// vector<8x16xf32>
234 /// ```
235 /// into:
236 /// ```
237 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
238 /// vector.transfer_write %v1
239 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
240 /// vector<1x8x16xf32>
241 /// ```
242 struct TransferWriteNonPermutationLowering
243  : public MaskableOpRewritePattern<vector::TransferWriteOp> {
244  using MaskableOpRewritePattern::MaskableOpRewritePattern;
245 
246  FailureOr<mlir::Value>
247  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
248  MaskingOpInterface maskOp,
249  PatternRewriter &rewriter) const override {
250  // TODO: support 0-d corner case.
251  if (op.getTransferRank() == 0)
252  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
253  // TODO: Support transfer_write inside MaskOp case.
254  if (maskOp)
255  return rewriter.notifyMatchFailure(op, "Masked case not supported");
256 
257  SmallVector<unsigned> permutation;
258  AffineMap map = op.getPermutationMap();
259  if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
260  return rewriter.notifyMatchFailure(
261  op,
262  "map is already permutable to minor identity, apply another pattern");
263  }
264 
265  // Missing outer dimensions are allowed, find the most outer existing
266  // dimension then deduce the missing inner dimensions.
267  SmallVector<bool> foundDim(map.getNumDims(), false);
268  for (AffineExpr exp : map.getResults())
269  foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
271  bool foundFirstDim = false;
272  SmallVector<int64_t> missingInnerDim;
273  for (size_t i = 0; i < foundDim.size(); i++) {
274  if (foundDim[i]) {
275  foundFirstDim = true;
276  continue;
277  }
278  if (!foundFirstDim)
279  continue;
280  // Once we found one outer dimension existing in the map keep track of all
281  // the missing dimensions after that.
282  missingInnerDim.push_back(i);
283  exprs.push_back(rewriter.getAffineDimExpr(i));
284  }
285  // Vector: add unit dims at the beginning of the shape.
286  Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
287  missingInnerDim.size());
288  // Mask: add unit dims at the end of the shape.
289  Value newMask;
290  if (op.getMask())
291  newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
292  missingInnerDim.size());
293  exprs.append(map.getResults().begin(), map.getResults().end());
294  AffineMap newMap =
295  AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
296  // All the new dimensions added are inbound.
297  SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
298  for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
299  newInBoundsValues.push_back(op.isDimInBounds(i));
300  }
301  ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
302  auto newWrite = rewriter.create<vector::TransferWriteOp>(
303  op.getLoc(), newVec, op.getSource(), op.getIndices(),
304  AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
305  if (newWrite.hasPureTensorSemantics())
306  return newWrite.getResult();
307  // In the memref case there's no return value. Use empty value to signal
308  // success.
309  return Value();
310  }
311 };
312 
313 /// Lower transfer_read op with broadcast in the leading dimensions into
314 /// transfer_read of lower rank + vector.broadcast.
315 /// Ex: vector.transfer_read ...
316 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
317 /// into:
318 /// %v = vector.transfer_read ...
319 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
320 /// vector.broadcast %v
321 struct TransferOpReduceRank
322  : public MaskableOpRewritePattern<vector::TransferReadOp> {
323  using MaskableOpRewritePattern::MaskableOpRewritePattern;
324 
325  FailureOr<mlir::Value>
326  matchAndRewriteMaskableOp(vector::TransferReadOp op,
327  MaskingOpInterface maskOp,
328  PatternRewriter &rewriter) const override {
329  // TODO: support 0-d corner case.
330  if (op.getTransferRank() == 0)
331  return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
332  // TODO: support masked case.
333  if (maskOp)
334  return rewriter.notifyMatchFailure(op, "Masked case not supported");
335 
336  AffineMap map = op.getPermutationMap();
337  unsigned numLeadingBroadcast = 0;
338  for (auto expr : map.getResults()) {
339  auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
340  if (!dimExpr || dimExpr.getValue() != 0)
341  break;
342  numLeadingBroadcast++;
343  }
344  // If there are no leading zeros in the map there is nothing to do.
345  if (numLeadingBroadcast == 0)
346  return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
347 
348  VectorType originalVecType = op.getVectorType();
349  unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
350  // Calculate new map, vector type and masks without the leading zeros.
351  AffineMap newMap = AffineMap::get(
352  map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
353  op.getContext());
354  // Only remove the leading zeros if the rest of the map is a minor identity
355  // with broadasting. Otherwise we first want to permute the map.
356  if (!newMap.isMinorIdentityWithBroadcasting()) {
357  return rewriter.notifyMatchFailure(
358  op, "map is not a minor identity with broadcasting");
359  }
360 
361  SmallVector<int64_t> newShape(
362  originalVecType.getShape().take_back(reducedShapeRank));
363  SmallVector<bool> newScalableDims(
364  originalVecType.getScalableDims().take_back(reducedShapeRank));
365 
366  VectorType newReadType = VectorType::get(
367  newShape, originalVecType.getElementType(), newScalableDims);
368  ArrayAttr newInBoundsAttr =
369  op.getInBounds()
370  ? rewriter.getArrayAttr(
371  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
372  : ArrayAttr();
373  Value newRead = rewriter.create<vector::TransferReadOp>(
374  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
375  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
376  newInBoundsAttr);
377  return rewriter
378  .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
379  .getVector();
380  }
381 };
382 
383 } // namespace
384 
387  patterns
388  .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
389  TransferOpReduceRank, TransferWriteNonPermutationLowering>(
390  patterns.getContext(), benefit);
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // populateVectorTransferLoweringPatterns
395 //===----------------------------------------------------------------------===//
396 
397 namespace {
398 /// Progressive lowering of transfer_read. This pattern supports lowering of
399 /// `vector.transfer_read` to a combination of `vector.load` and
400 /// `vector.broadcast` if all of the following hold:
401 /// - Stride of most minor memref dimension must be 1.
402 /// - Out-of-bounds masking is not required.
403 /// - If the memref's element type is a vector type then it coincides with the
404 /// result type.
405 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
406 struct TransferReadToVectorLoadLowering
407  : public MaskableOpRewritePattern<vector::TransferReadOp> {
408  TransferReadToVectorLoadLowering(MLIRContext *context,
409  std::optional<unsigned> maxRank,
410  PatternBenefit benefit = 1)
411  : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
412  maxTransferRank(maxRank) {}
413 
414  FailureOr<mlir::Value>
415  matchAndRewriteMaskableOp(vector::TransferReadOp read,
416  MaskingOpInterface maskOp,
417  PatternRewriter &rewriter) const override {
418  if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
419  return rewriter.notifyMatchFailure(
420  read, "vector type is greater than max transfer rank");
421  }
422 
423  if (maskOp)
424  return rewriter.notifyMatchFailure(read, "Masked case not supported");
425  SmallVector<unsigned> broadcastedDims;
426  // Permutations are handled by VectorToSCF or
427  // populateVectorTransferPermutationMapLoweringPatterns.
428  // We let the 0-d corner case pass-through as it is supported.
429  if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
430  &broadcastedDims))
431  return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
432 
433  auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
434  if (!memRefType)
435  return rewriter.notifyMatchFailure(read, "not a memref source");
436 
437  // Non-unit strides are handled by VectorToSCF.
438  if (!isLastMemrefDimUnitStride(memRefType))
439  return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
440 
441  // If there is broadcasting involved then we first load the unbroadcasted
442  // vector, and then broadcast it with `vector.broadcast`.
443  ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
444  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
445  for (unsigned i : broadcastedDims)
446  unbroadcastedVectorShape[i] = 1;
447  VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
448  unbroadcastedVectorShape, read.getVectorType().getElementType());
449 
450  // `vector.load` supports vector types as memref's elements only when the
451  // resulting vector type is the same as the element type.
452  auto memrefElTy = memRefType.getElementType();
453  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
454  return rewriter.notifyMatchFailure(read, "incompatible element type");
455 
456  // Otherwise, element types of the memref and the vector must match.
457  if (!isa<VectorType>(memrefElTy) &&
458  memrefElTy != read.getVectorType().getElementType())
459  return rewriter.notifyMatchFailure(read, "non-matching element type");
460 
461  // Out-of-bounds dims are handled by MaterializeTransferMask.
462  if (read.hasOutOfBoundsDim())
463  return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
464 
465  // Create vector load op.
466  Operation *res;
467  if (read.getMask()) {
468  if (read.getVectorType().getRank() != 1)
469  // vector.maskedload operates on 1-D vectors.
470  return rewriter.notifyMatchFailure(
471  read, "vector type is not rank 1, can't create masked load, needs "
472  "VectorToSCF");
473 
474  Value fill = rewriter.create<vector::SplatOp>(
475  read.getLoc(), unbroadcastedVectorType, read.getPadding());
476  res = rewriter.create<vector::MaskedLoadOp>(
477  read.getLoc(), unbroadcastedVectorType, read.getSource(),
478  read.getIndices(), read.getMask(), fill);
479  } else {
480  res = rewriter.create<vector::LoadOp>(
481  read.getLoc(), unbroadcastedVectorType, read.getSource(),
482  read.getIndices());
483  }
484 
485  // Insert a broadcasting op if required.
486  if (!broadcastedDims.empty())
487  res = rewriter.create<vector::BroadcastOp>(
488  read.getLoc(), read.getVectorType(), res->getResult(0));
489  return res->getResult(0);
490  }
491 
492  std::optional<unsigned> maxTransferRank;
493 };
494 
495 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
496 // TODO: we shouldn't cross the vector/scalar domains just for this
497 // but atm we lack the infra to avoid it. Possible solutions include:
498 // - go directly to LLVM + bitcast
499 // - introduce a bitcast op and likely a new pointer dialect
500 // - let memref.load/store additionally support the 0-d vector case
501 // There are still deeper data layout issues lingering even in this
502 // trivial case (for architectures for which this matters).
503 struct VectorLoadToMemrefLoadLowering
504  : public OpRewritePattern<vector::LoadOp> {
506 
507  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
508  PatternRewriter &rewriter) const override {
509  auto vecType = loadOp.getVectorType();
510  if (vecType.getNumElements() != 1)
511  return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
512 
513  auto memrefLoad = rewriter.create<memref::LoadOp>(
514  loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
515  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
516  memrefLoad);
517  return success();
518  }
519 };
520 
521 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
522 struct VectorStoreToMemrefStoreLowering
523  : public OpRewritePattern<vector::StoreOp> {
525 
526  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
527  PatternRewriter &rewriter) const override {
528  auto vecType = storeOp.getVectorType();
529  if (vecType.getNumElements() != 1)
530  return rewriter.notifyMatchFailure(storeOp, "not single element vector");
531 
532  Value extracted;
533  if (vecType.getRank() == 0) {
534  // TODO: Unifiy once ExtractOp supports 0-d vectors.
535  extracted = rewriter.create<vector::ExtractElementOp>(
536  storeOp.getLoc(), storeOp.getValueToStore());
537  } else {
538  SmallVector<int64_t> indices(vecType.getRank(), 0);
539  extracted = rewriter.create<vector::ExtractOp>(
540  storeOp.getLoc(), storeOp.getValueToStore(), indices);
541  }
542 
543  rewriter.replaceOpWithNewOp<memref::StoreOp>(
544  storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
545  return success();
546  }
547 };
548 
549 /// Progressive lowering of transfer_write. This pattern supports lowering of
550 /// `vector.transfer_write` to `vector.store` if all of the following hold:
551 /// - Stride of most minor memref dimension must be 1.
552 /// - Out-of-bounds masking is not required.
553 /// - If the memref's element type is a vector type then it coincides with the
554 /// type of the written value.
555 /// - The permutation map is the minor identity map (neither permutation nor
556 /// broadcasting is allowed).
557 struct TransferWriteToVectorStoreLowering
558  : public MaskableOpRewritePattern<vector::TransferWriteOp> {
559  TransferWriteToVectorStoreLowering(MLIRContext *context,
560  std::optional<unsigned> maxRank,
561  PatternBenefit benefit = 1)
562  : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
563  maxTransferRank(maxRank) {}
564 
565  FailureOr<mlir::Value>
566  matchAndRewriteMaskableOp(vector::TransferWriteOp write,
567  MaskingOpInterface maskOp,
568  PatternRewriter &rewriter) const override {
569  if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
570  return rewriter.notifyMatchFailure(
571  write, "vector type is greater than max transfer rank");
572  }
573  if (maskOp)
574  return rewriter.notifyMatchFailure(write, "Masked case not supported");
575 
576  // Permutations are handled by VectorToSCF or
577  // populateVectorTransferPermutationMapLoweringPatterns.
578  if ( // pass-through for the 0-d corner case.
579  !write.getPermutationMap().isMinorIdentity())
580  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
581  diag << "permutation map is not minor identity: " << write;
582  });
583 
584  auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
585  if (!memRefType)
586  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
587  diag << "not a memref type: " << write;
588  });
589 
590  // Non-unit strides are handled by VectorToSCF.
591  if (!isLastMemrefDimUnitStride(memRefType))
592  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
593  diag << "most minor stride is not 1: " << write;
594  });
595 
596  // `vector.store` supports vector types as memref's elements only when the
597  // type of the vector value being written is the same as the element type.
598  auto memrefElTy = memRefType.getElementType();
599  if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
600  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
601  diag << "elemental type mismatch: " << write;
602  });
603 
604  // Otherwise, element types of the memref and the vector must match.
605  if (!isa<VectorType>(memrefElTy) &&
606  memrefElTy != write.getVectorType().getElementType())
607  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
608  diag << "elemental type mismatch: " << write;
609  });
610 
611  // Out-of-bounds dims are handled by MaterializeTransferMask.
612  if (write.hasOutOfBoundsDim())
613  return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
614  diag << "out of bounds dim: " << write;
615  });
616  if (write.getMask()) {
617  if (write.getVectorType().getRank() != 1)
618  // vector.maskedstore operates on 1-D vectors.
619  return rewriter.notifyMatchFailure(
620  write.getLoc(), [=](Diagnostic &diag) {
621  diag << "vector type is not rank 1, can't create masked store, "
622  "needs VectorToSCF: "
623  << write;
624  });
625 
626  rewriter.create<vector::MaskedStoreOp>(
627  write.getLoc(), write.getSource(), write.getIndices(),
628  write.getMask(), write.getVector());
629  } else {
630  rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
631  write.getSource(), write.getIndices());
632  }
633  // There's no return value for StoreOps. Use Value() to signal success to
634  // matchAndRewrite.
635  return Value();
636  }
637 
638  std::optional<unsigned> maxTransferRank;
639 };
640 } // namespace
641 
643  RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
644  PatternBenefit benefit) {
645  patterns.add<TransferReadToVectorLoadLowering,
646  TransferWriteToVectorStoreLowering>(patterns.getContext(),
647  maxTransferRank, benefit);
648  patterns
649  .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
650  patterns.getContext(), benefit);
651 }
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
Definition: AffineMap.cpp:176
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
const FrozenRewritePatternSet & patterns
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:157