MLIR 22.0.0git
LowerVectorTransfer.cpp
Go to the documentation of this file.
1//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrite patterns for the permutation_map attribute of
10// vector.transfer operations.
11//
12//===----------------------------------------------------------------------===//
13
16
17using namespace mlir;
18using namespace mlir::vector;
19
20/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
21/// permutation based on the given indices.
22static ArrayAttr
24 const SmallVector<unsigned> &permutation) {
25 SmallVector<bool> newInBoundsValues(permutation.size());
26 size_t index = 0;
27 for (unsigned pos : permutation)
28 newInBoundsValues[pos] =
29 cast<BoolAttr>(attr.getValue()[index++]).getValue();
30 return builder.getBoolArrayAttr(newInBoundsValues);
31}
32
33/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
34/// dimensions.
35static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
36 int64_t addedRank) {
37 auto originalVecType = cast<VectorType>(vec.getType());
38 SmallVector<int64_t> newShape(addedRank, 1);
39 newShape.append(originalVecType.getShape().begin(),
40 originalVecType.getShape().end());
41
42 SmallVector<bool> newScalableDims(addedRank, false);
43 newScalableDims.append(originalVecType.getScalableDims().begin(),
44 originalVecType.getScalableDims().end());
45 VectorType newVecType = VectorType::get(
46 newShape, originalVecType.getElementType(), newScalableDims);
47 return vector::BroadcastOp::create(builder, loc, newVecType, vec);
48}
49
50/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
51/// dimensions.
52static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
53 int64_t addedRank) {
54 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
55 SmallVector<int64_t> permutation;
56 for (int64_t i = addedRank,
57 e = cast<VectorType>(broadcasted.getType()).getRank();
58 i < e; ++i)
59 permutation.push_back(i);
60 for (int64_t i = 0; i < addedRank; ++i)
61 permutation.push_back(i);
62 return vector::TransposeOp::create(builder, loc, broadcasted, permutation);
63}
64
65//===----------------------------------------------------------------------===//
66// populateVectorTransferPermutationMapLoweringPatterns
67//===----------------------------------------------------------------------===//
68
69namespace {
70/// Lower transfer_read op with permutation into a transfer_read with a
71/// permutation map composed of leading zeros followed by a minor identiy +
72/// vector.transpose op.
73/// Ex:
74/// vector.transfer_read ...
75/// permutation_map: (d0, d1, d2) -> (0, d1)
76/// into:
77/// %v = vector.transfer_read ...
78/// permutation_map: (d0, d1, d2) -> (d1, 0)
79/// vector.transpose %v, [1, 0]
80///
81/// vector.transfer_read ...
82/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
83/// into:
84/// %v = vector.transfer_read ...
85/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
86/// vector.transpose %v, [0, 1, 3, 2, 4]
87/// Note that an alternative is to transform it to linalg.transpose +
88/// vector.transfer_read to do the transpose in memory instead.
89struct TransferReadPermutationLowering
90 : public MaskableOpRewritePattern<vector::TransferReadOp> {
91 using MaskableOpRewritePattern::MaskableOpRewritePattern;
92
93 FailureOr<mlir::Value>
94 matchAndRewriteMaskableOp(vector::TransferReadOp op,
95 MaskingOpInterface maskOp,
96 PatternRewriter &rewriter) const override {
97 // TODO: support 0-d corner case.
98 if (op.getTransferRank() == 0)
99 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
100 // TODO: Support transfer_read inside MaskOp case.
101 if (maskOp)
102 return rewriter.notifyMatchFailure(op, "Masked case not supported");
103
104 SmallVector<unsigned> permutation;
105 AffineMap map = op.getPermutationMap();
106 if (map.getNumResults() == 0)
107 return rewriter.notifyMatchFailure(op, "0 result permutation map");
108 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
109 return rewriter.notifyMatchFailure(
110 op, "map is not permutable to minor identity, apply another pattern");
111 }
112 AffineMap permutationMap =
113 map.getPermutationMap(permutation, op.getContext());
114 if (permutationMap.isIdentity())
115 return rewriter.notifyMatchFailure(op, "map is not identity");
116
117 permutationMap = map.getPermutationMap(permutation, op.getContext());
118 // Caluclate the map of the new read by applying the inverse permutation.
119 permutationMap = inversePermutation(permutationMap);
120 AffineMap newMap = permutationMap.compose(map);
121 // Apply the reverse transpose to deduce the type of the transfer_read.
122 ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
123 SmallVector<int64_t> newVectorShape(originalShape.size());
124 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
125 SmallVector<bool> newScalableDims(originalShape.size());
126 for (const auto &pos : llvm::enumerate(permutation)) {
127 newVectorShape[pos.value()] = originalShape[pos.index()];
128 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
129 }
130
131 // Transpose in_bounds attribute.
132 ArrayAttr newInBoundsAttr =
133 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
134
135 // Generate new transfer_read operation.
136 VectorType newReadType = VectorType::get(
137 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
138 Value newRead = vector::TransferReadOp::create(
139 rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
140 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
141 newInBoundsAttr);
142
143 // Transpose result of transfer_read.
144 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145 return vector::TransposeOp::create(rewriter, op.getLoc(), newRead,
146 transposePerm)
147 .getResult();
148 }
149};
150
151/// Lower transfer_write op with permutation into a transfer_write with a
152/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
153/// Ex:
154/// vector.transfer_write %v ...
155/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
156/// into:
157/// %tmp = vector.transpose %v, [2, 0, 1]
158/// vector.transfer_write %tmp ...
159/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
160///
161/// vector.transfer_write %v ...
162/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
163/// into:
164/// %tmp = vector.transpose %v, [1, 0]
165/// %v = vector.transfer_write %tmp ...
166/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167struct TransferWritePermutationLowering
168 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
169 using MaskableOpRewritePattern::MaskableOpRewritePattern;
170
171 FailureOr<mlir::Value>
172 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
173 MaskingOpInterface maskOp,
174 PatternRewriter &rewriter) const override {
175 // TODO: support 0-d corner case.
176 if (op.getTransferRank() == 0)
177 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
178 // TODO: Support transfer_write inside MaskOp case.
179 if (maskOp)
180 return rewriter.notifyMatchFailure(op, "Masked case not supported");
181
182 SmallVector<unsigned> permutation;
183 AffineMap map = op.getPermutationMap();
184 if (map.isMinorIdentity())
185 return rewriter.notifyMatchFailure(op, "map is already minor identity");
186
187 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
188 return rewriter.notifyMatchFailure(
189 op, "map is not permutable to minor identity, apply another pattern");
190 }
191
192 // Remove unused dims from the permutation map. E.g.:
193 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
194 // comp = (d0, d1, d2) -> (d2, d0, d1)
195 auto comp = compressUnusedDims(map);
196 AffineMap permutationMap = inversePermutation(comp);
197 // Get positions of remaining result dims.
198 SmallVector<int64_t> indices;
199 llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
200 [](AffineExpr expr) {
201 return dyn_cast<AffineDimExpr>(expr).getPosition();
202 });
203
204 // Transpose in_bounds attribute.
205 ArrayAttr newInBoundsAttr =
206 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
207
208 // Generate new transfer_write operation.
209 Value newVec = vector::TransposeOp::create(rewriter, op.getLoc(),
210 op.getVector(), indices);
211 auto newMap = AffineMap::getMinorIdentityMap(
212 map.getNumDims(), map.getNumResults(), rewriter.getContext());
213 auto newWrite = vector::TransferWriteOp::create(
214 rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
215 AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
216 if (newWrite.hasPureTensorSemantics())
217 return newWrite.getResult();
218 // In the memref case there's no return value. Use empty value to signal
219 // success.
220 return Value();
221 }
222};
223
224/// Convert a transfer.write op with a map which isn't the permutation of a
225/// minor identity into a vector.broadcast + transfer_write with permutation of
226/// minor identity map by adding unit dim on inner dimension. Ex:
227/// ```
228/// vector.transfer_write %v
229/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
230/// vector<8x16xf32>
231/// ```
232/// into:
233/// ```
234/// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
235/// vector.transfer_write %v1
236/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
237/// vector<1x8x16xf32>
238/// ```
239struct TransferWriteNonPermutationLowering
240 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
241 using MaskableOpRewritePattern::MaskableOpRewritePattern;
242
243 FailureOr<mlir::Value>
244 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
245 MaskingOpInterface maskOp,
246 PatternRewriter &rewriter) const override {
247 // TODO: support 0-d corner case.
248 if (op.getTransferRank() == 0)
249 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
250 // TODO: Support transfer_write inside MaskOp case.
251 if (maskOp)
252 return rewriter.notifyMatchFailure(op, "Masked case not supported");
253
254 SmallVector<unsigned> permutation;
255 AffineMap map = op.getPermutationMap();
257 return rewriter.notifyMatchFailure(
258 op,
259 "map is already permutable to minor identity, apply another pattern");
260 }
261
262 // Missing outer dimensions are allowed, find the most outer existing
263 // dimension then deduce the missing inner dimensions.
264 SmallVector<bool> foundDim(map.getNumDims(), false);
265 for (AffineExpr exp : map.getResults())
266 foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
267 SmallVector<AffineExpr> exprs;
268 bool foundFirstDim = false;
269 SmallVector<int64_t> missingInnerDim;
270 for (size_t i = 0; i < foundDim.size(); i++) {
271 if (foundDim[i]) {
272 foundFirstDim = true;
273 continue;
274 }
275 if (!foundFirstDim)
276 continue;
277 // Once we found one outer dimension existing in the map keep track of all
278 // the missing dimensions after that.
279 missingInnerDim.push_back(i);
280 exprs.push_back(rewriter.getAffineDimExpr(i));
281 }
282 // Vector: add unit dims at the beginning of the shape.
283 Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
284 missingInnerDim.size());
285 // Mask: add unit dims at the end of the shape.
286 Value newMask;
287 if (op.getMask())
288 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
289 missingInnerDim.size());
290 exprs.append(map.getResults().begin(), map.getResults().end());
291 AffineMap newMap =
292 AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
293 // All the new dimensions added are inbound.
294 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
295 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
296 newInBoundsValues.push_back(op.isDimInBounds(i));
297 }
298 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
299 auto newWrite = vector::TransferWriteOp::create(
300 rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
301 AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
302 if (newWrite.hasPureTensorSemantics())
303 return newWrite.getResult();
304 // In the memref case there's no return value. Use empty value to signal
305 // success.
306 return Value();
307 }
308};
309
310/// Lower transfer_read op with broadcast in the leading dimensions into
311/// transfer_read of lower rank + vector.broadcast.
312/// Ex: vector.transfer_read ...
313/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
314/// into:
315/// %v = vector.transfer_read ...
316/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
317/// vector.broadcast %v
318struct TransferOpReduceRank
319 : public MaskableOpRewritePattern<vector::TransferReadOp> {
320 using MaskableOpRewritePattern::MaskableOpRewritePattern;
321
322 FailureOr<mlir::Value>
323 matchAndRewriteMaskableOp(vector::TransferReadOp op,
324 MaskingOpInterface maskOp,
325 PatternRewriter &rewriter) const override {
326 // TODO: support 0-d corner case.
327 if (op.getTransferRank() == 0)
328 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
329 // TODO: support masked case.
330 if (maskOp)
331 return rewriter.notifyMatchFailure(op, "Masked case not supported");
332
333 AffineMap map = op.getPermutationMap();
334 unsigned numLeadingBroadcast = 0;
335 for (auto expr : map.getResults()) {
336 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
337 if (!dimExpr || dimExpr.getValue() != 0)
338 break;
339 numLeadingBroadcast++;
340 }
341 // If there are no leading zeros in the map there is nothing to do.
342 if (numLeadingBroadcast == 0)
343 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
344
345 VectorType originalVecType = op.getVectorType();
346 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
347 // Calculate new map, vector type and masks without the leading zeros.
348 AffineMap newMap = AffineMap::get(
349 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
350 op.getContext());
351 // Only remove the leading zeros if the rest of the map is a minor identity
352 // with broadasting. Otherwise we first want to permute the map.
353 if (!newMap.isMinorIdentityWithBroadcasting()) {
354 return rewriter.notifyMatchFailure(
355 op, "map is not a minor identity with broadcasting");
356 }
357
358 SmallVector<int64_t> newShape(
359 originalVecType.getShape().take_back(reducedShapeRank));
360 SmallVector<bool> newScalableDims(
361 originalVecType.getScalableDims().take_back(reducedShapeRank));
362
363 VectorType newReadType = VectorType::get(
364 newShape, originalVecType.getElementType(), newScalableDims);
365 ArrayAttr newInBoundsAttr =
366 op.getInBounds()
367 ? rewriter.getArrayAttr(
368 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
369 : ArrayAttr();
370 Value newRead = vector::TransferReadOp::create(
371 rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
372 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
373 newInBoundsAttr);
374 return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
375 newRead)
376 .getVector();
377 }
378};
379
380} // namespace
381
385 .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
387 patterns.getContext(), benefit);
388}
389
390//===----------------------------------------------------------------------===//
391// populateVectorTransferLoweringPatterns
392//===----------------------------------------------------------------------===//
393
394namespace {
395/// Progressive lowering of transfer_read. This pattern supports lowering of
396/// `vector.transfer_read` to a combination of `vector.load` and
397/// `vector.broadcast` if all of the following hold:
398/// - Stride of most minor memref dimension must be 1.
399/// - Out-of-bounds masking is not required.
400/// - If the memref's element type is a vector type then it coincides with the
401/// result type.
402/// - The permutation map doesn't perform permutation (broadcasting is allowed).
403struct TransferReadToVectorLoadLowering
404 : public MaskableOpRewritePattern<vector::TransferReadOp> {
405 TransferReadToVectorLoadLowering(MLIRContext *context,
406 std::optional<unsigned> maxRank,
407 PatternBenefit benefit = 1)
408 : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
409 maxTransferRank(maxRank) {}
410
411 FailureOr<mlir::Value>
412 matchAndRewriteMaskableOp(vector::TransferReadOp read,
413 MaskingOpInterface maskOp,
414 PatternRewriter &rewriter) const override {
415 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
416 return rewriter.notifyMatchFailure(
417 read, "vector type is greater than max transfer rank");
418 }
419
420 if (maskOp)
421 return rewriter.notifyMatchFailure(read, "Masked case not supported");
422 SmallVector<unsigned> broadcastedDims;
423 // Permutations are handled by VectorToSCF or
424 // populateVectorTransferPermutationMapLoweringPatterns.
425 // We let the 0-d corner case pass-through as it is supported.
426 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
427 &broadcastedDims))
428 return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
429
430 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
431 if (!memRefType)
432 return rewriter.notifyMatchFailure(read, "not a memref source");
433
434 // Non-unit strides are handled by VectorToSCF.
435 if (!memRefType.isLastDimUnitStride())
436 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
437
438 // If there is broadcasting involved then we first load the unbroadcasted
439 // vector, and then broadcast it with `vector.broadcast`.
440 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
441 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
442 for (unsigned i : broadcastedDims)
443 unbroadcastedVectorShape[i] = 1;
444 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
445 unbroadcastedVectorShape, read.getVectorType().getElementType());
446
447 // `vector.load` supports vector types as memref's elements only when the
448 // resulting vector type is the same as the element type.
449 auto memrefElTy = memRefType.getElementType();
450 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
451 return rewriter.notifyMatchFailure(read, "incompatible element type");
452
453 // Otherwise, element types of the memref and the vector must match.
454 if (!isa<VectorType>(memrefElTy) &&
455 memrefElTy != read.getVectorType().getElementType())
456 return rewriter.notifyMatchFailure(read, "non-matching element type");
457
458 // Out-of-bounds dims are handled by MaterializeTransferMask.
459 if (read.hasOutOfBoundsDim())
460 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
461
462 // Create vector load op.
463 Operation *res;
464 if (read.getMask()) {
465 if (read.getVectorType().getRank() != 1)
466 // vector.maskedload operates on 1-D vectors.
467 return rewriter.notifyMatchFailure(
468 read, "vector type is not rank 1, can't create masked load, needs "
469 "VectorToSCF");
470
471 Value fill = vector::BroadcastOp::create(
472 rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
473 res = vector::MaskedLoadOp::create(
474 rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
475 read.getIndices(), read.getMask(), fill);
476 } else {
477 res = vector::LoadOp::create(rewriter, read.getLoc(),
478 unbroadcastedVectorType, read.getBase(),
479 read.getIndices());
480 }
481
482 // Insert a broadcasting op if required.
483 if (!broadcastedDims.empty())
484 res = vector::BroadcastOp::create(
485 rewriter, read.getLoc(), read.getVectorType(), res->getResult(0));
486 return res->getResult(0);
487 }
488
489 std::optional<unsigned> maxTransferRank;
490};
491
492/// Progressive lowering of transfer_write. This pattern supports lowering of
493/// `vector.transfer_write` to `vector.store` if all of the following hold:
494/// - Stride of most minor memref dimension must be 1.
495/// - Out-of-bounds masking is not required.
496/// - If the memref's element type is a vector type then it coincides with the
497/// type of the written value.
498/// - The permutation map is the minor identity map (neither permutation nor
499/// broadcasting is allowed).
500struct TransferWriteToVectorStoreLowering
501 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
502 TransferWriteToVectorStoreLowering(MLIRContext *context,
503 std::optional<unsigned> maxRank,
504 PatternBenefit benefit = 1)
505 : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
506 maxTransferRank(maxRank) {}
507
508 FailureOr<mlir::Value>
509 matchAndRewriteMaskableOp(vector::TransferWriteOp write,
510 MaskingOpInterface maskOp,
511 PatternRewriter &rewriter) const override {
512 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
513 return rewriter.notifyMatchFailure(
514 write, "vector type is greater than max transfer rank");
515 }
516 if (maskOp)
517 return rewriter.notifyMatchFailure(write, "Masked case not supported");
518
519 // Permutations are handled by VectorToSCF or
520 // populateVectorTransferPermutationMapLoweringPatterns.
521 if ( // pass-through for the 0-d corner case.
522 !write.getPermutationMap().isMinorIdentity())
523 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
524 diag << "permutation map is not minor identity: " << write;
525 });
526
527 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
528 if (!memRefType)
529 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
530 diag << "not a memref type: " << write;
531 });
532
533 // Non-unit strides are handled by VectorToSCF.
534 if (!memRefType.isLastDimUnitStride())
535 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
536 diag << "most minor stride is not 1: " << write;
537 });
538
539 // `vector.store` supports vector types as memref's elements only when the
540 // type of the vector value being written is the same as the element type.
541 auto memrefElTy = memRefType.getElementType();
542 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
543 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
544 diag << "elemental type mismatch: " << write;
545 });
546
547 // Otherwise, element types of the memref and the vector must match.
548 if (!isa<VectorType>(memrefElTy) &&
549 memrefElTy != write.getVectorType().getElementType())
550 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
551 diag << "elemental type mismatch: " << write;
552 });
553
554 // Out-of-bounds dims are handled by MaterializeTransferMask.
555 if (write.hasOutOfBoundsDim())
556 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
557 diag << "out of bounds dim: " << write;
558 });
559 if (write.getMask()) {
560 if (write.getVectorType().getRank() != 1)
561 // vector.maskedstore operates on 1-D vectors.
562 return rewriter.notifyMatchFailure(
563 write.getLoc(), [=](Diagnostic &diag) {
564 diag << "vector type is not rank 1, can't create masked store, "
565 "needs VectorToSCF: "
566 << write;
567 });
568
569 vector::MaskedStoreOp::create(rewriter, write.getLoc(), write.getBase(),
570 write.getIndices(), write.getMask(),
571 write.getVector());
572 } else {
573 vector::StoreOp::create(rewriter, write.getLoc(), write.getVector(),
574 write.getBase(), write.getIndices());
575 }
576 // There's no return value for StoreOps. Use Value() to signal success to
577 // matchAndRewrite.
578 return Value();
579 }
580
581 std::optional<unsigned> maxTransferRank;
582};
583} // namespace
584
586 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
587 PatternBenefit benefit) {
588 patterns.add<TransferReadToVectorLoadLowering,
589 TransferWriteToVectorStoreLowering>(patterns.getContext(),
590 maxTransferRank, benefit);
591}
ArrayAttr()
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static std::optional< VectorShape > vectorShape(Type type)
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.