27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/StringRef.h"
29#include "llvm/Support/DebugLog.h"
31#define DEBUG_TYPE "vector-transfer-opt"
38 LDBG() <<
" Finding ancestor of " << *op <<
" in region";
43 LDBG() <<
" -> Ancestor: " << *op;
45 LDBG() <<
" -> Ancestor: nullptr";
52class TransferOptimization {
54 TransferOptimization(RewriterBase &rewriter, Operation *op)
55 : rewriter(rewriter), dominators(op), postDominators(op) {}
56 void deadStoreOp(vector::TransferWriteOp);
57 void storeToLoadForwarding(vector::TransferReadOp);
59 LDBG() <<
"Removing " << opToErase.size() <<
" dead operations";
60 for (Operation *op : opToErase) {
61 LDBG() <<
" -> Erasing: " << *op;
68 RewriterBase &rewriter;
69 bool isReachable(Operation *start, Operation *dest);
70 DominanceInfo dominators;
71 PostDominanceInfo postDominators;
72 std::vector<Operation *> opToErase;
79 LDBG() <<
" Checking reachability from " << *start <<
" to " << *dest;
81 "This function only works for ops i the same region");
84 LDBG() <<
" -> Start dominates dest, reachable";
88 LDBG() <<
" -> Block reachable: " << blockReachable;
89 return blockReachable;
103void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104 LDBG() <<
"=== Starting deadStoreOp analysis for: " << *write.getOperation();
105 llvm::SmallVector<Operation *, 8> blockingAccesses;
106 Operation *firstOverwriteCandidate =
nullptr;
108 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
109 llvm::SmallVector<Operation *, 32> users(source.
getUsers().begin(),
111 LDBG() <<
"Found " << users.size() <<
" users of source memref";
112 llvm::SmallDenseSet<Operation *, 32> processed;
113 while (!users.empty()) {
114 Operation *user = users.pop_back_val();
115 LDBG() <<
"Processing user: " << *user;
117 if (!processed.insert(user).second) {
118 LDBG() <<
" -> Already processed, skipping";
121 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
122 LDBG() <<
" -> View-like operation, following to destination";
123 Value viewDest = viewLike.getViewDest();
128 LDBG() <<
" -> Memory effect free, skipping";
131 if (user == write.getOperation()) {
132 LDBG() <<
" -> Same as write operation, skipping";
135 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
136 LDBG() <<
" -> Found transfer_write candidate: " << *nextWrite;
139 cast<MemrefValue>(nextWrite.getBase()),
140 cast<MemrefValue>(write.getBase()));
142 bool postDominates = postDominators.
postDominates(nextWrite, write);
143 LDBG() <<
" -> Same view: " << sameView
144 <<
", Same value: " << sameValue
145 <<
", Post-dominates: " << postDominates;
147 if (sameView && sameValue && postDominates) {
148 LDBG() <<
" -> Valid overwrite candidate found";
149 if (firstOverwriteCandidate ==
nullptr ||
150 postDominators.
postDominates(firstOverwriteCandidate, nextWrite)) {
151 LDBG() <<
" -> New first overwrite candidate: " << *nextWrite;
152 firstOverwriteCandidate = nextWrite;
154 LDBG() <<
" -> Keeping existing first overwrite candidate";
156 postDominators.
postDominates(nextWrite, firstOverwriteCandidate));
160 LDBG() <<
" -> Not a valid overwrite candidate";
162 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
163 LDBG() <<
" -> Found vector transfer operation: " << *transferOp;
166 cast<VectorTransferOpInterface>(write.getOperation()),
167 cast<VectorTransferOpInterface>(transferOp.getOperation()),
169 LDBG() <<
" -> Is disjoint: " << isDisjoint;
171 LDBG() <<
" -> Skipping disjoint access";
175 LDBG() <<
" -> Adding to blocking accesses: " << *user;
176 blockingAccesses.push_back(user);
178 LDBG() <<
"Finished processing users. Found " << blockingAccesses.size()
179 <<
" blocking accesses";
181 if (firstOverwriteCandidate ==
nullptr) {
182 LDBG() <<
"No overwrite candidate found, store is not dead";
186 LDBG() <<
"First overwrite candidate: " << *firstOverwriteCandidate;
189 assert(writeAncestor &&
190 "write op should be recursively part of the top region");
191 LDBG() <<
"Write ancestor in top region: " << *writeAncestor;
193 LDBG() <<
"Checking " << blockingAccesses.size()
194 <<
" blocking accesses for reachability";
195 for (Operation *access : blockingAccesses) {
196 LDBG() <<
"Checking blocking access: " << *access;
200 if (accessAncestor ==
nullptr) {
201 LDBG() <<
" -> No ancestor in top region, skipping";
205 bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
206 LDBG() <<
" -> Is reachable from write: " << isReachableFromWrite;
207 if (!isReachableFromWrite) {
208 LDBG() <<
" -> Not reachable, skipping";
212 bool overwriteDominatesAccess =
213 dominators.
dominates(firstOverwriteCandidate, accessAncestor);
214 LDBG() <<
" -> Overwrite dominates access: " << overwriteDominatesAccess;
215 if (!overwriteDominatesAccess) {
216 LDBG() <<
"Store may not be dead due to op: " << *accessAncestor;
219 LDBG() <<
" -> Access is dominated by overwrite, continuing";
221 LDBG() <<
"Found dead store: " << *write.getOperation()
222 <<
" overwritten by: " << *firstOverwriteCandidate;
223 opToErase.push_back(write.getOperation());
237void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
238 LDBG() <<
"=== Starting storeToLoadForwarding analysis for: "
239 << *read.getOperation();
240 if (read.hasOutOfBoundsDim()) {
241 LDBG() <<
"Read has out-of-bounds dimensions, skipping";
244 SmallVector<Operation *, 8> blockingWrites;
245 vector::TransferWriteOp lastwrite =
nullptr;
247 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
248 llvm::SmallVector<Operation *, 32> users(source.
getUsers().begin(),
250 LDBG() <<
"Found " << users.size() <<
" users of source memref";
251 llvm::SmallDenseSet<Operation *, 32> processed;
252 while (!users.empty()) {
253 Operation *user = users.pop_back_val();
254 LDBG() <<
"Processing user: " << *user;
256 if (!processed.insert(user).second) {
257 LDBG() <<
" -> Already processed, skipping";
260 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
261 LDBG() <<
" -> View-like operation, following to destination";
262 Value viewDest = viewLike.getViewDest();
267 LDBG() <<
" -> Memory effect free or transfer_read, skipping";
270 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
271 LDBG() <<
" -> Found transfer_write candidate: " << *write;
275 cast<VectorTransferOpInterface>(write.getOperation()),
276 cast<VectorTransferOpInterface>(read.getOperation()),
278 LDBG() <<
" -> Is disjoint: " << isDisjoint;
280 LDBG() <<
" -> Skipping disjoint write";
286 cast<MemrefValue>(write.getBase()));
287 bool dominates = dominators.
dominates(write, read);
289 LDBG() <<
" -> Same view: " << sameView <<
", Dominates: " << dominates
290 <<
", Same value: " << sameValue;
292 if (sameView && dominates && sameValue) {
293 LDBG() <<
" -> Valid forwarding candidate found";
294 if (lastwrite ==
nullptr || dominators.
dominates(lastwrite, write)) {
295 LDBG() <<
" -> New last write candidate: " << *write;
298 LDBG() <<
" -> Keeping existing last write candidate";
299 assert(dominators.
dominates(write, lastwrite));
303 LDBG() <<
" -> Not a valid forwarding candidate";
305 LDBG() <<
" -> Adding to blocking writes: " << *user;
306 blockingWrites.push_back(user);
308 LDBG() <<
"Finished processing users. Found " << blockingWrites.size()
309 <<
" blocking writes";
311 if (lastwrite ==
nullptr) {
312 LDBG() <<
"No last write candidate found, cannot forward";
316 LDBG() <<
"Last write candidate: " << *lastwrite;
319 assert(readAncestor &&
320 "read op should be recursively part of the top region");
321 LDBG() <<
"Read ancestor in top region: " << *readAncestor;
323 LDBG() <<
"Checking " << blockingWrites.size()
324 <<
" blocking writes for post-dominance";
325 for (Operation *write : blockingWrites) {
326 LDBG() <<
"Checking blocking write: " << *write;
329 LDBG() <<
" -> Write ancestor: " << *writeAncestor;
331 LDBG() <<
" -> Write ancestor: nullptr";
336 if (writeAncestor ==
nullptr) {
337 LDBG() <<
" -> No ancestor in top region, skipping";
341 bool isReachableToRead = isReachable(writeAncestor, readAncestor);
342 LDBG() <<
" -> Is reachable to read: " << isReachableToRead;
343 if (!isReachableToRead) {
344 LDBG() <<
" -> Not reachable, skipping";
348 bool lastWritePostDominates =
350 LDBG() <<
" -> Last write post-dominates blocking write: "
351 << lastWritePostDominates;
352 if (!lastWritePostDominates) {
353 LDBG() <<
"Fail to do write to read forwarding due to op: " << *write;
356 LDBG() <<
" -> Blocking write is post-dominated, continuing";
359 LDBG() <<
"Forward value from " << *lastwrite.getOperation()
360 <<
" to: " << *read.getOperation();
361 read.replaceAllUsesWith(lastwrite.getVector());
362 opToErase.push_back(read.getOperation());
368 for (
const auto size : mixedSizes) {
369 if (llvm::dyn_cast_if_present<Value>(size)) {
370 reducedShape.push_back(ShapedType::kDynamic);
374 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
377 reducedShape.push_back(value.getSExtValue());
388 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
389 targetShape, inputType, offsets, sizes, strides);
390 return rankReducedType.canonicalizeStridedLayout();
398 MemRefType inputType = cast<MemRefType>(input.
getType());
404 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
406 if (resultType.canonicalizeStridedLayout() ==
407 inputType.canonicalizeStridedLayout())
409 return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
415 return llvm::count_if(
shape, [](
int64_t dimSize) {
return dimSize != 1; });
423 for (
auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
424 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
426 newShape.push_back(dimSize);
427 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
429 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
433static FailureOr<Value>
435 vector::CreateMaskOp op) {
436 auto type = op.getType();
438 if (reducedType.getRank() == type.getRank())
442 for (
auto [dim, dimIsScalable, operand] : llvm::zip_equal(
443 type.getShape(), type.getScalableDims(), op.getOperands())) {
444 if (dim == 1 && !dimIsScalable) {
447 if (!constant || (constant.value() != 1))
451 reducedOperands.push_back(operand);
453 return vector::CreateMaskOp::create(rewriter, loc, reducedType,
463class TransferReadDropUnitDimsPattern
464 :
public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
465 using MaskableOpRewritePattern::MaskableOpRewritePattern;
468 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
469 vector::MaskingOpInterface maskingOp,
470 PatternRewriter &rewriter)
const override {
471 LDBG() <<
"=== TransferReadDropUnitDimsPattern: Analyzing "
473 auto loc = transferReadOp.getLoc();
474 Value vector = transferReadOp.getVector();
475 VectorType vectorType = cast<VectorType>(vector.
getType());
476 Value source = transferReadOp.getBase();
477 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
480 LDBG() <<
" -> Not a MemRefType, skipping";
484 if (transferReadOp.hasOutOfBoundsDim()) {
485 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
488 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
489 LDBG() <<
" -> Not minor identity permutation map, skipping";
494 LDBG() <<
" -> Source rank: " << sourceType.getRank()
495 <<
", Reduced rank: " << reducedRank;
496 if (reducedRank == sourceType.getRank()) {
497 LDBG() <<
" -> No unit dimensions to drop, skipping";
502 if (reducedRank == 0 && maskingOp) {
503 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
509 LDBG() <<
" -> Vector type: " << vectorType
510 <<
", Reduced vector type: " << reducedVectorType;
511 if (reducedRank != reducedVectorType.getRank()) {
512 LDBG() <<
" -> Reduced ranks don't match, skipping";
515 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
516 return getConstantIntValue(v) != static_cast<int64_t>(0);
518 LDBG() <<
" -> Non-zero indices found, skipping";
522 Value maskOp = transferReadOp.getMask();
524 LDBG() <<
" -> Processing mask operation";
525 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
528 <<
" -> Unsupported mask op, only 'vector.create_mask' supported";
530 transferReadOp,
"unsupported mask op, only 'vector.create_mask' is "
531 "currently supported");
533 FailureOr<Value> rankReducedCreateMask =
535 if (
failed(rankReducedCreateMask)) {
536 LDBG() <<
" -> Failed to reduce mask dimensions";
539 maskOp = *rankReducedCreateMask;
540 LDBG() <<
" -> Successfully reduced mask dimensions";
543 LDBG() <<
" -> Creating rank-reduced subview and new transfer_read";
544 Value reducedShapeSource =
547 SmallVector<Value> zeros(reducedRank, c0);
549 SmallVector<bool> inBounds(reducedVectorType.getRank(),
true);
550 Operation *newTransferReadOp = vector::TransferReadOp::create(
551 rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
552 identityMap, transferReadOp.getPadding(), maskOp,
554 LDBG() <<
" -> Created new transfer_read: " << *newTransferReadOp;
557 LDBG() <<
" -> Applying masking operation";
558 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
559 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
560 maskingOp.getMask());
562 rewriter, newTransferReadOp, shapeCastMask);
565 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
566 loc, vectorType, newTransferReadOp->
getResults()[0]);
567 LDBG() <<
" -> Created shape cast: " << *shapeCast.getDefiningOp();
568 LDBG() <<
" -> Pattern match successful, returning result";
577class TransferWriteDropUnitDimsPattern
578 :
public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
579 using MaskableOpRewritePattern::MaskableOpRewritePattern;
582 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
583 vector::MaskingOpInterface maskingOp,
584 PatternRewriter &rewriter)
const override {
585 LDBG() <<
"=== TransferWriteDropUnitDimsPattern: Analyzing "
587 auto loc = transferWriteOp.getLoc();
588 Value vector = transferWriteOp.getVector();
589 VectorType vectorType = cast<VectorType>(vector.
getType());
590 Value source = transferWriteOp.getBase();
591 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
594 LDBG() <<
" -> Not a MemRefType, skipping";
598 if (transferWriteOp.hasOutOfBoundsDim()) {
599 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
602 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
603 LDBG() <<
" -> Not minor identity permutation map, skipping";
608 LDBG() <<
" -> Source rank: " << sourceType.getRank()
609 <<
", Reduced rank: " << reducedRank;
610 if (reducedRank == sourceType.getRank()) {
611 LDBG() <<
" -> No unit dimensions to drop, skipping";
616 if (reducedRank == 0 && maskingOp) {
617 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
623 LDBG() <<
" -> Vector type: " << vectorType
624 <<
", Reduced vector type: " << reducedVectorType;
625 if (reducedRank != reducedVectorType.getRank()) {
626 LDBG() <<
" -> Reduced ranks don't match, skipping";
629 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
630 return getConstantIntValue(v) != static_cast<int64_t>(0);
632 LDBG() <<
" -> Non-zero indices found, skipping";
636 Value maskOp = transferWriteOp.getMask();
638 LDBG() <<
" -> Processing mask operation";
639 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
642 <<
" -> Unsupported mask op, only 'vector.create_mask' supported";
645 "unsupported mask op, only 'vector.create_mask' is "
646 "currently supported");
648 FailureOr<Value> rankReducedCreateMask =
650 if (
failed(rankReducedCreateMask)) {
651 LDBG() <<
" -> Failed to reduce mask dimensions";
654 maskOp = *rankReducedCreateMask;
655 LDBG() <<
" -> Successfully reduced mask dimensions";
657 LDBG() <<
" -> Creating rank-reduced subview and new transfer_write";
658 Value reducedShapeSource =
661 SmallVector<Value> zeros(reducedRank, c0);
663 SmallVector<bool> inBounds(reducedVectorType.getRank(),
true);
664 auto shapeCastSrc = rewriter.
createOrFold<vector::ShapeCastOp>(
665 loc, reducedVectorType, vector);
666 Operation *newXferWrite = vector::TransferWriteOp::create(
667 rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
669 LDBG() <<
" -> Created new transfer_write: " << *newXferWrite;
672 LDBG() <<
" -> Applying masking operation";
673 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
674 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
675 maskingOp.getMask());
680 if (transferWriteOp.hasPureTensorSemantics()) {
681 LDBG() <<
" -> Pattern match successful (tensor semantics), returning "
683 return newXferWrite->getResults()[0];
688 LDBG() <<
" -> Pattern match successful (memref semantics)";
699 ShapedType inputType = cast<ShapedType>(input.
getType());
700 if (inputType.getRank() == 1)
703 for (
int64_t i = 0; i < firstDimToCollapse; ++i)
706 for (
int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
707 collapsedIndices.push_back(i);
708 reassociation.push_back(collapsedIndices);
709 return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
719 assert(firstDimToCollapse <
static_cast<int64_t>(
indices.size()));
728 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
729 return indicesAfterCollapsing;
754 auto &&[collapsedExpr, collapsedVals] =
757 rewriter, loc, collapsedExpr, collapsedVals);
759 if (
auto value = dyn_cast<Value>(collapsedOffset)) {
760 indicesAfterCollapsing.push_back(value);
766 return indicesAfterCollapsing;
778class FlattenContiguousRowMajorTransferReadPattern
779 :
public OpRewritePattern<vector::TransferReadOp> {
781 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
782 unsigned vectorBitwidth,
783 PatternBenefit benefit)
784 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
785 targetVectorBitwidth(vectorBitwidth) {}
787 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
788 PatternRewriter &rewriter)
const override {
789 LDBG() <<
"=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
791 auto loc = transferReadOp.
getLoc();
792 Value vector = transferReadOp.getVector();
793 VectorType vectorType = cast<VectorType>(vector.
getType());
794 auto source = transferReadOp.getBase();
795 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
800 LDBG() <<
" -> Not a MemRefType, skipping";
804 if (vectorType.getRank() <= 1) {
805 LDBG() <<
" -> Already 0D/1D, skipping";
808 if (!vectorType.getElementType().isSignlessIntOrFloat()) {
809 LDBG() <<
" -> Not signless int or float, skipping";
812 unsigned trailingVectorDimBitwidth =
813 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
814 LDBG() <<
" -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
815 <<
", target: " << targetVectorBitwidth;
816 if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
817 LDBG() <<
" -> Trailing dim bitwidth >= target, skipping";
821 LDBG() <<
" -> Not contiguous slice, skipping";
825 if (transferReadOp.hasOutOfBoundsDim()) {
826 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
829 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
830 LDBG() <<
" -> Not minor identity permutation map, skipping";
833 if (transferReadOp.getMask()) {
834 LDBG() <<
" -> Has mask, skipping";
840 ArrayRef<int64_t> collapsedVectorShape =
841 vectorType.getShape().drop_while([](
auto v) {
return v == 1; });
842 size_t collapsedVecRank = collapsedVectorShape.size();
845 if (collapsedVecRank == 0)
846 collapsedVecRank = 1;
850 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
851 LDBG() <<
" -> First dimension to collapse: " << firstDimToCollapse;
854 LDBG() <<
" -> Collapsing source memref";
855 Value collapsedSource =
857 MemRefType collapsedSourceType =
858 cast<MemRefType>(collapsedSource.
getType());
859 int64_t collapsedRank = collapsedSourceType.getRank();
860 assert(collapsedRank == firstDimToCollapse + 1);
861 LDBG() <<
" -> Collapsed source type: " << collapsedSourceType;
866 SmallVector<AffineExpr, 1> dimExprs{
872 SmallVector<Value> collapsedIndices =
874 transferReadOp.getIndices(), firstDimToCollapse);
877 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
878 vectorType.getElementType());
879 LDBG() <<
" -> Creating flattened vector type: " << flatVectorType;
880 vector::TransferReadOp flatRead = vector::TransferReadOp::create(
881 rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
882 transferReadOp.getPadding(), collapsedMap);
884 LDBG() <<
" -> Created flat transfer_read: " << *flatRead;
888 LDBG() <<
" -> Replacing with shape cast";
890 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
891 LDBG() <<
" -> Pattern match successful";
898 unsigned targetVectorBitwidth;
909class FlattenContiguousRowMajorTransferWritePattern
910 :
public OpRewritePattern<vector::TransferWriteOp> {
912 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
913 unsigned vectorBitwidth,
914 PatternBenefit benefit)
915 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
916 targetVectorBitwidth(vectorBitwidth) {}
918 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
919 PatternRewriter &rewriter)
const override {
920 auto loc = transferWriteOp.
getLoc();
921 Value vector = transferWriteOp.getVector();
922 VectorType vectorType = cast<VectorType>(vector.
getType());
923 Value source = transferWriteOp.getBase();
924 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
931 if (vectorType.getRank() <= 1)
934 if (!vectorType.getElementType().isSignlessIntOrFloat())
936 unsigned trailingVectorDimBitwidth =
937 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
938 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
943 if (transferWriteOp.hasOutOfBoundsDim())
945 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
947 if (transferWriteOp.getMask())
952 ArrayRef<int64_t> collapsedVectorShape =
953 vectorType.getShape().drop_while([](
auto v) {
return v == 1; });
954 size_t collapsedVecRank = collapsedVectorShape.size();
957 if (collapsedVecRank == 0)
958 collapsedVecRank = 1;
962 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
965 Value collapsedSource =
967 MemRefType collapsedSourceType =
968 cast<MemRefType>(collapsedSource.
getType());
969 int64_t collapsedRank = collapsedSourceType.getRank();
970 assert(collapsedRank == firstDimToCollapse + 1);
975 SmallVector<AffineExpr, 1> dimExprs{
981 SmallVector<Value> collapsedIndices =
983 transferWriteOp.getIndices(), firstDimToCollapse);
986 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
987 vectorType.getElementType());
989 vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
990 vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
991 rewriter, loc, flatVector, collapsedSource, collapsedIndices,
997 rewriter.
eraseOp(transferWriteOp);
1004 unsigned targetVectorBitwidth;
1014class RewriteScalarExtractOfTransferRead
1015 :
public OpRewritePattern<vector::ExtractOp> {
1017 RewriteScalarExtractOfTransferRead(MLIRContext *context,
1018 PatternBenefit benefit,
1019 bool allowMultipleUses)
1020 : OpRewritePattern(context, benefit),
1021 allowMultipleUses(allowMultipleUses) {}
1023 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1024 PatternRewriter &rewriter)
const override {
1026 auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
1030 if (isa<VectorType>(extractOp.getResult().getType()))
1033 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1036 if (allowMultipleUses &&
1037 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
1038 return isa<vector::ExtractOp>(use.getOwner());
1042 if (xferOp.getMask())
1045 if (!xferOp.getPermutationMap().isMinorIdentity())
1048 if (xferOp.hasOutOfBoundsDim())
1052 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
1053 xferOp.getIndices().end());
1054 for (
auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1055 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1059 OpFoldResult composedIdx;
1060 if (
auto attr = dyn_cast<Attribute>(pos)) {
1061 int64_t offset = cast<IntegerAttr>(attr).getInt();
1062 composedIdx = affine::makeComposedFoldedAffineApply(
1063 rewriter, extractOp.getLoc(),
1066 Value dynamicOffset = cast<Value>(pos);
1067 AffineExpr sym0, sym1;
1069 composedIdx = affine::makeComposedFoldedAffineApply(
1070 rewriter, extractOp.getLoc(), sym0 + sym1,
1071 {newIndices[idx], dynamicOffset});
1075 if (
auto value = dyn_cast<Value>(composedIdx)) {
1076 newIndices[idx] = value;
1082 if (isa<MemRefType>(xferOp.getBase().getType())) {
1087 extractOp, xferOp.getBase(), newIndices);
1094 bool allowMultipleUses;
1099class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
1102 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1103 PatternRewriter &rewriter)
const override {
1105 auto vecType = xferOp.getVectorType();
1106 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1109 if (xferOp.getMask())
1112 if (!xferOp.getPermutationMap().isMinorIdentity())
1115 Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1116 xferOp.getVector());
1118 if (isa<MemRefType>(xferOp.getBase().getType())) {
1120 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1123 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1131void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
1132 Operation *rootOp) {
1133 LDBG() <<
"=== Starting transferOpflowOpt on root operation: "
1134 << OpWithFlags(rootOp, OpPrintingFlags().skipRegions());
1135 TransferOptimization opt(rewriter, rootOp);
1139 LDBG() <<
"Phase 1: Store-to-load forwarding";
1141 rootOp->
walk([&](vector::TransferReadOp read) {
1142 if (isa<MemRefType>(read.getShapedType())) {
1143 LDBG() <<
"Processing transfer_read #" << ++readCount <<
": " << *read;
1144 opt.storeToLoadForwarding(read);
1147 LDBG() <<
"Phase 1 complete. Removing dead operations from forwarding";
1150 LDBG() <<
"Phase 2: Dead store elimination";
1152 rootOp->
walk([&](vector::TransferWriteOp write) {
1153 if (isa<MemRefType>(write.getShapedType())) {
1154 LDBG() <<
"Processing transfer_write #" << ++writeCount <<
": " << *write;
1155 opt.deadStoreOp(write);
1158 LDBG() <<
"Phase 2 complete. Removing dead operations from dead store "
1161 LDBG() <<
"=== transferOpflowOpt complete";
1166 bool allowMultipleUses) {
1167 patterns.
add<RewriteScalarExtractOfTransferRead>(patterns.
getContext(),
1168 benefit, allowMultipleUses);
1169 patterns.
add<RewriteScalarWrite>(patterns.
getContext(), benefit);
1172void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1175 .
add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1179void mlir::vector::populateFlattenVectorTransferPatterns(
1182 patterns.
add<FlattenContiguousRowMajorTransferReadPattern,
1183 FlattenContiguousRowMajorTransferWritePattern>(
1184 patterns.
getContext(), targetVectorBitwidth, benefit);
1185 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.