MLIR 22.0.0git
VectorTransferOpTransforms.cpp
Go to the documentation of this file.
1//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with optimizing transfer_read and
10// transfer_write ops.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/Dominance.h"
25#include "mlir/IR/Operation.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/StringRef.h"
29#include "llvm/Support/DebugLog.h"
30
31#define DEBUG_TYPE "vector-transfer-opt"
32
33using namespace mlir;
34
35/// Return the ancestor op in the region or nullptr if the region is not
36/// an ancestor of the op.
38 LDBG() << " Finding ancestor of " << *op << " in region";
39 for (; op != nullptr && op->getParentRegion() != region;
40 op = op->getParentOp())
41 ;
42 if (op) {
43 LDBG() << " -> Ancestor: " << *op;
44 } else {
45 LDBG() << " -> Ancestor: nullptr";
46 }
47 return op;
48}
49
50namespace {
51
52class TransferOptimization {
53public:
54 TransferOptimization(RewriterBase &rewriter, Operation *op)
55 : rewriter(rewriter), dominators(op), postDominators(op) {}
56 void deadStoreOp(vector::TransferWriteOp);
57 void storeToLoadForwarding(vector::TransferReadOp);
58 void removeDeadOp() {
59 LDBG() << "Removing " << opToErase.size() << " dead operations";
60 for (Operation *op : opToErase) {
61 LDBG() << " -> Erasing: " << *op;
62 rewriter.eraseOp(op);
63 }
64 opToErase.clear();
65 }
66
67private:
68 RewriterBase &rewriter;
69 bool isReachable(Operation *start, Operation *dest);
70 DominanceInfo dominators;
71 PostDominanceInfo postDominators;
72 std::vector<Operation *> opToErase;
73};
74
75} // namespace
76/// Return true if there is a path from start operation to dest operation,
77/// otherwise return false. The operations have to be in the same region.
78bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
79 LDBG() << " Checking reachability from " << *start << " to " << *dest;
80 assert(start->getParentRegion() == dest->getParentRegion() &&
81 "This function only works for ops i the same region");
82 // Simple case where the start op dominate the destination.
83 if (dominators.dominates(start, dest)) {
84 LDBG() << " -> Start dominates dest, reachable";
85 return true;
86 }
87 bool blockReachable = start->getBlock()->isReachable(dest->getBlock());
88 LDBG() << " -> Block reachable: " << blockReachable;
89 return blockReachable;
90}
91
92/// For transfer_write to overwrite fully another transfer_write must:
93/// 1. Access the same memref with the same indices and vector type.
94/// 2. Post-dominate the other transfer_write operation.
95/// If several candidates are available, one must be post-dominated by all the
96/// others since they are all post-dominating the same transfer_write. We only
97/// consider the transfer_write post-dominated by all the other candidates as
98/// this will be the first transfer_write executed after the potentially dead
99/// transfer_write.
100/// If we found such an overwriting transfer_write we know that the original
101/// transfer_write is dead if all reads that can be reached from the potentially
102/// dead transfer_write are dominated by the overwriting transfer_write.
103void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104 LDBG() << "=== Starting deadStoreOp analysis for: " << *write.getOperation();
105 llvm::SmallVector<Operation *, 8> blockingAccesses;
106 Operation *firstOverwriteCandidate = nullptr;
107 Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
108 LDBG() << "Source memref (after skipping view-like ops): " << source;
109 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
110 source.getUsers().end());
111 LDBG() << "Found " << users.size() << " users of source memref";
112 llvm::SmallDenseSet<Operation *, 32> processed;
113 while (!users.empty()) {
114 Operation *user = users.pop_back_val();
115 LDBG() << "Processing user: " << *user;
116 // If the user has already been processed skip.
117 if (!processed.insert(user).second) {
118 LDBG() << " -> Already processed, skipping";
119 continue;
120 }
121 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
122 LDBG() << " -> View-like operation, following to destination";
123 Value viewDest = viewLike.getViewDest();
124 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
125 continue;
126 }
127 if (isMemoryEffectFree(user)) {
128 LDBG() << " -> Memory effect free, skipping";
129 continue;
130 }
131 if (user == write.getOperation()) {
132 LDBG() << " -> Same as write operation, skipping";
133 continue;
134 }
135 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
136 LDBG() << " -> Found transfer_write candidate: " << *nextWrite;
137 // Check candidate that can override the store.
138 bool sameView = memref::isSameViewOrTrivialAlias(
139 cast<MemrefValue>(nextWrite.getBase()),
140 cast<MemrefValue>(write.getBase()));
141 bool sameValue = checkSameValueWAW(nextWrite, write);
142 bool postDominates = postDominators.postDominates(nextWrite, write);
143 LDBG() << " -> Same view: " << sameView
144 << ", Same value: " << sameValue
145 << ", Post-dominates: " << postDominates;
146
147 if (sameView && sameValue && postDominates) {
148 LDBG() << " -> Valid overwrite candidate found";
149 if (firstOverwriteCandidate == nullptr ||
150 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) {
151 LDBG() << " -> New first overwrite candidate: " << *nextWrite;
152 firstOverwriteCandidate = nextWrite;
153 } else {
154 LDBG() << " -> Keeping existing first overwrite candidate";
155 assert(
156 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
157 }
158 continue;
159 }
160 LDBG() << " -> Not a valid overwrite candidate";
161 }
162 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
163 LDBG() << " -> Found vector transfer operation: " << *transferOp;
164 // Don't need to consider disjoint accesses.
165 bool isDisjoint = vector::isDisjointTransferSet(
166 cast<VectorTransferOpInterface>(write.getOperation()),
167 cast<VectorTransferOpInterface>(transferOp.getOperation()),
168 /*testDynamicValueUsingBounds=*/true);
169 LDBG() << " -> Is disjoint: " << isDisjoint;
170 if (isDisjoint) {
171 LDBG() << " -> Skipping disjoint access";
172 continue;
173 }
174 }
175 LDBG() << " -> Adding to blocking accesses: " << *user;
176 blockingAccesses.push_back(user);
177 }
178 LDBG() << "Finished processing users. Found " << blockingAccesses.size()
179 << " blocking accesses";
180
181 if (firstOverwriteCandidate == nullptr) {
182 LDBG() << "No overwrite candidate found, store is not dead";
183 return;
184 }
185
186 LDBG() << "First overwrite candidate: " << *firstOverwriteCandidate;
187 Region *topRegion = firstOverwriteCandidate->getParentRegion();
188 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
189 assert(writeAncestor &&
190 "write op should be recursively part of the top region");
191 LDBG() << "Write ancestor in top region: " << *writeAncestor;
192
193 LDBG() << "Checking " << blockingAccesses.size()
194 << " blocking accesses for reachability";
195 for (Operation *access : blockingAccesses) {
196 LDBG() << "Checking blocking access: " << *access;
197 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
198 // TODO: if the access and write have the same ancestor we could recurse in
199 // the region to know if the access is reachable with more precision.
200 if (accessAncestor == nullptr) {
201 LDBG() << " -> No ancestor in top region, skipping";
202 continue;
203 }
204
205 bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
206 LDBG() << " -> Is reachable from write: " << isReachableFromWrite;
207 if (!isReachableFromWrite) {
208 LDBG() << " -> Not reachable, skipping";
209 continue;
210 }
211
212 bool overwriteDominatesAccess =
213 dominators.dominates(firstOverwriteCandidate, accessAncestor);
214 LDBG() << " -> Overwrite dominates access: " << overwriteDominatesAccess;
215 if (!overwriteDominatesAccess) {
216 LDBG() << "Store may not be dead due to op: " << *accessAncestor;
217 return;
218 }
219 LDBG() << " -> Access is dominated by overwrite, continuing";
220 }
221 LDBG() << "Found dead store: " << *write.getOperation()
222 << " overwritten by: " << *firstOverwriteCandidate;
223 opToErase.push_back(write.getOperation());
224}
225
226/// A transfer_write candidate to storeToLoad forwarding must:
227/// 1. Access the same memref with the same indices and vector type as the
228/// transfer_read.
229/// 2. Dominate the transfer_read operation.
230/// If several candidates are available, one must be dominated by all the others
231/// since they are all dominating the same transfer_read. We only consider the
232/// transfer_write dominated by all the other candidates as this will be the
233/// last transfer_write executed before the transfer_read.
234/// If we found such a candidate we can do the forwarding if all the other
235/// potentially aliasing ops that may reach the transfer_read are post-dominated
236/// by the transfer_write.
237void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
238 LDBG() << "=== Starting storeToLoadForwarding analysis for: "
239 << *read.getOperation();
240 if (read.hasOutOfBoundsDim()) {
241 LDBG() << "Read has out-of-bounds dimensions, skipping";
242 return;
243 }
244 SmallVector<Operation *, 8> blockingWrites;
245 vector::TransferWriteOp lastwrite = nullptr;
246 Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
247 LDBG() << "Source memref (after skipping view-like ops): " << source;
248 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
249 source.getUsers().end());
250 LDBG() << "Found " << users.size() << " users of source memref";
251 llvm::SmallDenseSet<Operation *, 32> processed;
252 while (!users.empty()) {
253 Operation *user = users.pop_back_val();
254 LDBG() << "Processing user: " << *user;
255 // If the user has already been processed skip.
256 if (!processed.insert(user).second) {
257 LDBG() << " -> Already processed, skipping";
258 continue;
259 }
260 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
261 LDBG() << " -> View-like operation, following to destination";
262 Value viewDest = viewLike.getViewDest();
263 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
264 continue;
265 }
266 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) {
267 LDBG() << " -> Memory effect free or transfer_read, skipping";
268 continue;
269 }
270 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
271 LDBG() << " -> Found transfer_write candidate: " << *write;
272 // If there is a write, but we can prove that it is disjoint we can ignore
273 // the write.
274 bool isDisjoint = vector::isDisjointTransferSet(
275 cast<VectorTransferOpInterface>(write.getOperation()),
276 cast<VectorTransferOpInterface>(read.getOperation()),
277 /*testDynamicValueUsingBounds=*/true);
278 LDBG() << " -> Is disjoint: " << isDisjoint;
279 if (isDisjoint) {
280 LDBG() << " -> Skipping disjoint write";
281 continue;
282 }
283
284 bool sameView =
285 memref::isSameViewOrTrivialAlias(cast<MemrefValue>(read.getBase()),
286 cast<MemrefValue>(write.getBase()));
287 bool dominates = dominators.dominates(write, read);
288 bool sameValue = checkSameValueRAW(write, read);
289 LDBG() << " -> Same view: " << sameView << ", Dominates: " << dominates
290 << ", Same value: " << sameValue;
291
292 if (sameView && dominates && sameValue) {
293 LDBG() << " -> Valid forwarding candidate found";
294 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) {
295 LDBG() << " -> New last write candidate: " << *write;
296 lastwrite = write;
297 } else {
298 LDBG() << " -> Keeping existing last write candidate";
299 assert(dominators.dominates(write, lastwrite));
300 }
301 continue;
302 }
303 LDBG() << " -> Not a valid forwarding candidate";
304 }
305 LDBG() << " -> Adding to blocking writes: " << *user;
306 blockingWrites.push_back(user);
307 }
308 LDBG() << "Finished processing users. Found " << blockingWrites.size()
309 << " blocking writes";
310
311 if (lastwrite == nullptr) {
312 LDBG() << "No last write candidate found, cannot forward";
313 return;
314 }
315
316 LDBG() << "Last write candidate: " << *lastwrite;
317 Region *topRegion = lastwrite->getParentRegion();
318 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
319 assert(readAncestor &&
320 "read op should be recursively part of the top region");
321 LDBG() << "Read ancestor in top region: " << *readAncestor;
322
323 LDBG() << "Checking " << blockingWrites.size()
324 << " blocking writes for post-dominance";
325 for (Operation *write : blockingWrites) {
326 LDBG() << "Checking blocking write: " << *write;
327 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
328 if (writeAncestor) {
329 LDBG() << " -> Write ancestor: " << *writeAncestor;
330 } else {
331 LDBG() << " -> Write ancestor: nullptr";
332 }
333
334 // TODO: if the store and read have the same ancestor we could recurse in
335 // the region to know if the read is reachable with more precision.
336 if (writeAncestor == nullptr) {
337 LDBG() << " -> No ancestor in top region, skipping";
338 continue;
339 }
340
341 bool isReachableToRead = isReachable(writeAncestor, readAncestor);
342 LDBG() << " -> Is reachable to read: " << isReachableToRead;
343 if (!isReachableToRead) {
344 LDBG() << " -> Not reachable, skipping";
345 continue;
346 }
347
348 bool lastWritePostDominates =
349 postDominators.postDominates(lastwrite, write);
350 LDBG() << " -> Last write post-dominates blocking write: "
351 << lastWritePostDominates;
352 if (!lastWritePostDominates) {
353 LDBG() << "Fail to do write to read forwarding due to op: " << *write;
354 return;
355 }
356 LDBG() << " -> Blocking write is post-dominated, continuing";
357 }
358
359 LDBG() << "Forward value from " << *lastwrite.getOperation()
360 << " to: " << *read.getOperation();
361 read.replaceAllUsesWith(lastwrite.getVector());
362 opToErase.push_back(read.getOperation());
363}
364
365/// Converts OpFoldResults to int64_t shape without unit dims.
367 SmallVector<int64_t> reducedShape;
368 for (const auto size : mixedSizes) {
369 if (llvm::dyn_cast_if_present<Value>(size)) {
370 reducedShape.push_back(ShapedType::kDynamic);
371 continue;
372 }
373
374 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
375 if (value == 1)
376 continue;
377 reducedShape.push_back(value.getSExtValue());
378 }
379 return reducedShape;
380}
381
382/// Drops unit dimensions from the input MemRefType.
383static MemRefType dropUnitDims(MemRefType inputType,
386 ArrayRef<OpFoldResult> strides) {
387 auto targetShape = getReducedShape(sizes);
388 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
389 targetShape, inputType, offsets, sizes, strides);
390 return rankReducedType.canonicalizeStridedLayout();
391}
392
393/// Creates a rank-reducing memref.subview op that drops unit dims from its
394/// input. Or just returns the input if it was already without unit dims.
396 mlir::Location loc,
397 Value input) {
398 MemRefType inputType = cast<MemRefType>(input.getType());
399 SmallVector<OpFoldResult> offsets(inputType.getRank(),
400 rewriter.getIndexAttr(0));
401 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
402 SmallVector<OpFoldResult> strides(inputType.getRank(),
403 rewriter.getIndexAttr(1));
404 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
405
406 if (resultType.canonicalizeStridedLayout() ==
407 inputType.canonicalizeStridedLayout())
408 return input;
409 return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
410 sizes, strides);
411}
412
413/// Returns the number of dims that aren't unit dims.
415 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
416}
417
418/// Trims non-scalable one dimensions from `oldType` and returns the result
419/// type.
420static VectorType trimNonScalableUnitDims(VectorType oldType) {
421 SmallVector<int64_t> newShape;
422 SmallVector<bool> newScalableDims;
423 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
424 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
425 continue;
426 newShape.push_back(dimSize);
427 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
428 }
429 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
430}
431
432// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
433static FailureOr<Value>
435 vector::CreateMaskOp op) {
436 auto type = op.getType();
437 VectorType reducedType = trimNonScalableUnitDims(type);
438 if (reducedType.getRank() == type.getRank())
439 return failure();
440
441 SmallVector<Value> reducedOperands;
442 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
443 type.getShape(), type.getScalableDims(), op.getOperands())) {
444 if (dim == 1 && !dimIsScalable) {
445 // If the mask for the unit dim is not a constant of 1, do nothing.
446 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
447 if (!constant || (constant.value() != 1))
448 return failure();
449 continue;
450 }
451 reducedOperands.push_back(operand);
452 }
453 return vector::CreateMaskOp::create(rewriter, loc, reducedType,
454 reducedOperands)
455 .getResult();
456}
457
458namespace {
459
460/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
461/// inserting a memref.subview dropping those unit dims. The vector shapes are
462/// also reduced accordingly.
463class TransferReadDropUnitDimsPattern
464 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
465 using MaskableOpRewritePattern::MaskableOpRewritePattern;
466
467 FailureOr<Value>
468 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
469 vector::MaskingOpInterface maskingOp,
470 PatternRewriter &rewriter) const override {
471 LDBG() << "=== TransferReadDropUnitDimsPattern: Analyzing "
472 << *transferReadOp;
473 auto loc = transferReadOp.getLoc();
474 Value vector = transferReadOp.getVector();
475 VectorType vectorType = cast<VectorType>(vector.getType());
476 Value source = transferReadOp.getBase();
477 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
478 // TODO: support tensor types.
479 if (!sourceType) {
480 LDBG() << " -> Not a MemRefType, skipping";
481 return failure();
482 }
483 // TODO: generalize this pattern, relax the requirements here.
484 if (transferReadOp.hasOutOfBoundsDim()) {
485 LDBG() << " -> Has out-of-bounds dimensions, skipping";
486 return failure();
487 }
488 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
489 LDBG() << " -> Not minor identity permutation map, skipping";
490 return failure();
491 }
492 // Check if the source shape can be further reduced.
493 int reducedRank = getReducedRank(sourceType.getShape());
494 LDBG() << " -> Source rank: " << sourceType.getRank()
495 << ", Reduced rank: " << reducedRank;
496 if (reducedRank == sourceType.getRank()) {
497 LDBG() << " -> No unit dimensions to drop, skipping";
498 return failure();
499 }
500 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
501 // out.
502 if (reducedRank == 0 && maskingOp) {
503 LDBG() << " -> 0-d vector with masking not supported, skipping";
504 return failure();
505 }
506 // Check if the reduced vector shape matches the reduced source shape.
507 // Otherwise, this case is not supported yet.
508 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
509 LDBG() << " -> Vector type: " << vectorType
510 << ", Reduced vector type: " << reducedVectorType;
511 if (reducedRank != reducedVectorType.getRank()) {
512 LDBG() << " -> Reduced ranks don't match, skipping";
513 return failure();
514 }
515 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
516 return getConstantIntValue(v) != static_cast<int64_t>(0);
517 })) {
518 LDBG() << " -> Non-zero indices found, skipping";
519 return failure();
520 }
521
522 Value maskOp = transferReadOp.getMask();
523 if (maskOp) {
524 LDBG() << " -> Processing mask operation";
525 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
526 if (!createMaskOp) {
527 LDBG()
528 << " -> Unsupported mask op, only 'vector.create_mask' supported";
529 return rewriter.notifyMatchFailure(
530 transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
531 "currently supported");
532 }
533 FailureOr<Value> rankReducedCreateMask =
534 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
535 if (failed(rankReducedCreateMask)) {
536 LDBG() << " -> Failed to reduce mask dimensions";
537 return failure();
538 }
539 maskOp = *rankReducedCreateMask;
540 LDBG() << " -> Successfully reduced mask dimensions";
541 }
542
543 LDBG() << " -> Creating rank-reduced subview and new transfer_read";
544 Value reducedShapeSource =
545 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
546 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
547 SmallVector<Value> zeros(reducedRank, c0);
548 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
549 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
550 Operation *newTransferReadOp = vector::TransferReadOp::create(
551 rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
552 identityMap, transferReadOp.getPadding(), maskOp,
553 rewriter.getBoolArrayAttr(inBounds));
554 LDBG() << " -> Created new transfer_read: " << *newTransferReadOp;
555
556 if (maskingOp) {
557 LDBG() << " -> Applying masking operation";
558 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
559 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
560 maskingOp.getMask());
561 newTransferReadOp = mlir::vector::maskOperation(
562 rewriter, newTransferReadOp, shapeCastMask);
563 }
564
565 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
566 loc, vectorType, newTransferReadOp->getResults()[0]);
567 LDBG() << " -> Created shape cast: " << *shapeCast.getDefiningOp();
568 LDBG() << " -> Pattern match successful, returning result";
569
570 return shapeCast;
571 }
572};
573
574/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
575/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
576/// vector shapes are also reduced accordingly.
577class TransferWriteDropUnitDimsPattern
578 : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
579 using MaskableOpRewritePattern::MaskableOpRewritePattern;
580
581 FailureOr<Value>
582 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
583 vector::MaskingOpInterface maskingOp,
584 PatternRewriter &rewriter) const override {
585 LDBG() << "=== TransferWriteDropUnitDimsPattern: Analyzing "
586 << *transferWriteOp;
587 auto loc = transferWriteOp.getLoc();
588 Value vector = transferWriteOp.getVector();
589 VectorType vectorType = cast<VectorType>(vector.getType());
590 Value source = transferWriteOp.getBase();
591 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
592 // TODO: support tensor type.
593 if (!sourceType) {
594 LDBG() << " -> Not a MemRefType, skipping";
595 return failure();
596 }
597 // TODO: generalize this pattern, relax the requirements here.
598 if (transferWriteOp.hasOutOfBoundsDim()) {
599 LDBG() << " -> Has out-of-bounds dimensions, skipping";
600 return failure();
601 }
602 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
603 LDBG() << " -> Not minor identity permutation map, skipping";
604 return failure();
605 }
606 // Check if the destination shape can be further reduced.
607 int reducedRank = getReducedRank(sourceType.getShape());
608 LDBG() << " -> Source rank: " << sourceType.getRank()
609 << ", Reduced rank: " << reducedRank;
610 if (reducedRank == sourceType.getRank()) {
611 LDBG() << " -> No unit dimensions to drop, skipping";
612 return failure();
613 }
614 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
615 // out.
616 if (reducedRank == 0 && maskingOp) {
617 LDBG() << " -> 0-d vector with masking not supported, skipping";
618 return failure();
619 }
620 // Check if the reduced vector shape matches the reduced destination shape.
621 // Otherwise, this case is not supported yet.
622 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
623 LDBG() << " -> Vector type: " << vectorType
624 << ", Reduced vector type: " << reducedVectorType;
625 if (reducedRank != reducedVectorType.getRank()) {
626 LDBG() << " -> Reduced ranks don't match, skipping";
627 return failure();
628 }
629 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
630 return getConstantIntValue(v) != static_cast<int64_t>(0);
631 })) {
632 LDBG() << " -> Non-zero indices found, skipping";
633 return failure();
634 }
635
636 Value maskOp = transferWriteOp.getMask();
637 if (maskOp) {
638 LDBG() << " -> Processing mask operation";
639 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
640 if (!createMaskOp) {
641 LDBG()
642 << " -> Unsupported mask op, only 'vector.create_mask' supported";
643 return rewriter.notifyMatchFailure(
644 transferWriteOp,
645 "unsupported mask op, only 'vector.create_mask' is "
646 "currently supported");
647 }
648 FailureOr<Value> rankReducedCreateMask =
649 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
650 if (failed(rankReducedCreateMask)) {
651 LDBG() << " -> Failed to reduce mask dimensions";
652 return failure();
653 }
654 maskOp = *rankReducedCreateMask;
655 LDBG() << " -> Successfully reduced mask dimensions";
656 }
657 LDBG() << " -> Creating rank-reduced subview and new transfer_write";
658 Value reducedShapeSource =
659 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
660 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
661 SmallVector<Value> zeros(reducedRank, c0);
662 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
663 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
664 auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
665 loc, reducedVectorType, vector);
666 Operation *newXferWrite = vector::TransferWriteOp::create(
667 rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
668 identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
669 LDBG() << " -> Created new transfer_write: " << *newXferWrite;
670
671 if (maskingOp) {
672 LDBG() << " -> Applying masking operation";
673 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
674 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
675 maskingOp.getMask());
676 newXferWrite =
677 mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
678 }
679
680 if (transferWriteOp.hasPureTensorSemantics()) {
681 LDBG() << " -> Pattern match successful (tensor semantics), returning "
682 "result";
683 return newXferWrite->getResults()[0];
684 }
685
686 // With Memref semantics, there's no return value. Use empty value to signal
687 // success.
688 LDBG() << " -> Pattern match successful (memref semantics)";
689 return Value();
690 }
691};
692
693} // namespace
694
695/// Creates a memref.collapse_shape collapsing all inner dimensions of the
696/// input starting at `firstDimToCollapse`.
698 Value input, int64_t firstDimToCollapse) {
699 ShapedType inputType = cast<ShapedType>(input.getType());
700 if (inputType.getRank() == 1)
701 return input;
703 for (int64_t i = 0; i < firstDimToCollapse; ++i)
704 reassociation.push_back(ReassociationIndices{i});
705 ReassociationIndices collapsedIndices;
706 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
707 collapsedIndices.push_back(i);
708 reassociation.push_back(collapsedIndices);
709 return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
710}
711
712/// Returns the new indices that collapses the inner dimensions starting from
713/// the `firstDimToCollapse` dimension.
715 Location loc,
718 int64_t firstDimToCollapse) {
719 assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
720
721 // If all the collapsed indices are zero then no extra logic is needed.
722 // Otherwise, a new offset/index has to be computed.
723 SmallVector<Value> indicesAfterCollapsing(
724 indices.begin(), indices.begin() + firstDimToCollapse);
725 SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
726 indices.end());
727 if (llvm::all_of(indicesToCollapse, isZeroInteger)) {
728 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
729 return indicesAfterCollapsing;
730 }
731
732 // Compute the remaining trailing index/offset required for reading from
733 // the collapsed memref:
734 //
735 // offset = 0
736 // for (i = firstDimToCollapse; i < outputRank; ++i)
737 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
738 //
739 // For this example:
740 // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
741 // memref<1x43x2xi32>, vector<1x2xi32>
742 // which would be collapsed to:
743 // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
744 // memref<1x86xi32>, vector<2xi32>
745 // one would get the following offset:
746 // %offset = %arg0 * 43
747 OpFoldResult collapsedOffset =
748 arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
749
750 auto collapsedStrides = computeSuffixProduct(
751 ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
752
753 // Compute the collapsed offset.
754 auto &&[collapsedExpr, collapsedVals] =
755 computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
757 rewriter, loc, collapsedExpr, collapsedVals);
758
759 if (auto value = dyn_cast<Value>(collapsedOffset)) {
760 indicesAfterCollapsing.push_back(value);
761 } else {
762 indicesAfterCollapsing.push_back(arith::ConstantIndexOp::create(
763 rewriter, loc, *getConstantIntValue(collapsedOffset)));
764 }
765
766 return indicesAfterCollapsing;
767}
768
769namespace {
770/// Rewrites contiguous row-major vector.transfer_read ops by inserting
771/// memref.collapse_shape on the source so that the resulting
772/// vector.transfer_read has a 1D source. Requires the source shape to be
773/// already reduced i.e. without unit dims.
774///
775/// If `targetVectorBitwidth` is provided, the flattening will only happen if
776/// the trailing dimension of the vector read is smaller than the provided
777/// bitwidth.
778class FlattenContiguousRowMajorTransferReadPattern
779 : public OpRewritePattern<vector::TransferReadOp> {
780public:
781 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
782 unsigned vectorBitwidth,
783 PatternBenefit benefit)
784 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
785 targetVectorBitwidth(vectorBitwidth) {}
786
787 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
788 PatternRewriter &rewriter) const override {
789 LDBG() << "=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
790 << *transferReadOp;
791 auto loc = transferReadOp.getLoc();
792 Value vector = transferReadOp.getVector();
793 VectorType vectorType = cast<VectorType>(vector.getType());
794 auto source = transferReadOp.getBase();
795 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
796
797 // 0. Check pre-conditions
798 // Contiguity check is valid on tensors only.
799 if (!sourceType) {
800 LDBG() << " -> Not a MemRefType, skipping";
801 return failure();
802 }
803 // If this is already 0D/1D, there's nothing to do.
804 if (vectorType.getRank() <= 1) {
805 LDBG() << " -> Already 0D/1D, skipping";
806 return failure();
807 }
808 if (!vectorType.getElementType().isSignlessIntOrFloat()) {
809 LDBG() << " -> Not signless int or float, skipping";
810 return failure();
811 }
812 unsigned trailingVectorDimBitwidth =
813 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
814 LDBG() << " -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
815 << ", target: " << targetVectorBitwidth;
816 if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
817 LDBG() << " -> Trailing dim bitwidth >= target, skipping";
818 return failure();
819 }
820 if (!vector::isContiguousSlice(sourceType, vectorType)) {
821 LDBG() << " -> Not contiguous slice, skipping";
822 return failure();
823 }
824 // TODO: generalize this pattern, relax the requirements here.
825 if (transferReadOp.hasOutOfBoundsDim()) {
826 LDBG() << " -> Has out-of-bounds dimensions, skipping";
827 return failure();
828 }
829 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
830 LDBG() << " -> Not minor identity permutation map, skipping";
831 return failure();
832 }
833 if (transferReadOp.getMask()) {
834 LDBG() << " -> Has mask, skipping";
835 return failure();
836 }
837
838 // Determine the first memref dimension to collapse - just enough so we can
839 // read a flattened vector.
840 int64_t firstDimToCollapse =
841 sourceType.getRank() -
842 vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
843 LDBG() << " -> First dimension to collapse: " << firstDimToCollapse;
844
845 // 1. Collapse the source memref
846 LDBG() << " -> Collapsing source memref";
847 Value collapsedSource =
848 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
849 MemRefType collapsedSourceType =
850 cast<MemRefType>(collapsedSource.getType());
851 int64_t collapsedRank = collapsedSourceType.getRank();
852 assert(collapsedRank == firstDimToCollapse + 1);
853 LDBG() << " -> Collapsed source type: " << collapsedSourceType;
854
855 // 2. Generate input args for a new vector.transfer_read that will read
856 // from the collapsed memref.
857 // 2.1. New dim exprs + affine map
858 SmallVector<AffineExpr, 1> dimExprs{
859 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
860 auto collapsedMap =
861 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
862
863 // 2.2 New indices
864 SmallVector<Value> collapsedIndices =
865 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
866 transferReadOp.getIndices(), firstDimToCollapse);
867
868 // 3. Create new vector.transfer_read that reads from the collapsed memref
869 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
870 vectorType.getElementType());
871 LDBG() << " -> Creating flattened vector type: " << flatVectorType;
872 vector::TransferReadOp flatRead = vector::TransferReadOp::create(
873 rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
874 transferReadOp.getPadding(), collapsedMap);
875 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
876 LDBG() << " -> Created flat transfer_read: " << *flatRead;
877
878 // 4. Replace the old transfer_read with the new one reading from the
879 // collapsed shape
880 LDBG() << " -> Replacing with shape cast";
881 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
882 transferReadOp, cast<VectorType>(vector.getType()), flatRead);
883 LDBG() << " -> Pattern match successful";
884 return success();
885 }
886
887private:
888 // Minimum bitwidth that the trailing vector dimension should have after
889 // flattening.
890 unsigned targetVectorBitwidth;
891};
892
893/// Rewrites contiguous row-major vector.transfer_write ops by inserting
894/// memref.collapse_shape on the source so that the resulting
895/// vector.transfer_write has a 1D source. Requires the source shape to be
896/// already reduced i.e. without unit dims.
897///
898/// If `targetVectorBitwidth` is provided, the flattening will only happen if
899/// the trailing dimension of the vector read is smaller than the provided
900/// bitwidth.
901class FlattenContiguousRowMajorTransferWritePattern
902 : public OpRewritePattern<vector::TransferWriteOp> {
903public:
904 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
905 unsigned vectorBitwidth,
906 PatternBenefit benefit)
907 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
908 targetVectorBitwidth(vectorBitwidth) {}
909
910 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
911 PatternRewriter &rewriter) const override {
912 auto loc = transferWriteOp.getLoc();
913 Value vector = transferWriteOp.getVector();
914 VectorType vectorType = cast<VectorType>(vector.getType());
915 Value source = transferWriteOp.getBase();
916 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
917
918 // 0. Check pre-conditions
919 // Contiguity check is valid on tensors only.
920 if (!sourceType)
921 return failure();
922 // If this is already 0D/1D, there's nothing to do.
923 if (vectorType.getRank() <= 1)
924 // Already 0D/1D, nothing to do.
925 return failure();
926 if (!vectorType.getElementType().isSignlessIntOrFloat())
927 return failure();
928 unsigned trailingVectorDimBitwidth =
929 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
930 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
931 return failure();
932 if (!vector::isContiguousSlice(sourceType, vectorType))
933 return failure();
934 // TODO: generalize this pattern, relax the requirements here.
935 if (transferWriteOp.hasOutOfBoundsDim())
936 return failure();
937 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
938 return failure();
939 if (transferWriteOp.getMask())
940 return failure();
941
942 // Determine the first memref dimension to collapse - just enough so we can
943 // read a flattened vector.
944 int64_t firstDimToCollapse =
945 sourceType.getRank() -
946 vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
947
948 // 1. Collapse the source memref
949 Value collapsedSource =
950 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
951 MemRefType collapsedSourceType =
952 cast<MemRefType>(collapsedSource.getType());
953 int64_t collapsedRank = collapsedSourceType.getRank();
954 assert(collapsedRank == firstDimToCollapse + 1);
955
956 // 2. Generate input args for a new vector.transfer_read that will read
957 // from the collapsed memref.
958 // 2.1. New dim exprs + affine map
959 SmallVector<AffineExpr, 1> dimExprs{
960 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
961 auto collapsedMap =
962 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
963
964 // 2.2 New indices
965 SmallVector<Value> collapsedIndices =
966 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
967 transferWriteOp.getIndices(), firstDimToCollapse);
968
969 // 3. Create new vector.transfer_write that writes to the collapsed memref
970 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
971 vectorType.getElementType());
972 Value flatVector =
973 vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
974 vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
975 rewriter, loc, flatVector, collapsedSource, collapsedIndices,
976 collapsedMap);
977 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
978
979 // 4. Replace the old transfer_write with the new one writing the
980 // collapsed shape
981 rewriter.eraseOp(transferWriteOp);
982 return success();
983 }
984
985private:
986 // Minimum bitwidth that the trailing vector dimension should have after
987 // flattening.
988 unsigned targetVectorBitwidth;
989};
990
991/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
992///
993/// All the users of the transfer op must be `vector.extract` ops. If
994/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
995/// users. Otherwise, rewrite only if the extract op is the single user of the
996/// transfer op. Rewriting a single vector load with multiple scalar loads may
997/// negatively affect performance.
998class RewriteScalarExtractOfTransferRead
999 : public OpRewritePattern<vector::ExtractOp> {
1000public:
1001 RewriteScalarExtractOfTransferRead(MLIRContext *context,
1002 PatternBenefit benefit,
1003 bool allowMultipleUses)
1004 : OpRewritePattern(context, benefit),
1005 allowMultipleUses(allowMultipleUses) {}
1006
1007 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1008 PatternRewriter &rewriter) const override {
1009 // Match phase.
1010 auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
1011 if (!xferOp)
1012 return failure();
1013 // Check that we are extracting a scalar and not a sub-vector.
1014 if (isa<VectorType>(extractOp.getResult().getType()))
1015 return failure();
1016 // If multiple uses are not allowed, check if xfer has a single use.
1017 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1018 return failure();
1019 // If multiple uses are allowed, check if all the xfer uses are extract ops.
1020 if (allowMultipleUses &&
1021 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
1022 return isa<vector::ExtractOp>(use.getOwner());
1023 }))
1024 return failure();
1025 // Mask not supported.
1026 if (xferOp.getMask())
1027 return failure();
1028 // Map not supported.
1029 if (!xferOp.getPermutationMap().isMinorIdentity())
1030 return failure();
1031 // Cannot rewrite if the indices may be out of bounds.
1032 if (xferOp.hasOutOfBoundsDim())
1033 return failure();
1034
1035 // Rewrite phase: construct scalar load.
1036 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
1037 xferOp.getIndices().end());
1038 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1039 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1040
1041 // Compute affine expression `newIndices[idx] + pos` where `pos` can be
1042 // either a constant or a value.
1043 OpFoldResult composedIdx;
1044 if (auto attr = dyn_cast<Attribute>(pos)) {
1045 int64_t offset = cast<IntegerAttr>(attr).getInt();
1047 rewriter, extractOp.getLoc(),
1048 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
1049 } else {
1050 Value dynamicOffset = cast<Value>(pos);
1051 AffineExpr sym0, sym1;
1052 bindSymbols(rewriter.getContext(), sym0, sym1);
1054 rewriter, extractOp.getLoc(), sym0 + sym1,
1055 {newIndices[idx], dynamicOffset});
1056 }
1057
1058 // Update the corresponding index with the folded result.
1059 if (auto value = dyn_cast<Value>(composedIdx)) {
1060 newIndices[idx] = value;
1061 } else {
1062 newIndices[idx] = arith::ConstantIndexOp::create(
1063 rewriter, extractOp.getLoc(), *getConstantIntValue(composedIdx));
1064 }
1065 }
1066 if (isa<MemRefType>(xferOp.getBase().getType())) {
1067 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
1068 newIndices);
1069 } else {
1070 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1071 extractOp, xferOp.getBase(), newIndices);
1072 }
1073
1074 return success();
1075 }
1076
1077private:
1078 bool allowMultipleUses;
1079};
1080
1081/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
1082/// to memref.store.
1083class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
1084 using Base::Base;
1085
1086 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1087 PatternRewriter &rewriter) const override {
1088 // Must be a scalar write.
1089 auto vecType = xferOp.getVectorType();
1090 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1091 return failure();
1092 // Mask not supported.
1093 if (xferOp.getMask())
1094 return failure();
1095 // Map not supported.
1096 if (!xferOp.getPermutationMap().isMinorIdentity())
1097 return failure();
1098 // Only float and integer element types are supported.
1099 Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1100 xferOp.getVector());
1101 // Construct a scalar store.
1102 if (isa<MemRefType>(xferOp.getBase().getType())) {
1103 rewriter.replaceOpWithNewOp<memref::StoreOp>(
1104 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1105 } else {
1106 rewriter.replaceOpWithNewOp<tensor::InsertOp>(
1107 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1108 }
1109 return success();
1110 }
1111};
1112
1113} // namespace
1114
1115void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
1116 Operation *rootOp) {
1117 LDBG() << "=== Starting transferOpflowOpt on root operation: "
1118 << OpWithFlags(rootOp, OpPrintingFlags().skipRegions());
1119 TransferOptimization opt(rewriter, rootOp);
1120
1121 // Run store to load forwarding first since it can expose more dead store
1122 // opportunity.
1123 LDBG() << "Phase 1: Store-to-load forwarding";
1124 int readCount = 0;
1125 rootOp->walk([&](vector::TransferReadOp read) {
1126 if (isa<MemRefType>(read.getShapedType())) {
1127 LDBG() << "Processing transfer_read #" << ++readCount << ": " << *read;
1128 opt.storeToLoadForwarding(read);
1129 }
1130 });
1131 LDBG() << "Phase 1 complete. Removing dead operations from forwarding";
1132 opt.removeDeadOp();
1133
1134 LDBG() << "Phase 2: Dead store elimination";
1135 int writeCount = 0;
1136 rootOp->walk([&](vector::TransferWriteOp write) {
1137 if (isa<MemRefType>(write.getShapedType())) {
1138 LDBG() << "Processing transfer_write #" << ++writeCount << ": " << *write;
1139 opt.deadStoreOp(write);
1140 }
1141 });
1142 LDBG() << "Phase 2 complete. Removing dead operations from dead store "
1143 "elimination";
1144 opt.removeDeadOp();
1145 LDBG() << "=== transferOpflowOpt complete";
1146}
1147
1150 bool allowMultipleUses) {
1151 patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
1152 benefit, allowMultipleUses);
1153 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
1154}
1155
1156void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1158 patterns
1159 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1160 patterns.getContext(), benefit);
1161}
1162
1163void mlir::vector::populateFlattenVectorTransferPatterns(
1164 RewritePatternSet &patterns, unsigned targetVectorBitwidth,
1165 PatternBenefit benefit) {
1166 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
1167 FlattenContiguousRowMajorTransferWritePattern>(
1168 patterns.getContext(), targetVectorBitwidth, benefit);
1169 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
1170}
return success()
static SmallVector< int64_t > getReducedShape(ArrayRef< OpFoldResult > mixedSizes)
Converts OpFoldResults to int64_t shape without unit dims.
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input)
Creates a rank-reducing memref.subview op that drops unit dims from its input.
static SmallVector< Value > getCollapsedIndices(RewriterBase &rewriter, Location loc, ArrayRef< int64_t > shape, ValueRange indices, int64_t firstDimToCollapse)
Returns the new indices that collapses the inner dimensions starting from the firstDimToCollapse dime...
static int getReducedRank(ArrayRef< int64_t > shape)
Returns the number of dims that aren't unit dims.
static VectorType trimNonScalableUnitDims(VectorType oldType)
Trims non-scalable one dimensions from oldType and returns the result type.
static FailureOr< Value > createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op)
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
static Operation * findAncestorOpInRegion(Region *region, Operation *op)
Return the ancestor op in the region or nullptr if the region is not an ancestor of the op.
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides)
Drops unit dimensions from the input MemRefType.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
Definition Block.cpp:363
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
Definition Dominance.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.