29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/StringRef.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/DebugLog.h"
33#include "llvm/Support/LogicalResult.h"
35#define DEBUG_TYPE "vector-transfer-opt"
42 LDBG() <<
" Finding ancestor of " << *op <<
" in region";
47 LDBG() <<
" -> Ancestor: " << *op;
49 LDBG() <<
" -> Ancestor: nullptr";
56class TransferOptimization {
58 TransferOptimization(RewriterBase &rewriter, Operation *op)
59 : rewriter(rewriter), dominators(op), postDominators(op) {}
60 void deadStoreOp(vector::TransferWriteOp);
61 void storeToLoadForwarding(vector::TransferReadOp);
63 LDBG() <<
"Removing " << opToErase.size() <<
" dead operations";
64 for (Operation *op : opToErase) {
65 LDBG() <<
" -> Erasing: " << *op;
72 RewriterBase &rewriter;
73 bool isReachable(Operation *start, Operation *dest);
74 DominanceInfo dominators;
75 PostDominanceInfo postDominators;
76 std::vector<Operation *> opToErase;
83 LDBG() <<
" Checking reachability from " << *start <<
" to " << *dest;
85 "This function only works for ops i the same region");
88 LDBG() <<
" -> Start dominates dest, reachable";
92 LDBG() <<
" -> Block reachable: " << blockReachable;
93 return blockReachable;
107void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
108 LDBG() <<
"=== Starting deadStoreOp analysis for: " << *write.getOperation();
109 llvm::SmallVector<Operation *, 8> blockingAccesses;
110 Operation *firstOverwriteCandidate =
nullptr;
112 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
113 llvm::SmallVector<Operation *, 32> users(source.
getUsers().begin(),
115 LDBG() <<
"Found " << users.size() <<
" users of source memref";
116 llvm::SmallDenseSet<Operation *, 32> processed;
117 while (!users.empty()) {
118 Operation *user = users.pop_back_val();
119 LDBG() <<
"Processing user: " << *user;
121 if (!processed.insert(user).second) {
122 LDBG() <<
" -> Already processed, skipping";
125 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
126 LDBG() <<
" -> View-like operation, following to destination";
127 Value viewDest = viewLike.getViewDest();
132 LDBG() <<
" -> Memory effect free, skipping";
135 if (user == write.getOperation()) {
136 LDBG() <<
" -> Same as write operation, skipping";
139 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
140 LDBG() <<
" -> Found transfer_write candidate: " << *nextWrite;
143 cast<MemrefValue>(nextWrite.getBase()),
144 cast<MemrefValue>(write.getBase()));
146 bool postDominates = postDominators.
postDominates(nextWrite, write);
147 LDBG() <<
" -> Same view: " << sameView
148 <<
", Same value: " << sameValue
149 <<
", Post-dominates: " << postDominates;
151 if (sameView && sameValue && postDominates) {
152 LDBG() <<
" -> Valid overwrite candidate found";
153 if (firstOverwriteCandidate ==
nullptr ||
154 postDominators.
postDominates(firstOverwriteCandidate, nextWrite)) {
155 LDBG() <<
" -> New first overwrite candidate: " << *nextWrite;
156 firstOverwriteCandidate = nextWrite;
158 LDBG() <<
" -> Keeping existing first overwrite candidate";
160 postDominators.
postDominates(nextWrite, firstOverwriteCandidate));
164 LDBG() <<
" -> Not a valid overwrite candidate";
166 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
167 LDBG() <<
" -> Found vector transfer operation: " << *transferOp;
170 cast<VectorTransferOpInterface>(write.getOperation()),
171 cast<VectorTransferOpInterface>(transferOp.getOperation()),
173 LDBG() <<
" -> Is disjoint: " << isDisjoint;
175 LDBG() <<
" -> Skipping disjoint access";
179 LDBG() <<
" -> Adding to blocking accesses: " << *user;
180 blockingAccesses.push_back(user);
182 LDBG() <<
"Finished processing users. Found " << blockingAccesses.size()
183 <<
" blocking accesses";
185 if (firstOverwriteCandidate ==
nullptr) {
186 LDBG() <<
"No overwrite candidate found, store is not dead";
190 LDBG() <<
"First overwrite candidate: " << *firstOverwriteCandidate;
193 assert(writeAncestor &&
194 "write op should be recursively part of the top region");
195 LDBG() <<
"Write ancestor in top region: " << *writeAncestor;
197 LDBG() <<
"Checking " << blockingAccesses.size()
198 <<
" blocking accesses for reachability";
199 for (Operation *access : blockingAccesses) {
200 LDBG() <<
"Checking blocking access: " << *access;
204 if (accessAncestor ==
nullptr) {
205 LDBG() <<
" -> No ancestor in top region, skipping";
209 bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
210 LDBG() <<
" -> Is reachable from write: " << isReachableFromWrite;
211 if (!isReachableFromWrite) {
212 LDBG() <<
" -> Not reachable, skipping";
216 bool overwriteDominatesAccess =
217 dominators.
dominates(firstOverwriteCandidate, accessAncestor);
218 LDBG() <<
" -> Overwrite dominates access: " << overwriteDominatesAccess;
219 if (!overwriteDominatesAccess) {
220 LDBG() <<
"Store may not be dead due to op: " << *accessAncestor;
223 LDBG() <<
" -> Access is dominated by overwrite, continuing";
225 LDBG() <<
"Found dead store: " << *write.getOperation()
226 <<
" overwritten by: " << *firstOverwriteCandidate;
227 opToErase.push_back(write.getOperation());
241void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
242 LDBG() <<
"=== Starting storeToLoadForwarding analysis for: "
243 << *read.getOperation();
244 if (read.hasOutOfBoundsDim()) {
245 LDBG() <<
"Read has out-of-bounds dimensions, skipping";
248 SmallVector<Operation *, 8> blockingWrites;
249 vector::TransferWriteOp lastwrite =
nullptr;
251 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
252 llvm::SmallVector<Operation *, 32> users(source.
getUsers().begin(),
254 LDBG() <<
"Found " << users.size() <<
" users of source memref";
255 llvm::SmallDenseSet<Operation *, 32> processed;
256 while (!users.empty()) {
257 Operation *user = users.pop_back_val();
258 LDBG() <<
"Processing user: " << *user;
260 if (!processed.insert(user).second) {
261 LDBG() <<
" -> Already processed, skipping";
264 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
265 LDBG() <<
" -> View-like operation, following to destination";
266 Value viewDest = viewLike.getViewDest();
271 LDBG() <<
" -> Memory effect free or transfer_read, skipping";
274 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
275 LDBG() <<
" -> Found transfer_write candidate: " << *write;
279 cast<VectorTransferOpInterface>(write.getOperation()),
280 cast<VectorTransferOpInterface>(read.getOperation()),
282 LDBG() <<
" -> Is disjoint: " << isDisjoint;
284 LDBG() <<
" -> Skipping disjoint write";
290 cast<MemrefValue>(write.getBase()));
291 bool dominates = dominators.
dominates(write, read);
293 LDBG() <<
" -> Same view: " << sameView <<
", Dominates: " << dominates
294 <<
", Same value: " << sameValue;
296 if (sameView && dominates && sameValue) {
297 LDBG() <<
" -> Valid forwarding candidate found";
298 if (lastwrite ==
nullptr || dominators.
dominates(lastwrite, write)) {
299 LDBG() <<
" -> New last write candidate: " << *write;
302 LDBG() <<
" -> Keeping existing last write candidate";
303 assert(dominators.
dominates(write, lastwrite));
307 LDBG() <<
" -> Not a valid forwarding candidate";
309 LDBG() <<
" -> Adding to blocking writes: " << *user;
310 blockingWrites.push_back(user);
312 LDBG() <<
"Finished processing users. Found " << blockingWrites.size()
313 <<
" blocking writes";
315 if (lastwrite ==
nullptr) {
316 LDBG() <<
"No last write candidate found, cannot forward";
320 LDBG() <<
"Last write candidate: " << *lastwrite;
323 assert(readAncestor &&
324 "read op should be recursively part of the top region");
325 LDBG() <<
"Read ancestor in top region: " << *readAncestor;
327 LDBG() <<
"Checking " << blockingWrites.size()
328 <<
" blocking writes for post-dominance";
329 for (Operation *write : blockingWrites) {
330 LDBG() <<
"Checking blocking write: " << *write;
333 LDBG() <<
" -> Write ancestor: " << *writeAncestor;
335 LDBG() <<
" -> Write ancestor: nullptr";
340 if (writeAncestor ==
nullptr) {
341 LDBG() <<
" -> No ancestor in top region, skipping";
345 bool isReachableToRead = isReachable(writeAncestor, readAncestor);
346 LDBG() <<
" -> Is reachable to read: " << isReachableToRead;
347 if (!isReachableToRead) {
348 LDBG() <<
" -> Not reachable, skipping";
352 bool lastWritePostDominates =
354 LDBG() <<
" -> Last write post-dominates blocking write: "
355 << lastWritePostDominates;
356 if (!lastWritePostDominates) {
357 LDBG() <<
"Fail to do write to read forwarding due to op: " << *write;
360 LDBG() <<
" -> Blocking write is post-dominated, continuing";
363 LDBG() <<
"Forward value from " << *lastwrite.getOperation()
364 <<
" to: " << *read.getOperation();
365 read.replaceAllUsesWith(lastwrite.getVector());
366 opToErase.push_back(read.getOperation());
372 for (
const auto size : mixedSizes) {
373 if (llvm::dyn_cast_if_present<Value>(size)) {
374 reducedShape.push_back(ShapedType::kDynamic);
378 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
381 reducedShape.push_back(value.getSExtValue());
392 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
393 targetShape, inputType, offsets, sizes, strides);
394 return rankReducedType.canonicalizeStridedLayout();
402 MemRefType inputType = cast<MemRefType>(input.
getType());
408 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
410 if (resultType.canonicalizeStridedLayout() ==
411 inputType.canonicalizeStridedLayout())
413 return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
419 return llvm::count_if(
shape, [](
int64_t dimSize) {
return dimSize != 1; });
427 for (
auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
428 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
430 newShape.push_back(dimSize);
431 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
433 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
443template <
typename MaskOp>
446 auto type = op.getType();
448 if (reducedType.getRank() == type.getRank())
451 using ElemType = std::decay_t<
decltype(*op.getMaskDimSizes().begin())>;
453 for (
auto [dim, dimIsScalable, elem] : llvm::zip_equal(
454 type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
455 if (dim == 1 && !dimIsScalable) {
460 reduced.push_back(elem);
462 return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
470class TransferReadDropUnitDimsPattern
471 :
public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
472 using MaskableOpRewritePattern::MaskableOpRewritePattern;
475 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
476 vector::MaskingOpInterface maskingOp,
477 PatternRewriter &rewriter)
const override {
478 LDBG() <<
"=== TransferReadDropUnitDimsPattern: Analyzing "
480 auto loc = transferReadOp.getLoc();
481 Value vector = transferReadOp.getVector();
482 VectorType vectorType = cast<VectorType>(vector.
getType());
483 Value source = transferReadOp.getBase();
484 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
487 LDBG() <<
" -> Not a MemRefType, skipping";
491 if (transferReadOp.hasOutOfBoundsDim()) {
492 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
495 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
496 LDBG() <<
" -> Not minor identity permutation map, skipping";
501 LDBG() <<
" -> Source rank: " << sourceType.getRank()
502 <<
", Reduced rank: " << reducedRank;
503 if (reducedRank == sourceType.getRank()) {
504 LDBG() <<
" -> No unit dimensions to drop, skipping";
509 if (reducedRank == 0 && maskingOp) {
510 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
516 LDBG() <<
" -> Vector type: " << vectorType
517 <<
", Reduced vector type: " << reducedVectorType;
518 if (reducedRank != reducedVectorType.getRank()) {
519 LDBG() <<
" -> Reduced ranks don't match, skipping";
522 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
523 return getConstantIntValue(v) != static_cast<int64_t>(0);
525 LDBG() <<
" -> Non-zero indices found, skipping";
529 Value maskOp = transferReadOp.getMask();
531 LDBG() <<
" -> Processing mask operation";
532 auto maskVectorType = cast<VectorType>(maskOp.
getType());
533 FailureOr<Value> rankReducedMaskOp = failure();
534 if (
auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>())
537 else if (
auto constantMaskOp =
544 "unsupported mask op, only 'vector.create_mask' and "
545 "'vector.constant_mask' are currently supported");
547 if (succeeded(rankReducedMaskOp)) {
548 maskOp = *rankReducedMaskOp;
549 LDBG() <<
" -> Successfully reduced mask dimensions";
550 }
else if (maskVectorType.getRank() != reducedVectorType.getRank()) {
552 transferReadOp,
"Mask reduction required, but failed");
556 LDBG() <<
" -> Creating rank-reduced subview and new transfer_read";
557 Value reducedShapeSource =
560 Repeated<Value> zeros(reducedRank, c0);
562 SmallVector<bool> inBounds(reducedVectorType.getRank(),
true);
563 Operation *newTransferReadOp = vector::TransferReadOp::create(
564 rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
565 identityMap, transferReadOp.getPadding(), maskOp,
567 LDBG() <<
" -> Created new transfer_read: " << *newTransferReadOp;
570 LDBG() <<
" -> Applying masking operation";
571 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
572 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
573 maskingOp.getMask());
575 rewriter, newTransferReadOp, shapeCastMask);
578 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
579 loc, vectorType, newTransferReadOp->
getResults()[0]);
580 LDBG() <<
" -> Created shape cast: " << *shapeCast.getDefiningOp();
581 LDBG() <<
" -> Pattern match successful, returning result";
590class TransferWriteDropUnitDimsPattern
591 :
public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
592 using MaskableOpRewritePattern::MaskableOpRewritePattern;
595 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
596 vector::MaskingOpInterface maskingOp,
597 PatternRewriter &rewriter)
const override {
598 LDBG() <<
"=== TransferWriteDropUnitDimsPattern: Analyzing "
600 auto loc = transferWriteOp.getLoc();
601 Value vector = transferWriteOp.getVector();
602 VectorType vectorType = cast<VectorType>(vector.
getType());
603 Value source = transferWriteOp.getBase();
604 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
607 LDBG() <<
" -> Not a MemRefType, skipping";
611 if (transferWriteOp.hasOutOfBoundsDim()) {
612 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
615 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
616 LDBG() <<
" -> Not minor identity permutation map, skipping";
621 LDBG() <<
" -> Source rank: " << sourceType.getRank()
622 <<
", Reduced rank: " << reducedRank;
623 if (reducedRank == sourceType.getRank()) {
624 LDBG() <<
" -> No unit dimensions to drop, skipping";
629 if (reducedRank == 0 && maskingOp) {
630 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
636 LDBG() <<
" -> Vector type: " << vectorType
637 <<
", Reduced vector type: " << reducedVectorType;
638 if (reducedRank != reducedVectorType.getRank()) {
639 LDBG() <<
" -> Reduced ranks don't match, skipping";
642 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
643 return getConstantIntValue(v) != static_cast<int64_t>(0);
645 LDBG() <<
" -> Non-zero indices found, skipping";
649 Value maskOp = transferWriteOp.getMask();
651 LDBG() <<
" -> Processing mask operation";
652 auto maskVectorType = cast<VectorType>(maskOp.
getType());
653 FailureOr<Value> rankReducedMask = failure();
654 if (
auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>())
657 else if (
auto constantMaskOp =
664 "unsupported mask op, only 'vector.create_mask' and "
665 "'vector.constant_mask' are currently supported");
667 if (succeeded(rankReducedMask)) {
668 maskOp = *rankReducedMask;
669 LDBG() <<
" -> Successfully reduced mask dimensions";
670 }
else if (maskVectorType.getRank() != reducedVectorType.getRank()) {
672 transferWriteOp,
"Mask reduction required, but failed");
675 LDBG() <<
" -> Creating rank-reduced subview and new transfer_write";
676 Value reducedShapeSource =
679 Repeated<Value> zeros(reducedRank, c0);
681 SmallVector<bool> inBounds(reducedVectorType.getRank(),
true);
682 auto shapeCastSrc = rewriter.
createOrFold<vector::ShapeCastOp>(
683 loc, reducedVectorType, vector);
684 Operation *newXferWrite = vector::TransferWriteOp::create(
685 rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
687 LDBG() <<
" -> Created new transfer_write: " << *newXferWrite;
690 LDBG() <<
" -> Applying masking operation";
691 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
692 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
693 maskingOp.getMask());
698 if (transferWriteOp.hasPureTensorSemantics()) {
699 LDBG() <<
" -> Pattern match successful (tensor semantics), returning "
701 return newXferWrite->getResults()[0];
706 LDBG() <<
" -> Pattern match successful (memref semantics)";
717 ShapedType inputType = cast<ShapedType>(input.
getType());
718 if (inputType.getRank() == 1)
721 for (
int64_t i = 0; i < firstDimToCollapse; ++i)
724 for (
int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
725 collapsedIndices.push_back(i);
726 reassociation.push_back(collapsedIndices);
727 return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
737 assert(firstDimToCollapse <
static_cast<int64_t>(
indices.size()));
746 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
747 return indicesAfterCollapsing;
772 auto &&[collapsedExpr, collapsedVals] =
775 rewriter, loc, collapsedExpr, collapsedVals);
777 if (
auto value = dyn_cast<Value>(collapsedOffset)) {
778 indicesAfterCollapsing.push_back(value);
784 return indicesAfterCollapsing;
796class FlattenContiguousRowMajorTransferReadPattern
797 :
public OpRewritePattern<vector::TransferReadOp> {
799 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
800 unsigned vectorBitwidth,
801 PatternBenefit benefit)
802 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
803 targetVectorBitwidth(vectorBitwidth) {}
805 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
806 PatternRewriter &rewriter)
const override {
807 LDBG() <<
"=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
809 auto loc = transferReadOp.
getLoc();
810 Value vector = transferReadOp.getVector();
811 VectorType vectorType = cast<VectorType>(vector.
getType());
812 auto source = transferReadOp.getBase();
813 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
818 LDBG() <<
" -> Not a MemRefType, skipping";
822 if (vectorType.getRank() <= 1) {
823 LDBG() <<
" -> Already 0D/1D, skipping";
826 if (!vectorType.getElementType().isSignlessIntOrFloat()) {
827 LDBG() <<
" -> Not signless int or float, skipping";
830 unsigned trailingVectorDimBitwidth =
831 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
832 LDBG() <<
" -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
833 <<
", target: " << targetVectorBitwidth;
834 if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
835 LDBG() <<
" -> Trailing dim bitwidth >= target, skipping";
839 LDBG() <<
" -> Not contiguous slice, skipping";
843 if (transferReadOp.hasOutOfBoundsDim()) {
844 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
847 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
848 LDBG() <<
" -> Not minor identity permutation map, skipping";
851 if (transferReadOp.getMask()) {
852 LDBG() <<
" -> Has mask, skipping";
858 ArrayRef<int64_t> collapsedVectorShape =
859 vectorType.getShape().drop_while([](
auto v) {
return v == 1; });
860 size_t collapsedVecRank = collapsedVectorShape.size();
863 if (collapsedVecRank == 0)
864 collapsedVecRank = 1;
868 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
869 LDBG() <<
" -> First dimension to collapse: " << firstDimToCollapse;
872 LDBG() <<
" -> Collapsing source memref";
873 Value collapsedSource =
875 MemRefType collapsedSourceType =
876 cast<MemRefType>(collapsedSource.
getType());
877 int64_t collapsedRank = collapsedSourceType.getRank();
878 assert(collapsedRank == firstDimToCollapse + 1);
879 LDBG() <<
" -> Collapsed source type: " << collapsedSourceType;
884 SmallVector<AffineExpr, 1> dimExprs{
890 SmallVector<Value> collapsedIndices =
892 transferReadOp.getIndices(), firstDimToCollapse);
895 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
896 vectorType.getElementType());
897 LDBG() <<
" -> Creating flattened vector type: " << flatVectorType;
898 vector::TransferReadOp flatRead = vector::TransferReadOp::create(
899 rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
900 transferReadOp.getPadding(), collapsedMap);
902 LDBG() <<
" -> Created flat transfer_read: " << *flatRead;
906 LDBG() <<
" -> Replacing with shape cast";
908 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
909 LDBG() <<
" -> Pattern match successful";
916 unsigned targetVectorBitwidth;
927class FlattenContiguousRowMajorTransferWritePattern
928 :
public OpRewritePattern<vector::TransferWriteOp> {
930 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
931 unsigned vectorBitwidth,
932 PatternBenefit benefit)
933 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
934 targetVectorBitwidth(vectorBitwidth) {}
936 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
937 PatternRewriter &rewriter)
const override {
938 auto loc = transferWriteOp.
getLoc();
939 Value vector = transferWriteOp.getVector();
940 VectorType vectorType = cast<VectorType>(vector.
getType());
941 Value source = transferWriteOp.getBase();
942 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
949 if (vectorType.getRank() <= 1)
952 if (!vectorType.getElementType().isSignlessIntOrFloat())
954 unsigned trailingVectorDimBitwidth =
955 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
956 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
961 if (transferWriteOp.hasOutOfBoundsDim())
963 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
965 if (transferWriteOp.getMask())
970 ArrayRef<int64_t> collapsedVectorShape =
971 vectorType.getShape().drop_while([](
auto v) {
return v == 1; });
972 size_t collapsedVecRank = collapsedVectorShape.size();
975 if (collapsedVecRank == 0)
976 collapsedVecRank = 1;
980 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
983 Value collapsedSource =
985 MemRefType collapsedSourceType =
986 cast<MemRefType>(collapsedSource.
getType());
987 int64_t collapsedRank = collapsedSourceType.getRank();
988 assert(collapsedRank == firstDimToCollapse + 1);
993 SmallVector<AffineExpr, 1> dimExprs{
999 SmallVector<Value> collapsedIndices =
1001 transferWriteOp.getIndices(), firstDimToCollapse);
1004 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
1005 vectorType.getElementType());
1007 vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
1008 vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
1009 rewriter, loc, flatVector, collapsedSource, collapsedIndices,
1015 rewriter.
eraseOp(transferWriteOp);
1022 unsigned targetVectorBitwidth;
1032class RewriteScalarExtractOfTransferRead
1033 :
public OpRewritePattern<vector::ExtractOp> {
1035 RewriteScalarExtractOfTransferRead(MLIRContext *context,
1036 PatternBenefit benefit,
1037 bool allowMultipleUses)
1038 : OpRewritePattern(context, benefit),
1039 allowMultipleUses(allowMultipleUses) {}
1041 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1042 PatternRewriter &rewriter)
const override {
1044 auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
1048 if (isa<VectorType>(extractOp.getResult().getType()))
1051 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1054 if (allowMultipleUses &&
1055 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
1056 return isa<vector::ExtractOp>(use.getOwner());
1060 if (xferOp.getMask())
1063 if (!xferOp.getPermutationMap().isMinorIdentity())
1066 if (xferOp.hasOutOfBoundsDim())
1070 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
1071 xferOp.getIndices().end());
1072 for (
auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1073 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1077 OpFoldResult composedIdx;
1078 if (
auto attr = dyn_cast<Attribute>(pos)) {
1079 int64_t offset = cast<IntegerAttr>(attr).getInt();
1081 rewriter, extractOp.getLoc(),
1084 Value dynamicOffset = cast<Value>(pos);
1085 AffineExpr sym0, sym1;
1088 rewriter, extractOp.getLoc(), sym0 + sym1,
1089 {newIndices[idx], dynamicOffset});
1093 if (
auto value = dyn_cast<Value>(composedIdx)) {
1094 newIndices[idx] = value;
1100 if (isa<MemRefType>(xferOp.getBase().getType())) {
1105 extractOp, xferOp.getBase(), newIndices);
1112 bool allowMultipleUses;
1117class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
1120 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1121 PatternRewriter &rewriter)
const override {
1123 auto vecType = xferOp.getVectorType();
1124 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1127 if (xferOp.getMask())
1130 if (!xferOp.getPermutationMap().isMinorIdentity())
1133 Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1134 xferOp.getVector());
1136 if (isa<MemRefType>(xferOp.getBase().getType())) {
1138 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1141 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1149void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
1150 Operation *rootOp) {
1151 LDBG() <<
"=== Starting transferOpflowOpt on root operation: "
1152 << OpWithFlags(rootOp, OpPrintingFlags().skipRegions());
1153 TransferOptimization opt(rewriter, rootOp);
1157 LDBG() <<
"Phase 1: Store-to-load forwarding";
1159 rootOp->
walk([&](vector::TransferReadOp read) {
1160 if (isa<MemRefType>(read.getShapedType())) {
1161 LDBG() <<
"Processing transfer_read #" << ++readCount <<
": " << *read;
1162 opt.storeToLoadForwarding(read);
1165 LDBG() <<
"Phase 1 complete. Removing dead operations from forwarding";
1168 LDBG() <<
"Phase 2: Dead store elimination";
1170 rootOp->
walk([&](vector::TransferWriteOp write) {
1171 if (isa<MemRefType>(write.getShapedType())) {
1172 LDBG() <<
"Processing transfer_write #" << ++writeCount <<
": " << *write;
1173 opt.deadStoreOp(write);
1176 LDBG() <<
"Phase 2 complete. Removing dead operations from dead store "
1179 LDBG() <<
"=== transferOpflowOpt complete";
1184 bool allowMultipleUses) {
1185 patterns.
add<RewriteScalarExtractOfTransferRead>(patterns.
getContext(),
1186 benefit, allowMultipleUses);
1187 patterns.
add<RewriteScalarWrite>(patterns.
getContext(), benefit);
1190void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1193 .
add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1197void mlir::vector::populateFlattenVectorTransferPatterns(
1200 patterns.
add<FlattenContiguousRowMajorTransferReadPattern,
1201 FlattenContiguousRowMajorTransferWritePattern>(
1202 patterns.
getContext(), targetVectorBitwidth, benefit);
1203 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.