28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/StringRef.h"
30#include "llvm/Support/DebugLog.h"
32#define DEBUG_TYPE "vector-transfer-opt"
39 LDBG() <<
" Finding ancestor of " << *op <<
" in region";
44 LDBG() <<
" -> Ancestor: " << *op;
46 LDBG() <<
" -> Ancestor: nullptr";
53class TransferOptimization {
55 TransferOptimization(RewriterBase &rewriter, Operation *op)
56 : rewriter(rewriter), dominators(op), postDominators(op) {}
57 void deadStoreOp(vector::TransferWriteOp);
58 void storeToLoadForwarding(vector::TransferReadOp);
60 LDBG() <<
"Removing " << opToErase.size() <<
" dead operations";
61 for (Operation *op : opToErase) {
62 LDBG() <<
" -> Erasing: " << *op;
69 RewriterBase &rewriter;
70 bool isReachable(Operation *start, Operation *dest);
71 DominanceInfo dominators;
72 PostDominanceInfo postDominators;
73 std::vector<Operation *> opToErase;
80 LDBG() <<
" Checking reachability from " << *start <<
" to " << *dest;
82 "This function only works for ops i the same region");
85 LDBG() <<
" -> Start dominates dest, reachable";
89 LDBG() <<
" -> Block reachable: " << blockReachable;
90 return blockReachable;
104void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
105 LDBG() <<
"=== Starting deadStoreOp analysis for: " << *write.getOperation();
106 llvm::SmallVector<Operation *, 8> blockingAccesses;
107 Operation *firstOverwriteCandidate =
nullptr;
109 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
110 llvm::SmallVector<Operation *, 32> users(source.
getUsers().begin(),
112 LDBG() <<
"Found " << users.size() <<
" users of source memref";
113 llvm::SmallDenseSet<Operation *, 32> processed;
114 while (!users.empty()) {
115 Operation *user = users.pop_back_val();
116 LDBG() <<
"Processing user: " << *user;
118 if (!processed.insert(user).second) {
119 LDBG() <<
" -> Already processed, skipping";
122 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
123 LDBG() <<
" -> View-like operation, following to destination";
124 Value viewDest = viewLike.getViewDest();
129 LDBG() <<
" -> Memory effect free, skipping";
132 if (user == write.getOperation()) {
133 LDBG() <<
" -> Same as write operation, skipping";
136 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
137 LDBG() <<
" -> Found transfer_write candidate: " << *nextWrite;
140 cast<MemrefValue>(nextWrite.getBase()),
141 cast<MemrefValue>(write.getBase()));
143 bool postDominates = postDominators.
postDominates(nextWrite, write);
144 LDBG() <<
" -> Same view: " << sameView
145 <<
", Same value: " << sameValue
146 <<
", Post-dominates: " << postDominates;
148 if (sameView && sameValue && postDominates) {
149 LDBG() <<
" -> Valid overwrite candidate found";
150 if (firstOverwriteCandidate ==
nullptr ||
151 postDominators.
postDominates(firstOverwriteCandidate, nextWrite)) {
152 LDBG() <<
" -> New first overwrite candidate: " << *nextWrite;
153 firstOverwriteCandidate = nextWrite;
155 LDBG() <<
" -> Keeping existing first overwrite candidate";
157 postDominators.
postDominates(nextWrite, firstOverwriteCandidate));
161 LDBG() <<
" -> Not a valid overwrite candidate";
163 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
164 LDBG() <<
" -> Found vector transfer operation: " << *transferOp;
167 cast<VectorTransferOpInterface>(write.getOperation()),
168 cast<VectorTransferOpInterface>(transferOp.getOperation()),
170 LDBG() <<
" -> Is disjoint: " << isDisjoint;
172 LDBG() <<
" -> Skipping disjoint access";
176 LDBG() <<
" -> Adding to blocking accesses: " << *user;
177 blockingAccesses.push_back(user);
179 LDBG() <<
"Finished processing users. Found " << blockingAccesses.size()
180 <<
" blocking accesses";
182 if (firstOverwriteCandidate ==
nullptr) {
183 LDBG() <<
"No overwrite candidate found, store is not dead";
187 LDBG() <<
"First overwrite candidate: " << *firstOverwriteCandidate;
190 assert(writeAncestor &&
191 "write op should be recursively part of the top region");
192 LDBG() <<
"Write ancestor in top region: " << *writeAncestor;
194 LDBG() <<
"Checking " << blockingAccesses.size()
195 <<
" blocking accesses for reachability";
196 for (Operation *access : blockingAccesses) {
197 LDBG() <<
"Checking blocking access: " << *access;
201 if (accessAncestor ==
nullptr) {
202 LDBG() <<
" -> No ancestor in top region, skipping";
206 bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
207 LDBG() <<
" -> Is reachable from write: " << isReachableFromWrite;
208 if (!isReachableFromWrite) {
209 LDBG() <<
" -> Not reachable, skipping";
213 bool overwriteDominatesAccess =
214 dominators.
dominates(firstOverwriteCandidate, accessAncestor);
215 LDBG() <<
" -> Overwrite dominates access: " << overwriteDominatesAccess;
216 if (!overwriteDominatesAccess) {
217 LDBG() <<
"Store may not be dead due to op: " << *accessAncestor;
220 LDBG() <<
" -> Access is dominated by overwrite, continuing";
222 LDBG() <<
"Found dead store: " << *write.getOperation()
223 <<
" overwritten by: " << *firstOverwriteCandidate;
224 opToErase.push_back(write.getOperation());
238void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
239 LDBG() <<
"=== Starting storeToLoadForwarding analysis for: "
240 << *read.getOperation();
241 if (read.hasOutOfBoundsDim()) {
242 LDBG() <<
"Read has out-of-bounds dimensions, skipping";
245 SmallVector<Operation *, 8> blockingWrites;
246 vector::TransferWriteOp lastwrite =
nullptr;
248 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
249 llvm::SmallVector<Operation *, 32> users(source.
getUsers().begin(),
251 LDBG() <<
"Found " << users.size() <<
" users of source memref";
252 llvm::SmallDenseSet<Operation *, 32> processed;
253 while (!users.empty()) {
254 Operation *user = users.pop_back_val();
255 LDBG() <<
"Processing user: " << *user;
257 if (!processed.insert(user).second) {
258 LDBG() <<
" -> Already processed, skipping";
261 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
262 LDBG() <<
" -> View-like operation, following to destination";
263 Value viewDest = viewLike.getViewDest();
268 LDBG() <<
" -> Memory effect free or transfer_read, skipping";
271 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
272 LDBG() <<
" -> Found transfer_write candidate: " << *write;
276 cast<VectorTransferOpInterface>(write.getOperation()),
277 cast<VectorTransferOpInterface>(read.getOperation()),
279 LDBG() <<
" -> Is disjoint: " << isDisjoint;
281 LDBG() <<
" -> Skipping disjoint write";
287 cast<MemrefValue>(write.getBase()));
288 bool dominates = dominators.
dominates(write, read);
290 LDBG() <<
" -> Same view: " << sameView <<
", Dominates: " << dominates
291 <<
", Same value: " << sameValue;
293 if (sameView && dominates && sameValue) {
294 LDBG() <<
" -> Valid forwarding candidate found";
295 if (lastwrite ==
nullptr || dominators.
dominates(lastwrite, write)) {
296 LDBG() <<
" -> New last write candidate: " << *write;
299 LDBG() <<
" -> Keeping existing last write candidate";
300 assert(dominators.
dominates(write, lastwrite));
304 LDBG() <<
" -> Not a valid forwarding candidate";
306 LDBG() <<
" -> Adding to blocking writes: " << *user;
307 blockingWrites.push_back(user);
309 LDBG() <<
"Finished processing users. Found " << blockingWrites.size()
310 <<
" blocking writes";
312 if (lastwrite ==
nullptr) {
313 LDBG() <<
"No last write candidate found, cannot forward";
317 LDBG() <<
"Last write candidate: " << *lastwrite;
320 assert(readAncestor &&
321 "read op should be recursively part of the top region");
322 LDBG() <<
"Read ancestor in top region: " << *readAncestor;
324 LDBG() <<
"Checking " << blockingWrites.size()
325 <<
" blocking writes for post-dominance";
326 for (Operation *write : blockingWrites) {
327 LDBG() <<
"Checking blocking write: " << *write;
330 LDBG() <<
" -> Write ancestor: " << *writeAncestor;
332 LDBG() <<
" -> Write ancestor: nullptr";
337 if (writeAncestor ==
nullptr) {
338 LDBG() <<
" -> No ancestor in top region, skipping";
342 bool isReachableToRead = isReachable(writeAncestor, readAncestor);
343 LDBG() <<
" -> Is reachable to read: " << isReachableToRead;
344 if (!isReachableToRead) {
345 LDBG() <<
" -> Not reachable, skipping";
349 bool lastWritePostDominates =
351 LDBG() <<
" -> Last write post-dominates blocking write: "
352 << lastWritePostDominates;
353 if (!lastWritePostDominates) {
354 LDBG() <<
"Fail to do write to read forwarding due to op: " << *write;
357 LDBG() <<
" -> Blocking write is post-dominated, continuing";
360 LDBG() <<
"Forward value from " << *lastwrite.getOperation()
361 <<
" to: " << *read.getOperation();
362 read.replaceAllUsesWith(lastwrite.getVector());
363 opToErase.push_back(read.getOperation());
369 for (
const auto size : mixedSizes) {
370 if (llvm::dyn_cast_if_present<Value>(size)) {
371 reducedShape.push_back(ShapedType::kDynamic);
375 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
378 reducedShape.push_back(value.getSExtValue());
389 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
390 targetShape, inputType, offsets, sizes, strides);
391 return rankReducedType.canonicalizeStridedLayout();
399 MemRefType inputType = cast<MemRefType>(input.
getType());
405 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
407 if (resultType.canonicalizeStridedLayout() ==
408 inputType.canonicalizeStridedLayout())
410 return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
416 return llvm::count_if(
shape, [](
int64_t dimSize) {
return dimSize != 1; });
424 for (
auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
425 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
427 newShape.push_back(dimSize);
428 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
430 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
440template <
typename MaskOp>
443 auto type = op.getType();
445 if (reducedType.getRank() == type.getRank())
448 using ElemType = std::decay_t<
decltype(*op.getMaskDimSizes().begin())>;
450 for (
auto [dim, dimIsScalable, elem] : llvm::zip_equal(
451 type.getShape(), type.getScalableDims(), op.getMaskDimSizes())) {
452 if (dim == 1 && !dimIsScalable) {
457 reduced.push_back(elem);
459 return MaskOp::create(rewriter, loc, reducedType, reduced).getResult();
467class TransferReadDropUnitDimsPattern
468 :
public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
469 using MaskableOpRewritePattern::MaskableOpRewritePattern;
472 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
473 vector::MaskingOpInterface maskingOp,
474 PatternRewriter &rewriter)
const override {
475 LDBG() <<
"=== TransferReadDropUnitDimsPattern: Analyzing "
477 auto loc = transferReadOp.getLoc();
478 Value vector = transferReadOp.getVector();
479 VectorType vectorType = cast<VectorType>(vector.
getType());
480 Value source = transferReadOp.getBase();
481 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
484 LDBG() <<
" -> Not a MemRefType, skipping";
488 if (transferReadOp.hasOutOfBoundsDim()) {
489 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
492 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
493 LDBG() <<
" -> Not minor identity permutation map, skipping";
498 LDBG() <<
" -> Source rank: " << sourceType.getRank()
499 <<
", Reduced rank: " << reducedRank;
500 if (reducedRank == sourceType.getRank()) {
501 LDBG() <<
" -> No unit dimensions to drop, skipping";
506 if (reducedRank == 0 && maskingOp) {
507 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
513 LDBG() <<
" -> Vector type: " << vectorType
514 <<
", Reduced vector type: " << reducedVectorType;
515 if (reducedRank != reducedVectorType.getRank()) {
516 LDBG() <<
" -> Reduced ranks don't match, skipping";
519 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
520 return getConstantIntValue(v) != static_cast<int64_t>(0);
522 LDBG() <<
" -> Non-zero indices found, skipping";
526 Value maskOp = transferReadOp.getMask();
528 LDBG() <<
" -> Processing mask operation";
529 FailureOr<Value> rankReducedMaskOp = failure();
530 if (
auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>())
533 else if (
auto constantMaskOp =
538 if (
failed(rankReducedMaskOp)) {
539 LDBG() <<
" -> Failed to reduce mask dimensions";
542 "unsupported mask op, only 'vector.create_mask' and "
543 "'vector.constant_mask' are currently supported");
545 maskOp = *rankReducedMaskOp;
546 LDBG() <<
" -> Successfully reduced mask dimensions";
549 LDBG() <<
" -> Creating rank-reduced subview and new transfer_read";
550 Value reducedShapeSource =
553 Repeated<Value> zeros(reducedRank, c0);
555 SmallVector<bool> inBounds(reducedVectorType.getRank(),
true);
556 Operation *newTransferReadOp = vector::TransferReadOp::create(
557 rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
558 identityMap, transferReadOp.getPadding(), maskOp,
560 LDBG() <<
" -> Created new transfer_read: " << *newTransferReadOp;
563 LDBG() <<
" -> Applying masking operation";
564 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
565 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
566 maskingOp.getMask());
568 rewriter, newTransferReadOp, shapeCastMask);
571 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
572 loc, vectorType, newTransferReadOp->
getResults()[0]);
573 LDBG() <<
" -> Created shape cast: " << *shapeCast.getDefiningOp();
574 LDBG() <<
" -> Pattern match successful, returning result";
583class TransferWriteDropUnitDimsPattern
584 :
public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
585 using MaskableOpRewritePattern::MaskableOpRewritePattern;
588 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
589 vector::MaskingOpInterface maskingOp,
590 PatternRewriter &rewriter)
const override {
591 LDBG() <<
"=== TransferWriteDropUnitDimsPattern: Analyzing "
593 auto loc = transferWriteOp.getLoc();
594 Value vector = transferWriteOp.getVector();
595 VectorType vectorType = cast<VectorType>(vector.
getType());
596 Value source = transferWriteOp.getBase();
597 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
600 LDBG() <<
" -> Not a MemRefType, skipping";
604 if (transferWriteOp.hasOutOfBoundsDim()) {
605 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
608 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
609 LDBG() <<
" -> Not minor identity permutation map, skipping";
614 LDBG() <<
" -> Source rank: " << sourceType.getRank()
615 <<
", Reduced rank: " << reducedRank;
616 if (reducedRank == sourceType.getRank()) {
617 LDBG() <<
" -> No unit dimensions to drop, skipping";
622 if (reducedRank == 0 && maskingOp) {
623 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
629 LDBG() <<
" -> Vector type: " << vectorType
630 <<
", Reduced vector type: " << reducedVectorType;
631 if (reducedRank != reducedVectorType.getRank()) {
632 LDBG() <<
" -> Reduced ranks don't match, skipping";
635 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
636 return getConstantIntValue(v) != static_cast<int64_t>(0);
638 LDBG() <<
" -> Non-zero indices found, skipping";
642 Value maskOp = transferWriteOp.getMask();
644 LDBG() <<
" -> Processing mask operation";
645 FailureOr<Value> rankReducedMask = failure();
646 if (
auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>())
649 else if (
auto constantMaskOp =
654 if (
failed(rankReducedMask)) {
655 LDBG() <<
" -> Failed to reduce mask dimensions";
658 "unsupported mask op, only 'vector.create_mask' and "
659 "'vector.constant_mask' are currently supported");
661 maskOp = *rankReducedMask;
662 LDBG() <<
" -> Successfully reduced mask dimensions";
664 LDBG() <<
" -> Creating rank-reduced subview and new transfer_write";
665 Value reducedShapeSource =
668 Repeated<Value> zeros(reducedRank, c0);
670 SmallVector<bool> inBounds(reducedVectorType.getRank(),
true);
671 auto shapeCastSrc = rewriter.
createOrFold<vector::ShapeCastOp>(
672 loc, reducedVectorType, vector);
673 Operation *newXferWrite = vector::TransferWriteOp::create(
674 rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
676 LDBG() <<
" -> Created new transfer_write: " << *newXferWrite;
679 LDBG() <<
" -> Applying masking operation";
680 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
681 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
682 maskingOp.getMask());
687 if (transferWriteOp.hasPureTensorSemantics()) {
688 LDBG() <<
" -> Pattern match successful (tensor semantics), returning "
690 return newXferWrite->getResults()[0];
695 LDBG() <<
" -> Pattern match successful (memref semantics)";
706 ShapedType inputType = cast<ShapedType>(input.
getType());
707 if (inputType.getRank() == 1)
710 for (
int64_t i = 0; i < firstDimToCollapse; ++i)
713 for (
int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
714 collapsedIndices.push_back(i);
715 reassociation.push_back(collapsedIndices);
716 return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
726 assert(firstDimToCollapse <
static_cast<int64_t>(
indices.size()));
735 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
736 return indicesAfterCollapsing;
761 auto &&[collapsedExpr, collapsedVals] =
764 rewriter, loc, collapsedExpr, collapsedVals);
766 if (
auto value = dyn_cast<Value>(collapsedOffset)) {
767 indicesAfterCollapsing.push_back(value);
773 return indicesAfterCollapsing;
785class FlattenContiguousRowMajorTransferReadPattern
786 :
public OpRewritePattern<vector::TransferReadOp> {
788 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
789 unsigned vectorBitwidth,
790 PatternBenefit benefit)
791 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
792 targetVectorBitwidth(vectorBitwidth) {}
794 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
795 PatternRewriter &rewriter)
const override {
796 LDBG() <<
"=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
798 auto loc = transferReadOp.
getLoc();
799 Value vector = transferReadOp.getVector();
800 VectorType vectorType = cast<VectorType>(vector.
getType());
801 auto source = transferReadOp.getBase();
802 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
807 LDBG() <<
" -> Not a MemRefType, skipping";
811 if (vectorType.getRank() <= 1) {
812 LDBG() <<
" -> Already 0D/1D, skipping";
815 if (!vectorType.getElementType().isSignlessIntOrFloat()) {
816 LDBG() <<
" -> Not signless int or float, skipping";
819 unsigned trailingVectorDimBitwidth =
820 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
821 LDBG() <<
" -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
822 <<
", target: " << targetVectorBitwidth;
823 if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
824 LDBG() <<
" -> Trailing dim bitwidth >= target, skipping";
828 LDBG() <<
" -> Not contiguous slice, skipping";
832 if (transferReadOp.hasOutOfBoundsDim()) {
833 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
836 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
837 LDBG() <<
" -> Not minor identity permutation map, skipping";
840 if (transferReadOp.getMask()) {
841 LDBG() <<
" -> Has mask, skipping";
847 ArrayRef<int64_t> collapsedVectorShape =
848 vectorType.getShape().drop_while([](
auto v) {
return v == 1; });
849 size_t collapsedVecRank = collapsedVectorShape.size();
852 if (collapsedVecRank == 0)
853 collapsedVecRank = 1;
857 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
858 LDBG() <<
" -> First dimension to collapse: " << firstDimToCollapse;
861 LDBG() <<
" -> Collapsing source memref";
862 Value collapsedSource =
864 MemRefType collapsedSourceType =
865 cast<MemRefType>(collapsedSource.
getType());
866 int64_t collapsedRank = collapsedSourceType.getRank();
867 assert(collapsedRank == firstDimToCollapse + 1);
868 LDBG() <<
" -> Collapsed source type: " << collapsedSourceType;
873 SmallVector<AffineExpr, 1> dimExprs{
879 SmallVector<Value> collapsedIndices =
881 transferReadOp.getIndices(), firstDimToCollapse);
884 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
885 vectorType.getElementType());
886 LDBG() <<
" -> Creating flattened vector type: " << flatVectorType;
887 vector::TransferReadOp flatRead = vector::TransferReadOp::create(
888 rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
889 transferReadOp.getPadding(), collapsedMap);
891 LDBG() <<
" -> Created flat transfer_read: " << *flatRead;
895 LDBG() <<
" -> Replacing with shape cast";
897 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
898 LDBG() <<
" -> Pattern match successful";
905 unsigned targetVectorBitwidth;
916class FlattenContiguousRowMajorTransferWritePattern
917 :
public OpRewritePattern<vector::TransferWriteOp> {
919 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
920 unsigned vectorBitwidth,
921 PatternBenefit benefit)
922 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
923 targetVectorBitwidth(vectorBitwidth) {}
925 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
926 PatternRewriter &rewriter)
const override {
927 auto loc = transferWriteOp.
getLoc();
928 Value vector = transferWriteOp.getVector();
929 VectorType vectorType = cast<VectorType>(vector.
getType());
930 Value source = transferWriteOp.getBase();
931 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
938 if (vectorType.getRank() <= 1)
941 if (!vectorType.getElementType().isSignlessIntOrFloat())
943 unsigned trailingVectorDimBitwidth =
944 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
945 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
950 if (transferWriteOp.hasOutOfBoundsDim())
952 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
954 if (transferWriteOp.getMask())
959 ArrayRef<int64_t> collapsedVectorShape =
960 vectorType.getShape().drop_while([](
auto v) {
return v == 1; });
961 size_t collapsedVecRank = collapsedVectorShape.size();
964 if (collapsedVecRank == 0)
965 collapsedVecRank = 1;
969 int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
972 Value collapsedSource =
974 MemRefType collapsedSourceType =
975 cast<MemRefType>(collapsedSource.
getType());
976 int64_t collapsedRank = collapsedSourceType.getRank();
977 assert(collapsedRank == firstDimToCollapse + 1);
982 SmallVector<AffineExpr, 1> dimExprs{
988 SmallVector<Value> collapsedIndices =
990 transferWriteOp.getIndices(), firstDimToCollapse);
993 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
994 vectorType.getElementType());
996 vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
997 vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
998 rewriter, loc, flatVector, collapsedSource, collapsedIndices,
1004 rewriter.
eraseOp(transferWriteOp);
1011 unsigned targetVectorBitwidth;
1021class RewriteScalarExtractOfTransferRead
1022 :
public OpRewritePattern<vector::ExtractOp> {
1024 RewriteScalarExtractOfTransferRead(MLIRContext *context,
1025 PatternBenefit benefit,
1026 bool allowMultipleUses)
1027 : OpRewritePattern(context, benefit),
1028 allowMultipleUses(allowMultipleUses) {}
1030 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1031 PatternRewriter &rewriter)
const override {
1033 auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
1037 if (isa<VectorType>(extractOp.getResult().getType()))
1040 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1043 if (allowMultipleUses &&
1044 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
1045 return isa<vector::ExtractOp>(use.getOwner());
1049 if (xferOp.getMask())
1052 if (!xferOp.getPermutationMap().isMinorIdentity())
1055 if (xferOp.hasOutOfBoundsDim())
1059 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
1060 xferOp.getIndices().end());
1061 for (
auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1062 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1066 OpFoldResult composedIdx;
1067 if (
auto attr = dyn_cast<Attribute>(pos)) {
1068 int64_t offset = cast<IntegerAttr>(attr).getInt();
1069 composedIdx = affine::makeComposedFoldedAffineApply(
1070 rewriter, extractOp.getLoc(),
1073 Value dynamicOffset = cast<Value>(pos);
1074 AffineExpr sym0, sym1;
1076 composedIdx = affine::makeComposedFoldedAffineApply(
1077 rewriter, extractOp.getLoc(), sym0 + sym1,
1078 {newIndices[idx], dynamicOffset});
1082 if (
auto value = dyn_cast<Value>(composedIdx)) {
1083 newIndices[idx] = value;
1089 if (isa<MemRefType>(xferOp.getBase().getType())) {
1094 extractOp, xferOp.getBase(), newIndices);
1101 bool allowMultipleUses;
1106class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
1109 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1110 PatternRewriter &rewriter)
const override {
1112 auto vecType = xferOp.getVectorType();
1113 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1116 if (xferOp.getMask())
1119 if (!xferOp.getPermutationMap().isMinorIdentity())
1122 Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1123 xferOp.getVector());
1125 if (isa<MemRefType>(xferOp.getBase().getType())) {
1127 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1130 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1138void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
1139 Operation *rootOp) {
1140 LDBG() <<
"=== Starting transferOpflowOpt on root operation: "
1141 << OpWithFlags(rootOp, OpPrintingFlags().skipRegions());
1142 TransferOptimization opt(rewriter, rootOp);
1146 LDBG() <<
"Phase 1: Store-to-load forwarding";
1148 rootOp->
walk([&](vector::TransferReadOp read) {
1149 if (isa<MemRefType>(read.getShapedType())) {
1150 LDBG() <<
"Processing transfer_read #" << ++readCount <<
": " << *read;
1151 opt.storeToLoadForwarding(read);
1154 LDBG() <<
"Phase 1 complete. Removing dead operations from forwarding";
1157 LDBG() <<
"Phase 2: Dead store elimination";
1159 rootOp->
walk([&](vector::TransferWriteOp write) {
1160 if (isa<MemRefType>(write.getShapedType())) {
1161 LDBG() <<
"Processing transfer_write #" << ++writeCount <<
": " << *write;
1162 opt.deadStoreOp(write);
1165 LDBG() <<
"Phase 2 complete. Removing dead operations from dead store "
1168 LDBG() <<
"=== transferOpflowOpt complete";
1173 bool allowMultipleUses) {
1174 patterns.
add<RewriteScalarExtractOfTransferRead>(patterns.
getContext(),
1175 benefit, allowMultipleUses);
1176 patterns.
add<RewriteScalarWrite>(patterns.
getContext(), benefit);
1179void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1182 .
add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1186void mlir::vector::populateFlattenVectorTransferPatterns(
1189 patterns.
add<FlattenContiguousRowMajorTransferReadPattern,
1190 FlattenContiguousRowMajorTransferWritePattern>(
1191 patterns.
getContext(), targetVectorBitwidth, benefit);
1192 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool postDominates(Operation *a, Operation *b) const
Return true if operation A postdominates operation B.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.