MLIR 23.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with optimizing transfer_read and
10// transfer_write ops.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/Dominance.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/Operation.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/StringRef.h"
30#include "llvm/Support/DebugLog.h"
31
32#define DEBUG_TYPE "vector-transfer-opt"
33
34using namespace mlir;
35
36/// Return the ancestor op in the region or nullptr if the region is not
37/// an ancestor of the op.
39 LDBG() << " Finding ancestor of " << *op << " in region";
40 for (; op != nullptr && op->getParentRegion() != region;
41 op = op->getParentOp())
42 ;
43 if (op) {
44 LDBG() << " -> Ancestor: " << *op;
45 } else {
46 LDBG() << " -> Ancestor: nullptr";
47 }
48 return op;
49}
50
51namespace {
52
53class TransferOptimization {
54public:
55 TransferOptimization(RewriterBase &rewriter, Operation *op)
56 : rewriter(rewriter), dominators(op), postDominators(op) {}
57 void deadStoreOp(vector::TransferWriteOp);
58 void storeToLoadForwarding(vector::TransferReadOp);
59 void removeDeadOp() {
60 LDBG() << "Removing " << opToErase.size() << " dead operations";
61 for (Operation *op : opToErase) {
62 LDBG() << " -> Erasing: " << *op;
63 rewriter.eraseOp(op);
64 }
65 opToErase.clear();
66 }
67
68private:
69 RewriterBase &rewriter;
70 bool isReachable(Operation *start, Operation *dest);
71 DominanceInfo dominators;
72 PostDominanceInfo postDominators;
73 std::vector<Operation *> opToErase;
74};
75
76} // namespace
77/// Return true if there is a path from start operation to dest operation,
78/// otherwise return false. The operations have to be in the same region.
79bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
80 LDBG() << " Checking reachability from " << *start << " to " << *dest;
81 assert(start->getParentRegion() == dest->getParentRegion() &&
82 "This function only works for ops i the same region");
83 // Simple case where the start op dominate the destination.
84 if (dominators.dominates(start, dest)) {
85 LDBG() << " -> Start dominates dest, reachable";
86 return true;
87 }
88 bool blockReachable = start->getBlock()->isReachable(dest->getBlock());
89 LDBG() << " -> Block reachable: " << blockReachable;
90 return blockReachable;
91}
92
93/// For transfer_write to overwrite fully another transfer_write must:
94/// 1. Access the same memref with the same indices and vector type.
95/// 2. Post-dominate the other transfer_write operation.
96/// If several candidates are available, one must be post-dominated by all the
97/// others since they are all post-dominating the same transfer_write. We only
98/// consider the transfer_write post-dominated by all the other candidates as
99/// this will be the first transfer_write executed after the potentially dead
100/// transfer_write.
101/// If we found such an overwriting transfer_write we know that the original
102/// transfer_write is dead if all reads that can be reached from the potentially
103/// dead transfer_write are dominated by the overwriting transfer_write.
104void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
105 LDBG() << "=== Starting deadStoreOp analysis for: " << *write.getOperation();
106 llvm::SmallVector<Operation *, 8> blockingAccesses;
107 Operation *firstOverwriteCandidate = nullptr;
108 Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
109 LDBG() << "Source memref (after skipping view-like ops): " << source;
110 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
111 source.getUsers().end());
112 LDBG() << "Found " << users.size() << " users of source memref";
113 llvm::SmallDenseSet<Operation *, 32> processed;
114 while (!users.empty()) {
115 Operation *user = users.pop_back_val();
116 LDBG() << "Processing user: " << *user;
117 // If the user has already been processed skip.
118 if (!processed.insert(user).second) {
119 LDBG() << " -> Already processed, skipping";
120 continue;
121 }
122 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
123 LDBG() << " -> View-like operation, following to destination";
124 Value viewDest = viewLike.getViewDest();
125 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
126 continue;
127 }
128 if (isMemoryEffectFree(user)) {
129 LDBG() << " -> Memory effect free, skipping";
130 continue;
131 }
132 if (user == write.getOperation()) {
133 LDBG() << " -> Same as write operation, skipping";
134 continue;
135 }
136 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
137 LDBG() << " -> Found transfer_write candidate: " << *nextWrite;
138 // Check candidate that can override the store.
139 bool sameView = memref::isSameViewOrTrivialAlias(
140 cast<MemrefValue>(nextWrite.getBase()),
141 cast<MemrefValue>(write.getBase()));
142 bool sameValue = checkSameValueWAW(nextWrite, write);
143 bool postDominates = postDominators.postDominates(nextWrite, write);
144 LDBG() << " -> Same view: " << sameView
145 << ", Same value: " << sameValue
146 << ", Post-dominates: " << postDominates;
147
148 if (sameView && sameValue && postDominates) {
149 LDBG() << " -> Valid overwrite candidate found";
150 if (firstOverwriteCandidate == nullptr ||
151 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) {
152 LDBG() << " -> New first overwrite candidate: " << *nextWrite;
153 firstOverwriteCandidate = nextWrite;
154 } else {
155 LDBG() << " -> Keeping existing first overwrite candidate";
156 assert(
157 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
158 }
159 continue;
160 }
161 LDBG() << " -> Not a valid overwrite candidate";
162 }
163 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
164 LDBG() << " -> Found vector transfer operation: " << *transferOp;
165 // Don't need to consider disjoint accesses.
166 bool isDisjoint = vector::isDisjointTransferSet(
167 cast<VectorTransferOpInterface>(write.getOperation()),
168 cast<VectorTransferOpInterface>(transferOp.getOperation()),
169 /*testDynamicValueUsingBounds=*/true);
170 LDBG() << " -> Is disjoint: " << isDisjoint;
171 if (isDisjoint) {
172 LDBG() << " -> Skipping disjoint access";
173 continue;
174 }
175 }
176 LDBG() << " -> Adding to blocking accesses: " << *user;
177 blockingAccesses.push_back(user);
178 }
179 LDBG() << "Finished processing users. Found " << blockingAccesses.size()
180 << " blocking accesses";
181
182 if (firstOverwriteCandidate == nullptr) {
183 LDBG() << "No overwrite candidate found, store is not dead";
184 return;
185 }
186
187 LDBG() << "First overwrite candidate: " << *firstOverwriteCandidate;
188 Region *topRegion = firstOverwriteCandidate->getParentRegion();
189 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
190 assert(writeAncestor &&
191 "write op should be recursively part of the top region");
192 LDBG() << "Write ancestor in top region: " << *writeAncestor;
193
194 LDBG() << "Checking " << blockingAccesses.size()
195 << " blocking accesses for reachability";
196 for (Operation *access : blockingAccesses) {
197 LDBG() << "Checking blocking access: " << *access;
198 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
199 // TODO: if the access and write have the same ancestor we could recurse in
200 // the region to know if the access is reachable with more precision.
201 if (accessAncestor == nullptr) {
202 LDBG() << " -> No ancestor in top region, skipping";
203 continue;
204 }
205
206 bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
207 LDBG() << " -> Is reachable from write: " << isReachableFromWrite;
208 if (!isReachableFromWrite) {
209 LDBG() << " -> Not reachable, skipping";
210 continue;
211 }
212
213 bool overwriteDominatesAccess =
214 dominators.dominates(firstOverwriteCandidate, accessAncestor);
215 LDBG() << " -> Overwrite dominates access: " << overwriteDominatesAccess;
216 if (!overwriteDominatesAccess) {
217 LDBG() << "Store may not be dead due to op: " << *accessAncestor;
218 return;
219 }
220 LDBG() << " -> Access is dominated by overwrite, continuing";
221 }
222 LDBG() << "Found dead store: " << *write.getOperation()
223 << " overwritten by: " << *firstOverwriteCandidate;
224 opToErase.push_back(write.getOperation());
225}
226
227/// A transfer_write candidate to storeToLoad forwarding must:
228/// 1. Access the same memref with the same indices and vector type as the
229/// transfer_read.
230/// 2. Dominate the transfer_read operation.
231/// If several candidates are available, one must be dominated by all the others
232/// since they are all dominating the same transfer_read. We only consider the
233/// transfer_write dominated by all the other candidates as this will be the
234/// last transfer_write executed before the transfer_read.
235/// If we found such a candidate we can do the forwarding if all the other
236/// potentially aliasing ops that may reach the transfer_read are post-dominated
237/// by the transfer_write.
238void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
239 LDBG() << "=== Starting storeToLoadForwarding analysis for: "
240 << *read.getOperation();
241 if (read.hasOutOfBoundsDim()) {
242 LDBG() << "Read has out-of-bounds dimensions, skipping";
243 return;
244 }
245 SmallVector<Operation *, 8> blockingWrites;
246 vector::TransferWriteOp lastwrite = nullptr;
247 Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
248 LDBG() << "Source memref (after skipping view-like ops): " << source;
249 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
250 source.getUsers().end());
251 LDBG() << "Found " << users.size() << " users of source memref";
252 llvm::SmallDenseSet<Operation *, 32> processed;
253 while (!users.empty()) {
254 Operation *user = users.pop_back_val();
255 LDBG() << "Processing user: " << *user;
256 // If the user has already been processed skip.
257 if (!processed.insert(user).second) {
258 LDBG() << " -> Already processed, skipping";
259 continue;
260 }
261 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
262 LDBG() << " -> View-like operation, following to destination";
263 Value viewDest = viewLike.getViewDest();
264 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
265 continue;
266 }
267 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) {
268 LDBG() << " -> Memory effect free or transfer_read, skipping";
269 continue;
270 }
271 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
272 LDBG() << " -> Found transfer_write candidate: " << *write;
273 // If there is a write, but we can prove that it is disjoint we can ignore
274 // the write.
275 bool isDisjoint = vector::isDisjointTransferSet(
276 cast<VectorTransferOpInterface>(write.getOperation()),
277 cast<VectorTransferOpInterface>(read.getOperation()),
278 /*testDynamicValueUsingBounds=*/true);
279 LDBG() << " -> Is disjoint: " << isDisjoint;
280 if (isDisjoint) {
281 LDBG() << " -> Skipping disjoint write";
282 continue;
283 }
284
285 bool sameView =
286 memref::isSameViewOrTrivialAlias(cast<MemrefValue>(read.getBase()),
287 cast<MemrefValue>(write.getBase()));
288 bool dominates = dominators.dominates(write, read);
289 bool sameValue = checkSameValueRAW(write, read);
290 LDBG() << " -> Same view: " << sameView << ", Dominates: " << dominates
291 << ", Same value: " << sameValue;
292
293 if (sameView && dominates && sameValue) {
294 LDBG() << " -> Valid forwarding candidate found";
295 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) {
296 LDBG() << " -> New last write candidate: " << *write;
297 lastwrite = write;
298 } else {
299 LDBG() << " -> Keeping existing last write candidate";
300 assert(dominators.dominates(write, lastwrite));
301 }
302 continue;
303 }
304 LDBG() << " -> Not a valid forwarding candidate";
305 }
306 LDBG() << " -> Adding to blocking writes: " << *user;
307 blockingWrites.push_back(user);
308 }
309 LDBG() << "Finished processing users. Found " << blockingWrites.size()
310 << " blocking writes";
311
312 if (lastwrite == nullptr) {
313 LDBG() << "No last write candidate found, cannot forward";
314 return;
315 }
316
317 LDBG() << "Last write candidate: " << *lastwrite;
318 Region *topRegion = lastwrite->getParentRegion();
319 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
320 assert(readAncestor &&
321 "read op should be recursively part of the top region");
322 LDBG() << "Read ancestor in top region: " << *readAncestor;
323
324 LDBG() << "Checking " << blockingWrites.size()
325 << " blocking writes for post-dominance";
326 for (Operation *write : blockingWrites) {
327 LDBG() << "Checking blocking write: " << *write;
328 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
329 if (writeAncestor) {
330 LDBG() << " -> Write ancestor: " << *writeAncestor;
331 } else {
332 LDBG() << " -> Write ancestor: nullptr";
333 }
334
335 // TODO: if the store and read have the same ancestor we could recurse in
336 // the region to know if the read is reachable with more precision.
337 if (writeAncestor == nullptr) {
338 LDBG() << " -> No ancestor in top region, skipping";
339 continue;
340 }
341
342 bool isReachableToRead = isReachable(writeAncestor, readAncestor);
343 LDBG() << " -> Is reachable to read: " << isReachableToRead;
344 if (!isReachableToRead) {
345 LDBG() << " -> Not reachable, skipping";
346 continue;
347 }
348
349 bool lastWritePostDominates =
350 postDominators.postDominates(lastwrite, write);
351 LDBG() << " -> Last write post-dominates blocking write: "
352 << lastWritePostDominates;
353 if (!lastWritePostDominates) {
354 LDBG() << "Fail to do write to read forwarding due to op: " << *write;
355 return;
356 }
357 LDBG() << " -> Blocking write is post-dominated, continuing";
358 }
359
360 LDBG() << "Forward value from " << *lastwrite.getOperation()
361 << " to: " << *read.getOperation();
362 read.replaceAllUsesWith(lastwrite.getVector());
363 opToErase.push_back(read.getOperation());
364}
365
366/// Converts OpFoldResults to int64_t shape without unit dims.
368 SmallVector<int64_t> reducedShape;
369 for (const auto size : mixedSizes) {
370 if (llvm::dyn_cast_if_present<Value>(size)) {
371 reducedShape.push_back(ShapedType::kDynamic);
372 continue;
373 }
374
375 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
376 if (value == 1)
377 continue;
378 reducedShape.push_back(value.getSExtValue());
379 }
380 return reducedShape;
381}
382
383/// Drops unit dimensions from the input MemRefType.
384static MemRefType dropUnitDims(MemRefType inputType,
387 ArrayRef<OpFoldResult> strides) {
388 auto targetShape = getReducedShape(sizes);
389 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
390 targetShape, inputType, offsets, sizes, strides);
391 return rankReducedType.canonicalizeStridedLayout();
392}
393
394/// Creates a rank-reducing memref.subview op that drops unit dims from its
395/// input. Or just returns the input if it was already without unit dims.
397 mlir::Location loc,
398 Value input) {
399 MemRefType inputType = cast<MemRefType>(input.getType());
400 SmallVector<OpFoldResult> offsets(inputType.getRank(),
401 rewriter.getIndexAttr(0));
402 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
403 SmallVector<OpFoldResult> strides(inputType.getRank(),
404 rewriter.getIndexAttr(1));
405 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
406
407 if (resultType.canonicalizeStridedLayout() ==
408 inputType.canonicalizeStridedLayout())
409 return input;
410 return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
411 sizes, strides);
412}
413
414/// Returns the number of dims that aren't unit dims.
416 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
417}
418
419/// Trims non-scalable one dimensions from `oldType` and returns the result
420/// type.
421static VectorType trimNonScalableUnitDims(VectorType oldType) {
422 SmallVector<int64_t> newShape;
423 SmallVector<bool> newScalableDims;
424 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
425 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
426 continue;
427 newShape.push_back(dimSize);
428 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
429 }
430 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
431}
432
433static bool isUnitDimMask(Value maskDimSize) {
434 return matchPattern(maskDimSize, m_One());
435}
436static bool isUnitDimMask(int64_t maskDimSize) { return maskDimSize == 1; }
437
438/// Rewrites a mask op to drop non-scalable unit dimensions.
439/// Supports vector.create_mask and vector.constant_mask.
440template <typename MaskOp>
441static FailureOr<Value> maskDropNonScalableUnitDims(PatternRewriter &rewriter,
442 Location loc, MaskOp op) {
443 auto type = op.getType();
444 VectorType reducedType = trimNonScalableUnitDims(type);
445 if (reducedType.getRank() == type.getRank())
446 return failure();
447
448 using ElemType = std::decay_t<decltype(*op.getMaskDimSizes().begin())>;
449 SmallVector<ElemType> reduced;
450 for (auto [dim, dimIsScalable, elem] : llvm::zip_equal(
451 type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
452 if (dim == 1 && !dimIsScalable) {
453 if (!isUnitDimMask(elem))
454 return failure();
455 continue;
456 }
457 reduced.push_back(elem);
458 }
459 return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
460}
461
462namespace {
463
464/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
465/// inserting a memref.subview dropping those unit dims. The vector shapes are
466/// also reduced accordingly.
467class TransferReadDropUnitDimsPattern
468 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
469 using MaskableOpRewritePattern::MaskableOpRewritePattern;
470
471 FailureOr<Value>
472 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
473 vector::MaskingOpInterface maskingOp,
474 PatternRewriter &rewriter) const override {
475 LDBG() << "=== TransferReadDropUnitDimsPattern: Analyzing "
476 << *transferReadOp;
477 auto loc = transferReadOp.getLoc();
478 Value vector = transferReadOp.getVector();
479 VectorType vectorType = cast<VectorType>(vector.getType());
480 Value source = transferReadOp.getBase();
481 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
482 // TODO: support tensor types.
483 if (!sourceType) {
484 LDBG() << " -> Not a MemRefType, skipping";
485 return failure();
486 }
487 // TODO: generalize this pattern, relax the requirements here.
488 if (transferReadOp.hasOutOfBoundsDim()) {
489 LDBG() << " -> Has out-of-bounds dimensions, skipping";
490 return failure();
491 }
492 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
493 LDBG() << " -> Not minor identity permutation map, skipping";
494 return failure();
495 }
496 // Check if the source shape can be further reduced.
497 int reducedRank = getReducedRank(sourceType.getShape());
498 LDBG() << " -> Source rank: " << sourceType.getRank()
499 << ", Reduced rank: " << reducedRank;
500 if (reducedRank == sourceType.getRank()) {
501 LDBG() << " -> No unit dimensions to drop, skipping";
502 return failure();
503 }
504 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
505 // out.
506 if (reducedRank == 0 && maskingOp) {
507 LDBG() << " -> 0-d vector with masking not supported, skipping";
508 return failure();
509 }
510 // Check if the reduced vector shape matches the reduced source shape.
511 // Otherwise, this case is not supported yet.
512 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
513 LDBG() << " -> Vector type: " << vectorType
514 << ", Reduced vector type: " << reducedVectorType;
515 if (reducedRank != reducedVectorType.getRank()) {
516 LDBG() << " -> Reduced ranks don't match, skipping";
517 return failure();
518 }
519 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
520 return getConstantIntValue(v) != static_cast<int64_t>(0);
521 })) {
522 LDBG() << " -> Non-zero indices found, skipping";
523 return failure();
524 }
525
526 Value maskOp = transferReadOp.getMask();
527 if (maskOp) {
528 LDBG() << " -> Processing mask operation";
529 FailureOr<Value> rankReducedMaskOp = failure();
530 if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
531 rankReducedMaskOp =
532 maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
533 else if (auto constantMaskOp =
534 maskOp.getDefiningOp<vector::ConstantMaskOp>())
535 rankReducedMaskOp =
536 maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
537
538 if (failed(rankReducedMaskOp)) {
539 LDBG() << " -> Failed to reduce mask dimensions";
540 return rewriter.notifyMatchFailure(
541 transferReadOp,
542 "unsupported mask op, only 'vector.create_mask' and "
543 "'vector.constant_mask' are currently supported");
544 }
545 maskOp = *rankReducedMaskOp;
546 LDBG() << " -> Successfully reduced mask dimensions";
547 }
548
549 LDBG() << " -> Creating rank-reduced subview and new transfer_read";
550 Value reducedShapeSource =
551 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
552 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
553 Repeated<Value> zeros(reducedRank, c0);
554 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
555 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
556 Operation *newTransferReadOp = vector::TransferReadOp::create(
557 rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
558 identityMap, transferReadOp.getPadding(), maskOp,
559 rewriter.getBoolArrayAttr(inBounds));
560 LDBG() << " -> Created new transfer_read: " << *newTransferReadOp;
561
562 if (maskingOp) {
563 LDBG() << " -> Applying masking operation";
564 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
565 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
566 maskingOp.getMask());
567 newTransferReadOp = mlir::vector::maskOperation(
568 rewriter, newTransferReadOp, shapeCastMask);
569 }
570
571 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
572 loc, vectorType, newTransferReadOp->getResults()[0]);
573 LDBG() << " -> Created shape cast: " << *shapeCast.getDefiningOp();
574 LDBG() << " -> Pattern match successful, returning result";
575
576 return shapeCast;
577 }
578};
579
580/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
581/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
582/// vector shapes are also reduced accordingly.
583class TransferWriteDropUnitDimsPattern
584 : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
585 using MaskableOpRewritePattern::MaskableOpRewritePattern;
586
587 FailureOr<Value>
588 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
589 vector::MaskingOpInterface maskingOp,
590 PatternRewriter &rewriter) const override {
591 LDBG() << "=== TransferWriteDropUnitDimsPattern: Analyzing "
592 << *transferWriteOp;
593 auto loc = transferWriteOp.getLoc();
594 Value vector = transferWriteOp.getVector();
595 VectorType vectorType = cast<VectorType>(vector.getType());
596 Value source = transferWriteOp.getBase();
597 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
598 // TODO: support tensor type.
599 if (!sourceType) {
600 LDBG() << " -> Not a MemRefType, skipping";
601 return failure();
602 }
603 // TODO: generalize this pattern, relax the requirements here.
604 if (transferWriteOp.hasOutOfBoundsDim()) {
605 LDBG() << " -> Has out-of-bounds dimensions, skipping";
606 return failure();
607 }
608 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
609 LDBG() << " -> Not minor identity permutation map, skipping";
610 return failure();
611 }
612 // Check if the destination shape can be further reduced.
613 int reducedRank = getReducedRank(sourceType.getShape());
614 LDBG() << " -> Source rank: " << sourceType.getRank()
615 << ", Reduced rank: " << reducedRank;
616 if (reducedRank == sourceType.getRank()) {
617 LDBG() << " -> No unit dimensions to drop, skipping";
618 return failure();
619 }
620 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
621 // out.
622 if (reducedRank == 0 && maskingOp) {
623 LDBG() << " -> 0-d vector with masking not supported, skipping";
624 return failure();
625 }
626 // Check if the reduced vector shape matches the reduced destination shape.
627 // Otherwise, this case is not supported yet.
628 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
629 LDBG() << " -> Vector type: " << vectorType
630 << ", Reduced vector type: " << reducedVectorType;
631 if (reducedRank != reducedVectorType.getRank()) {
632 LDBG() << " -> Reduced ranks don't match, skipping";
633 return failure();
634 }
635 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
636 return getConstantIntValue(v) != static_cast<int64_t>(0);
637 })) {
638 LDBG() << " -> Non-zero indices found, skipping";
639 return failure();
640 }
641
642 Value maskOp = transferWriteOp.getMask();
643 if (maskOp) {
644 LDBG() << " -> Processing mask operation";
645 FailureOr<Value> rankReducedMask = failure();
646 if (auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>())
647 rankReducedMask =
648 maskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
649 else if (auto constantMaskOp =
650 maskOp.getDefiningOp<vector::ConstantMaskOp>())
651 rankReducedMask =
652 maskDropNonScalableUnitDims(rewriter, loc, constantMaskOp);
653
654 if (failed(rankReducedMask)) {
655 LDBG() << " -> Failed to reduce mask dimensions";
656 return rewriter.notifyMatchFailure(
657 transferWriteOp,
658 "unsupported mask op, only 'vector.create_mask' and "
659 "'vector.constant_mask' are currently supported");
660 }
661 maskOp = *rankReducedMask;
662 LDBG() << " -> Successfully reduced mask dimensions";
663 }
664 LDBG() << " -> Creating rank-reduced subview and new transfer_write";
665 Value reducedShapeSource =
666 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
667 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
668 Repeated<Value> zeros(reducedRank, c0);
669 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
670 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
671 auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
672 loc, reducedVectorType, vector);
673 Operation *newXferWrite = vector::TransferWriteOp::create(
674 rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
675 identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
676 LDBG() << " -> Created new transfer_write: " << *newXferWrite;
677
678 if (maskingOp) {
679 LDBG() << " -> Applying masking operation";
680 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
681 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
682 maskingOp.getMask());
683 newXferWrite =
684 mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
685 }
686
687 if (transferWriteOp.hasPureTensorSemantics()) {
688 LDBG() << " -> Pattern match successful (tensor semantics), returning "
689 "result";
690 return newXferWrite->getResults()[0];
691 }
692
693 // With Memref semantics, there's no return value. Use empty value to signal
694 // success.
695 LDBG() << " -> Pattern match successful (memref semantics)";
696 return Value();
697 }
698};
699
700} // namespace
701
702/// Creates a memref.collapse_shape collapsing all inner dimensions of the
703/// input starting at `firstDimToCollapse`.
705 Value input, int64_t firstDimToCollapse) {
706 ShapedType inputType = cast<ShapedType>(input.getType());
707 if (inputType.getRank() == 1)
708 return input;
710 for (int64_t i = 0; i < firstDimToCollapse; ++i)
711 reassociation.push_back(ReassociationIndices{i});
712 ReassociationIndices collapsedIndices;
713 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
714 collapsedIndices.push_back(i);
715 reassociation.push_back(collapsedIndices);
716 return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
717}
718
719/// Returns the new indices that collapses the inner dimensions starting from
720/// the `firstDimToCollapse` dimension.
722 Location loc,
725 int64_t firstDimToCollapse) {
726 assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
727
728 // If all the collapsed indices are zero then no extra logic is needed.
729 // Otherwise, a new offset/index has to be computed.
730 SmallVector<Value> indicesAfterCollapsing(
731 indices.begin(), indices.begin() + firstDimToCollapse);
732 SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
733 indices.end());
734 if (llvm::all_of(indicesToCollapse, isZeroInteger)) {
735 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
736 return indicesAfterCollapsing;
737 }
738
739 // Compute the remaining trailing index/offset required for reading from
740 // the collapsed memref:
741 //
742 // offset = 0
743 // for (i = firstDimToCollapse; i < outputRank; ++i)
744 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
745 //
746 // For this example:
747 // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
748 // memref<1x43x2xi32>, vector<1x2xi32>
749 // which would be collapsed to:
750 // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
751 // memref<1x86xi32>, vector<2xi32>
752 // one would get the following offset:
753 // %offset = %arg0 * 43
754 OpFoldResult collapsedOffset =
755 arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
756
757 auto collapsedStrides = computeSuffixProduct(
758 ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
759
760 // Compute the collapsed offset.
761 auto &&[collapsedExpr, collapsedVals] =
762 computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
764 rewriter, loc, collapsedExpr, collapsedVals);
765
766 if (auto value = dyn_cast<Value>(collapsedOffset)) {
767 indicesAfterCollapsing.push_back(value);
768 } else {
769 indicesAfterCollapsing.push_back(arith::ConstantIndexOp::create(
770 rewriter, loc, *getConstantIntValue(collapsedOffset)));
771 }
772
773 return indicesAfterCollapsing;
774}
775
776namespace {
777/// Rewrites contiguous row-major vector.transfer_read ops by inserting
778/// memref.collapse_shape on the source so that the resulting
779/// vector.transfer_read has a 1D source. Requires the source shape to be
780/// already reduced i.e. without unit dims.
781///
782/// If `targetVectorBitwidth` is provided, the flattening will only happen if
783/// the trailing dimension of the vector read is smaller than the provided
784/// bitwidth.
785class FlattenContiguousRowMajorTransferReadPattern
786 : public OpRewritePattern<vector::TransferReadOp> {
787public:
788 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
789 unsigned vectorBitwidth,
790 PatternBenefit benefit)
791 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
792 targetVectorBitwidth(vectorBitwidth) {}
793
794 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
795 PatternRewriter &rewriter) const override {
796 LDBG() << "=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
797 << *transferReadOp;
798 auto loc = transferReadOp.getLoc();
799 Value vector = transferReadOp.getVector();
800 VectorType vectorType = cast<VectorType>(vector.getType());
801 auto source = transferReadOp.getBase();
802 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
803
804 // 0. Check pre-conditions
805 // Contiguity check is valid on tensors only.
806 if (!sourceType) {
807 LDBG() << " -> Not a MemRefType, skipping";
808 return failure();
809 }
810 // If this is already 0D/1D, there's nothing to do.
811 if (vectorType.getRank() <= 1) {
812 LDBG() << " -> Already 0D/1D, skipping";
813 return failure();
814 }
815 if (!vectorType.getElementType().isSignlessIntOrFloat()) {
816 LDBG() << " -> Not signless int or float, skipping";
817 return failure();
818 }
819 unsigned trailingVectorDimBitwidth =
820 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
821 LDBG() << " -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
822 << ", target: " << targetVectorBitwidth;
823 if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
824 LDBG() << " -> Trailing dim bitwidth >= target, skipping";
825 return failure();
826 }
827 if (!vector::isContiguousSlice(sourceType, vectorType)) {
828 LDBG() << " -> Not contiguous slice, skipping";
829 return failure();
830 }
831 // TODO: generalize this pattern, relax the requirements here.
832 if (transferReadOp.hasOutOfBoundsDim()) {
833 LDBG() << " -> Has out-of-bounds dimensions, skipping";
834 return failure();
835 }
836 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
837 LDBG() << " -> Not minor identity permutation map, skipping";
838 return failure();
839 }
840 if (transferReadOp.getMask()) {
841 LDBG() << " -> Has mask, skipping";
842 return failure();
843 }
844
845 // Determine vector dimensions to collapse.
846 // Ignore a leading sequence of adjacent unit dimensions in the vector.
847 ArrayRef<int64_t> collapsedVectorShape =
848 vectorType.getShape().drop_while([](auto v) { return v == 1; });
849 size_t collapsedVecRank = collapsedVectorShape.size();
850 // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
851 // to a 1D single-element vector.
852 if (collapsedVecRank == 0)
853 collapsedVecRank = 1;
854
855 // Determine the first memref dimension to collapse - just enough so we can
856 // read a flattened vector.
857 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
858 LDBG() << " -> First dimension to collapse: " << firstDimToCollapse;
859
860 // 1. Collapse the source memref
861 LDBG() << " -> Collapsing source memref";
862 Value collapsedSource =
863 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
864 MemRefType collapsedSourceType =
865 cast<MemRefType>(collapsedSource.getType());
866 int64_t collapsedRank = collapsedSourceType.getRank();
867 assert(collapsedRank == firstDimToCollapse + 1);
868 LDBG() << " -> Collapsed source type: " << collapsedSourceType;
869
870 // 2. Generate input args for a new vector.transfer_read that will read
871 // from the collapsed memref.
872 // 2.1. New dim exprs + affine map
873 SmallVector<AffineExpr, 1> dimExprs{
874 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
875 auto collapsedMap =
876 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
877
878 // 2.2 New indices
879 SmallVector<Value> collapsedIndices =
880 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
881 transferReadOp.getIndices(), firstDimToCollapse);
882
883 // 3. Create new vector.transfer_read that reads from the collapsed memref
884 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
885 vectorType.getElementType());
886 LDBG() << " -> Creating flattened vector type: " << flatVectorType;
887 vector::TransferReadOp flatRead = vector::TransferReadOp::create(
888 rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
889 transferReadOp.getPadding(), collapsedMap);
890 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
891 LDBG() << " -> Created flat transfer_read: " << *flatRead;
892
893 // 4. Replace the old transfer_read with the new one reading from the
894 // collapsed shape
895 LDBG() << " -> Replacing with shape cast";
896 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
897 transferReadOp, cast<VectorType>(vector.getType()), flatRead);
898 LDBG() << " -> Pattern match successful";
899 return success();
900 }
901
902private:
903 // Minimum bitwidth that the trailing vector dimension should have after
904 // flattening.
905 unsigned targetVectorBitwidth;
906};
907
908/// Rewrites contiguous row-major vector.transfer_write ops by inserting
909/// memref.collapse_shape on the source so that the resulting
910/// vector.transfer_write has a 1D source. Requires the source shape to be
911/// already reduced i.e. without unit dims.
912///
913/// If `targetVectorBitwidth` is provided, the flattening will only happen if
914/// the trailing dimension of the vector read is smaller than the provided
915/// bitwidth.
916class FlattenContiguousRowMajorTransferWritePattern
917 : public OpRewritePattern<vector::TransferWriteOp> {
918public:
919 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
920 unsigned vectorBitwidth,
921 PatternBenefit benefit)
922 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
923 targetVectorBitwidth(vectorBitwidth) {}
924
925 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
926 PatternRewriter &rewriter) const override {
927 auto loc = transferWriteOp.getLoc();
928 Value vector = transferWriteOp.getVector();
929 VectorType vectorType = cast<VectorType>(vector.getType());
930 Value source = transferWriteOp.getBase();
931 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
932
933 // 0. Check pre-conditions
934 // Contiguity check is valid on tensors only.
935 if (!sourceType)
936 return failure();
937 // If this is already 0D/1D, there's nothing to do.
938 if (vectorType.getRank() <= 1)
939 // Already 0D/1D, nothing to do.
940 return failure();
941 if (!vectorType.getElementType().isSignlessIntOrFloat())
942 return failure();
943 unsigned trailingVectorDimBitwidth =
944 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
945 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
946 return failure();
947 if (!vector::isContiguousSlice(sourceType, vectorType))
948 return failure();
949 // TODO: generalize this pattern, relax the requirements here.
950 if (transferWriteOp.hasOutOfBoundsDim())
951 return failure();
952 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
953 return failure();
954 if (transferWriteOp.getMask())
955 return failure();
956
957 // Determine vector dimensions to collapse.
958 // Ignore a leading sequence of adjacent unit dimensions in the vector.
959 ArrayRef<int64_t> collapsedVectorShape =
960 vectorType.getShape().drop_while([](auto v) { return v == 1; });
961 size_t collapsedVecRank = collapsedVectorShape.size();
962 // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
963 // to a 1D single-element vector.
964 if (collapsedVecRank == 0)
965 collapsedVecRank = 1;
966
967 // Determine the first memref dimension to collapse - just enough so we can
968 // read a flattened vector.
969 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
970
971 // 1. Collapse the source memref
972 Value collapsedSource =
973 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
974 MemRefType collapsedSourceType =
975 cast<MemRefType>(collapsedSource.getType());
976 int64_t collapsedRank = collapsedSourceType.getRank();
977 assert(collapsedRank == firstDimToCollapse + 1);
978
979 // 2. Generate input args for a new vector.transfer_read that will read
980 // from the collapsed memref.
981 // 2.1. New dim exprs + affine map
982 SmallVector<AffineExpr, 1> dimExprs{
983 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
984 auto collapsedMap =
985 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
986
987 // 2.2 New indices
988 SmallVector<Value> collapsedIndices =
989 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
990 transferWriteOp.getIndices(), firstDimToCollapse);
991
992 // 3. Create new vector.transfer_write that writes to the collapsed memref
993 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
994 vectorType.getElementType());
995 Value flatVector =
996 vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
997 vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
998 rewriter, loc, flatVector, collapsedSource, collapsedIndices,
999 collapsedMap);
1000 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1001
1002 // 4. Replace the old transfer_write with the new one writing the
1003 // collapsed shape
1004 rewriter.eraseOp(transferWriteOp);
1005 return success();
1006 }
1007
1008private:
1009 // Minimum bitwidth that the trailing vector dimension should have after
1010 // flattening.
1011 unsigned targetVectorBitwidth;
1012};
1013
1014/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
1015///
1016/// All the users of the transfer op must be `vector.extract` ops. If
1017/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
1018/// users. Otherwise, rewrite only if the extract op is the single user of the
1019/// transfer op. Rewriting a single vector load with multiple scalar loads may
1020/// negatively affect performance.
1021class RewriteScalarExtractOfTransferRead
1022 : public OpRewritePattern<vector::ExtractOp> {
1023public:
1024 RewriteScalarExtractOfTransferRead(MLIRContext *context,
1025 PatternBenefit benefit,
1026 bool allowMultipleUses)
1027 : OpRewritePattern(context, benefit),
1028 allowMultipleUses(allowMultipleUses) {}
1029
1030 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1031 PatternRewriter &rewriter) const override {
1032 // Match phase.
1033 auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
1034 if (!xferOp)
1035 return failure();
1036 // Check that we are extracting a scalar and not a sub-vector.
1037 if (isa<VectorType>(extractOp.getResult().getType()))
1038 return failure();
1039 // If multiple uses are not allowed, check if xfer has a single use.
1040 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1041 return failure();
1042 // If multiple uses are allowed, check if all the xfer uses are extract ops.
1043 if (allowMultipleUses &&
1044 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
1045 return isa<vector::ExtractOp>(use.getOwner());
1046 }))
1047 return failure();
1048 // Mask not supported.
1049 if (xferOp.getMask())
1050 return failure();
1051 // Map not supported.
1052 if (!xferOp.getPermutationMap().isMinorIdentity())
1053 return failure();
1054 // Cannot rewrite if the indices may be out of bounds.
1055 if (xferOp.hasOutOfBoundsDim())
1056 return failure();
1057
1058 // Rewrite phase: construct scalar load.
1059 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
1060 xferOp.getIndices().end());
1061 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1062 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1063
1064 // Compute affine expression `newIndices[idx] + pos` where `pos` can be
1065 // either a constant or a value.
1066 OpFoldResult composedIdx;
1067 if (auto attr = dyn_cast<Attribute>(pos)) {
1068 int64_t offset = cast<IntegerAttr>(attr).getInt();
1069 composedIdx = affine::makeComposedFoldedAffineApply(
1070 rewriter, extractOp.getLoc(),
1071 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
1072 } else {
1073 Value dynamicOffset = cast<Value>(pos);
1074 AffineExpr sym0, sym1;
1075 bindSymbols(rewriter.getContext(), sym0, sym1);
1076 composedIdx = affine::makeComposedFoldedAffineApply(
1077 rewriter, extractOp.getLoc(), sym0 + sym1,
1078 {newIndices[idx], dynamicOffset});
1079 }
1080
1081 // Update the corresponding index with the folded result.
1082 if (auto value = dyn_cast<Value>(composedIdx)) {
1083 newIndices[idx] = value;
1084 } else {
1085 newIndices[idx] = arith::ConstantIndexOp::create(
1086 rewriter, extractOp.getLoc(), *getConstantIntValue(composedIdx));
1087 }
1088 }
1089 if (isa<MemRefType>(xferOp.getBase().getType())) {
1090 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
1091 newIndices);
1092 } else {
1093 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1094 extractOp, xferOp.getBase(), newIndices);
1095 }
1096
1097 return success();
1098 }
1099
1100private:
1101 bool allowMultipleUses;
1102};
1103
1104/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
1105/// to memref.store.
1106class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
1107 using Base::Base;
1108
1109 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1110 PatternRewriter &rewriter) const override {
1111 // Must be a scalar write.
1112 auto vecType = xferOp.getVectorType();
1113 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1114 return failure();
1115 // Mask not supported.
1116 if (xferOp.getMask())
1117 return failure();
1118 // Map not supported.
1119 if (!xferOp.getPermutationMap().isMinorIdentity())
1120 return failure();
1121 // Only float and integer element types are supported.
1122 Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1123 xferOp.getVector());
1124 // Construct a scalar store.
1125 if (isa<MemRefType>(xferOp.getBase().getType())) {
1126 rewriter.replaceOpWithNewOp<memref::StoreOp>(
1127 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1128 } else {
1129 rewriter.replaceOpWithNewOp<tensor::InsertOp>(
1130 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1131 }
1132 return success();
1133 }
1134};
1135
1136} // namespace
1137
1138void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
1139 Operation *rootOp) {
1140 LDBG() << "=== Starting transferOpflowOpt on root operation: "
1141 << OpWithFlags(rootOp, OpPrintingFlags().skipRegions());
1142 TransferOptimization opt(rewriter, rootOp);
1143
1144 // Run store to load forwarding first since it can expose more dead store
1145 // opportunity.
1146 LDBG() << "Phase 1: Store-to-load forwarding";
1147 int readCount = 0;
1148 rootOp->walk([&](vector::TransferReadOp read) {
1149 if (isa<MemRefType>(read.getShapedType())) {
1150 LDBG() << "Processing transfer_read #" << ++readCount << ": " << *read;
1151 opt.storeToLoadForwarding(read);
1152 }
1153 });
1154 LDBG() << "Phase 1 complete. Removing dead operations from forwarding";
1155 opt.removeDeadOp();
1156
1157 LDBG() << "Phase 2: Dead store elimination";
1158 int writeCount = 0;
1159 rootOp->walk([&](vector::TransferWriteOp write) {
1160 if (isa<MemRefType>(write.getShapedType())) {
1161 LDBG() << "Processing transfer_write #" << ++writeCount << ": " << *write;
1162 opt.deadStoreOp(write);
1163 }
1164 });
1165 LDBG() << "Phase 2 complete. Removing dead operations from dead store "
1166 "elimination";
1167 opt.removeDeadOp();
1168 LDBG() << "=== transferOpflowOpt complete";
1169}
1170
1172 RewritePatternSet &patterns, PatternBenefit benefit,
1173 bool allowMultipleUses) {
1174 patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
1175 benefit, allowMultipleUses);
1176 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
1177}
1178
1179void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1180 RewritePatternSet &patterns, PatternBenefit benefit) {
1181 patterns
1182 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1183 patterns.getContext(), benefit);
1184}
1185
1186void mlir::vector::populateFlattenVectorTransferPatterns(
1187 RewritePatternSet &patterns, unsigned targetVectorBitwidth,
1188 PatternBenefit benefit) {
1189 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
1190 FlattenContiguousRowMajorTransferWritePattern>(
1191 patterns.getContext(), targetVectorBitwidth, benefit);
1192 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
1193}
return success()
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static SmallVector< Value > getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef< int64_t > shape, ValueRange indices, int64_t firstDimToCollapse)
Returns the new indices that collapses the inner dimensions starting from the firstDimToCollapse dime...
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static FailureOr< Value > maskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, MaskOp op)
Rewrites a mask op to drop non-scalable unit dimensions.
static bool isUnitDimMask(Value maskDimSize)
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
Definition Block.cpp:368
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getResults()
Definition Operation.h:441
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
Definition Dominance.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.