27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/DebugLog.h"
31 #define DEBUG_TYPE "vector-transfer-opt"
38 LDBG() <<
" Finding ancestor of " << *op <<
" in region";
43 LDBG() <<
" -> Ancestor: " << *op;
45 LDBG() <<
" -> Ancestor: nullptr";
52 class TransferOptimization {
55 : rewriter(rewriter), dominators(op), postDominators(op) {}
56 void deadStoreOp(vector::TransferWriteOp);
57 void storeToLoadForwarding(vector::TransferReadOp);
59 LDBG() <<
"Removing " << opToErase.size() <<
" dead operations";
61 LDBG() <<
" -> Erasing: " << *op;
72 std::vector<Operation *> opToErase;
79 LDBG() <<
" Checking reachability from " << *start <<
" to " << *dest;
81 "This function only works for ops i the same region");
83 if (dominators.dominates(start, dest)) {
84 LDBG() <<
" -> Start dominates dest, reachable";
88 LDBG() <<
" -> Block reachable: " << blockReachable;
89 return blockReachable;
103 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104 LDBG() <<
"=== Starting deadStoreOp analysis for: " << *write.getOperation();
106 Operation *firstOverwriteCandidate =
nullptr;
108 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
111 LDBG() <<
"Found " << users.size() <<
" users of source memref";
112 llvm::SmallDenseSet<Operation *, 32> processed;
113 while (!users.empty()) {
115 LDBG() <<
"Processing user: " << *user;
117 if (!processed.insert(user).second) {
118 LDBG() <<
" -> Already processed, skipping";
121 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
122 LDBG() <<
" -> View-like operation, following to destination";
123 Value viewDest = viewLike.getViewDest();
128 LDBG() <<
" -> Memory effect free, skipping";
131 if (user == write.getOperation()) {
132 LDBG() <<
" -> Same as write operation, skipping";
135 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
136 LDBG() <<
" -> Found transfer_write candidate: " << *nextWrite;
139 cast<MemrefValue>(nextWrite.getBase()),
140 cast<MemrefValue>(write.getBase()));
142 bool postDominates = postDominators.postDominates(nextWrite, write);
143 LDBG() <<
" -> Same view: " << sameView
144 <<
", Same value: " << sameValue
145 <<
", Post-dominates: " << postDominates;
147 if (sameView && sameValue && postDominates) {
148 LDBG() <<
" -> Valid overwrite candidate found";
149 if (firstOverwriteCandidate ==
nullptr ||
150 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) {
151 LDBG() <<
" -> New first overwrite candidate: " << *nextWrite;
152 firstOverwriteCandidate = nextWrite;
154 LDBG() <<
" -> Keeping existing first overwrite candidate";
156 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
160 LDBG() <<
" -> Not a valid overwrite candidate";
162 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
163 LDBG() <<
" -> Found vector transfer operation: " << *transferOp;
166 cast<VectorTransferOpInterface>(write.getOperation()),
167 cast<VectorTransferOpInterface>(transferOp.getOperation()),
169 LDBG() <<
" -> Is disjoint: " << isDisjoint;
171 LDBG() <<
" -> Skipping disjoint access";
175 LDBG() <<
" -> Adding to blocking accesses: " << *user;
176 blockingAccesses.push_back(user);
178 LDBG() <<
"Finished processing users. Found " << blockingAccesses.size()
179 <<
" blocking accesses";
181 if (firstOverwriteCandidate ==
nullptr) {
182 LDBG() <<
"No overwrite candidate found, store is not dead";
186 LDBG() <<
"First overwrite candidate: " << *firstOverwriteCandidate;
189 assert(writeAncestor &&
190 "write op should be recursively part of the top region");
191 LDBG() <<
"Write ancestor in top region: " << *writeAncestor;
193 LDBG() <<
"Checking " << blockingAccesses.size()
194 <<
" blocking accesses for reachability";
195 for (
Operation *access : blockingAccesses) {
196 LDBG() <<
"Checking blocking access: " << *access;
200 if (accessAncestor ==
nullptr) {
201 LDBG() <<
" -> No ancestor in top region, skipping";
205 bool isReachableFromWrite = isReachable(writeAncestor, accessAncestor);
206 LDBG() <<
" -> Is reachable from write: " << isReachableFromWrite;
207 if (!isReachableFromWrite) {
208 LDBG() <<
" -> Not reachable, skipping";
212 bool overwriteDominatesAccess =
213 dominators.dominates(firstOverwriteCandidate, accessAncestor);
214 LDBG() <<
" -> Overwrite dominates access: " << overwriteDominatesAccess;
215 if (!overwriteDominatesAccess) {
216 LDBG() <<
"Store may not be dead due to op: " << *accessAncestor;
219 LDBG() <<
" -> Access is dominated by overwrite, continuing";
221 LDBG() <<
"Found dead store: " << *write.getOperation()
222 <<
" overwritten by: " << *firstOverwriteCandidate;
223 opToErase.push_back(write.getOperation());
237 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
238 LDBG() <<
"=== Starting storeToLoadForwarding analysis for: "
239 << *read.getOperation();
240 if (read.hasOutOfBoundsDim()) {
241 LDBG() <<
"Read has out-of-bounds dimensions, skipping";
245 vector::TransferWriteOp lastwrite =
nullptr;
247 LDBG() <<
"Source memref (after skipping view-like ops): " << source;
250 LDBG() <<
"Found " << users.size() <<
" users of source memref";
251 llvm::SmallDenseSet<Operation *, 32> processed;
252 while (!users.empty()) {
254 LDBG() <<
"Processing user: " << *user;
256 if (!processed.insert(user).second) {
257 LDBG() <<
" -> Already processed, skipping";
260 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
261 LDBG() <<
" -> View-like operation, following to destination";
262 Value viewDest = viewLike.getViewDest();
267 LDBG() <<
" -> Memory effect free or transfer_read, skipping";
270 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
271 LDBG() <<
" -> Found transfer_write candidate: " << *write;
275 cast<VectorTransferOpInterface>(write.getOperation()),
276 cast<VectorTransferOpInterface>(read.getOperation()),
278 LDBG() <<
" -> Is disjoint: " << isDisjoint;
280 LDBG() <<
" -> Skipping disjoint write";
286 cast<MemrefValue>(write.getBase()));
287 bool dominates = dominators.dominates(write, read);
289 LDBG() <<
" -> Same view: " << sameView <<
", Dominates: " << dominates
290 <<
", Same value: " << sameValue;
292 if (sameView && dominates && sameValue) {
293 LDBG() <<
" -> Valid forwarding candidate found";
294 if (lastwrite ==
nullptr || dominators.dominates(lastwrite, write)) {
295 LDBG() <<
" -> New last write candidate: " << *write;
298 LDBG() <<
" -> Keeping existing last write candidate";
299 assert(dominators.dominates(write, lastwrite));
303 LDBG() <<
" -> Not a valid forwarding candidate";
305 LDBG() <<
" -> Adding to blocking writes: " << *user;
306 blockingWrites.push_back(user);
308 LDBG() <<
"Finished processing users. Found " << blockingWrites.size()
309 <<
" blocking writes";
311 if (lastwrite ==
nullptr) {
312 LDBG() <<
"No last write candidate found, cannot forward";
316 LDBG() <<
"Last write candidate: " << *lastwrite;
319 assert(readAncestor &&
320 "read op should be recursively part of the top region");
321 LDBG() <<
"Read ancestor in top region: " << *readAncestor;
323 LDBG() <<
"Checking " << blockingWrites.size()
324 <<
" blocking writes for post-dominance";
325 for (
Operation *write : blockingWrites) {
326 LDBG() <<
"Checking blocking write: " << *write;
329 LDBG() <<
" -> Write ancestor: " << *writeAncestor;
331 LDBG() <<
" -> Write ancestor: nullptr";
336 if (writeAncestor ==
nullptr) {
337 LDBG() <<
" -> No ancestor in top region, skipping";
341 bool isReachableToRead = isReachable(writeAncestor, readAncestor);
342 LDBG() <<
" -> Is reachable to read: " << isReachableToRead;
343 if (!isReachableToRead) {
344 LDBG() <<
" -> Not reachable, skipping";
348 bool lastWritePostDominates =
349 postDominators.postDominates(lastwrite, write);
350 LDBG() <<
" -> Last write post-dominates blocking write: "
351 << lastWritePostDominates;
352 if (!lastWritePostDominates) {
353 LDBG() <<
"Fail to do write to read forwarding due to op: " << *write;
356 LDBG() <<
" -> Blocking write is post-dominated, continuing";
359 LDBG() <<
"Forward value from " << *lastwrite.getOperation()
360 <<
" to: " << *read.getOperation();
361 read.replaceAllUsesWith(lastwrite.getVector());
362 opToErase.push_back(read.getOperation());
368 for (
const auto size : mixedSizes) {
369 if (llvm::dyn_cast_if_present<Value>(size)) {
370 reducedShape.push_back(ShapedType::kDynamic);
374 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
377 reducedShape.push_back(value.getSExtValue());
388 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
389 targetShape, inputType, offsets, sizes, strides);
390 return rankReducedType.canonicalizeStridedLayout();
398 MemRefType inputType = cast<MemRefType>(input.
getType());
404 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
406 if (resultType.canonicalizeStridedLayout() ==
407 inputType.canonicalizeStridedLayout())
409 return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
415 return llvm::count_if(shape, [](int64_t dimSize) {
return dimSize != 1; });
424 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
426 newShape.push_back(dimSize);
427 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
429 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
433 static FailureOr<Value>
435 vector::CreateMaskOp op) {
436 auto type = op.getType();
438 if (reducedType.getRank() == type.getRank())
442 for (
auto [dim, dimIsScalable, operand] : llvm::zip_equal(
443 type.getShape(), type.getScalableDims(), op.getOperands())) {
444 if (dim == 1 && !dimIsScalable) {
446 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
447 if (!constant || (constant.value() != 1))
451 reducedOperands.push_back(operand);
453 return vector::CreateMaskOp::create(rewriter, loc, reducedType,
463 class TransferReadDropUnitDimsPattern
465 using MaskableOpRewritePattern::MaskableOpRewritePattern;
468 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
469 vector::MaskingOpInterface maskingOp,
471 LDBG() <<
"=== TransferReadDropUnitDimsPattern: Analyzing "
473 auto loc = transferReadOp.getLoc();
474 Value vector = transferReadOp.getVector();
475 VectorType vectorType = cast<VectorType>(vector.
getType());
476 Value source = transferReadOp.getBase();
477 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
480 LDBG() <<
" -> Not a MemRefType, skipping";
484 if (transferReadOp.hasOutOfBoundsDim()) {
485 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
488 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
489 LDBG() <<
" -> Not minor identity permutation map, skipping";
494 LDBG() <<
" -> Source rank: " << sourceType.getRank()
495 <<
", Reduced rank: " << reducedRank;
496 if (reducedRank == sourceType.getRank()) {
497 LDBG() <<
" -> No unit dimensions to drop, skipping";
502 if (reducedRank == 0 && maskingOp) {
503 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
509 LDBG() <<
" -> Vector type: " << vectorType
510 <<
", Reduced vector type: " << reducedVectorType;
511 if (reducedRank != reducedVectorType.getRank()) {
512 LDBG() <<
" -> Reduced ranks don't match, skipping";
515 if (llvm::any_of(transferReadOp.getIndices(), [](
Value v) {
516 return getConstantIntValue(v) != static_cast<int64_t>(0);
518 LDBG() <<
" -> Non-zero indices found, skipping";
522 Value maskOp = transferReadOp.getMask();
524 LDBG() <<
" -> Processing mask operation";
525 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
528 <<
" -> Unsupported mask op, only 'vector.create_mask' supported";
530 transferReadOp,
"unsupported mask op, only 'vector.create_mask' is "
531 "currently supported");
533 FailureOr<Value> rankReducedCreateMask =
535 if (
failed(rankReducedCreateMask)) {
536 LDBG() <<
" -> Failed to reduce mask dimensions";
539 maskOp = *rankReducedCreateMask;
540 LDBG() <<
" -> Successfully reduced mask dimensions";
543 LDBG() <<
" -> Creating rank-reduced subview and new transfer_read";
544 Value reducedShapeSource =
550 Operation *newTransferReadOp = vector::TransferReadOp::create(
551 rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
552 identityMap, transferReadOp.getPadding(), maskOp,
554 LDBG() <<
" -> Created new transfer_read: " << *newTransferReadOp;
557 LDBG() <<
" -> Applying masking operation";
558 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
559 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
560 maskingOp.getMask());
562 rewriter, newTransferReadOp, shapeCastMask);
565 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
566 loc, vectorType, newTransferReadOp->
getResults()[0]);
567 LDBG() <<
" -> Created shape cast: " << *shapeCast.getDefiningOp();
568 LDBG() <<
" -> Pattern match successful, returning result";
577 class TransferWriteDropUnitDimsPattern
579 using MaskableOpRewritePattern::MaskableOpRewritePattern;
582 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
583 vector::MaskingOpInterface maskingOp,
585 LDBG() <<
"=== TransferWriteDropUnitDimsPattern: Analyzing "
587 auto loc = transferWriteOp.getLoc();
588 Value vector = transferWriteOp.getVector();
589 VectorType vectorType = cast<VectorType>(vector.
getType());
590 Value source = transferWriteOp.getBase();
591 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
594 LDBG() <<
" -> Not a MemRefType, skipping";
598 if (transferWriteOp.hasOutOfBoundsDim()) {
599 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
602 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) {
603 LDBG() <<
" -> Not minor identity permutation map, skipping";
608 LDBG() <<
" -> Source rank: " << sourceType.getRank()
609 <<
", Reduced rank: " << reducedRank;
610 if (reducedRank == sourceType.getRank()) {
611 LDBG() <<
" -> No unit dimensions to drop, skipping";
616 if (reducedRank == 0 && maskingOp) {
617 LDBG() <<
" -> 0-d vector with masking not supported, skipping";
623 LDBG() <<
" -> Vector type: " << vectorType
624 <<
", Reduced vector type: " << reducedVectorType;
625 if (reducedRank != reducedVectorType.getRank()) {
626 LDBG() <<
" -> Reduced ranks don't match, skipping";
629 if (llvm::any_of(transferWriteOp.getIndices(), [](
Value v) {
630 return getConstantIntValue(v) != static_cast<int64_t>(0);
632 LDBG() <<
" -> Non-zero indices found, skipping";
636 Value maskOp = transferWriteOp.getMask();
638 LDBG() <<
" -> Processing mask operation";
639 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
642 <<
" -> Unsupported mask op, only 'vector.create_mask' supported";
645 "unsupported mask op, only 'vector.create_mask' is "
646 "currently supported");
648 FailureOr<Value> rankReducedCreateMask =
650 if (
failed(rankReducedCreateMask)) {
651 LDBG() <<
" -> Failed to reduce mask dimensions";
654 maskOp = *rankReducedCreateMask;
655 LDBG() <<
" -> Successfully reduced mask dimensions";
657 LDBG() <<
" -> Creating rank-reduced subview and new transfer_write";
658 Value reducedShapeSource =
664 auto shapeCastSrc = rewriter.
createOrFold<vector::ShapeCastOp>(
665 loc, reducedVectorType, vector);
666 Operation *newXferWrite = vector::TransferWriteOp::create(
667 rewriter, loc,
Type(), shapeCastSrc, reducedShapeSource, zeros,
669 LDBG() <<
" -> Created new transfer_write: " << *newXferWrite;
672 LDBG() <<
" -> Applying masking operation";
673 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
674 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
675 maskingOp.getMask());
680 if (transferWriteOp.hasPureTensorSemantics()) {
681 LDBG() <<
" -> Pattern match successful (tensor semantics), returning "
683 return newXferWrite->getResults()[0];
688 LDBG() <<
" -> Pattern match successful (memref semantics)";
698 Value input, int64_t firstDimToCollapse) {
699 ShapedType inputType = cast<ShapedType>(input.
getType());
700 if (inputType.getRank() == 1)
703 for (int64_t i = 0; i < firstDimToCollapse; ++i)
706 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
707 collapsedIndices.push_back(i);
708 reassociation.push_back(collapsedIndices);
709 return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
718 int64_t firstDimToCollapse) {
719 assert(firstDimToCollapse <
static_cast<int64_t
>(indices.size()));
724 indices.begin(), indices.begin() + firstDimToCollapse);
728 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
729 return indicesAfterCollapsing;
754 auto &&[collapsedExpr, collapsedVals] =
757 rewriter, loc, collapsedExpr, collapsedVals);
759 if (
auto value = dyn_cast<Value>(collapsedOffset)) {
760 indicesAfterCollapsing.push_back(value);
766 return indicesAfterCollapsing;
778 class FlattenContiguousRowMajorTransferReadPattern
781 FlattenContiguousRowMajorTransferReadPattern(
MLIRContext *context,
782 unsigned vectorBitwidth,
785 targetVectorBitwidth(vectorBitwidth) {}
787 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
789 LDBG() <<
"=== FlattenContiguousRowMajorTransferReadPattern: Analyzing "
791 auto loc = transferReadOp.
getLoc();
792 Value vector = transferReadOp.getVector();
793 VectorType vectorType = cast<VectorType>(vector.
getType());
794 auto source = transferReadOp.getBase();
795 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
800 LDBG() <<
" -> Not a MemRefType, skipping";
804 if (vectorType.getRank() <= 1) {
805 LDBG() <<
" -> Already 0D/1D, skipping";
808 if (!vectorType.getElementType().isSignlessIntOrFloat()) {
809 LDBG() <<
" -> Not signless int or float, skipping";
812 unsigned trailingVectorDimBitwidth =
813 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
814 LDBG() <<
" -> Trailing vector dim bitwidth: " << trailingVectorDimBitwidth
815 <<
", target: " << targetVectorBitwidth;
816 if (trailingVectorDimBitwidth >= targetVectorBitwidth) {
817 LDBG() <<
" -> Trailing dim bitwidth >= target, skipping";
821 LDBG() <<
" -> Not contiguous slice, skipping";
825 if (transferReadOp.hasOutOfBoundsDim()) {
826 LDBG() <<
" -> Has out-of-bounds dimensions, skipping";
829 if (!transferReadOp.getPermutationMap().isMinorIdentity()) {
830 LDBG() <<
" -> Not minor identity permutation map, skipping";
833 if (transferReadOp.getMask()) {
834 LDBG() <<
" -> Has mask, skipping";
840 int64_t firstDimToCollapse =
841 sourceType.getRank() -
842 vectorType.getShape().drop_while([](
auto v) {
return v == 1; }).size();
843 LDBG() <<
" -> First dimension to collapse: " << firstDimToCollapse;
846 LDBG() <<
" -> Collapsing source memref";
847 Value collapsedSource =
849 MemRefType collapsedSourceType =
850 cast<MemRefType>(collapsedSource.
getType());
851 int64_t collapsedRank = collapsedSourceType.getRank();
852 assert(collapsedRank == firstDimToCollapse + 1);
853 LDBG() <<
" -> Collapsed source type: " << collapsedSourceType;
866 transferReadOp.getIndices(), firstDimToCollapse);
869 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
870 vectorType.getElementType());
871 LDBG() <<
" -> Creating flattened vector type: " << flatVectorType;
872 vector::TransferReadOp flatRead = vector::TransferReadOp::create(
873 rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
874 transferReadOp.getPadding(), collapsedMap);
876 LDBG() <<
" -> Created flat transfer_read: " << *flatRead;
880 LDBG() <<
" -> Replacing with shape cast";
882 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
883 LDBG() <<
" -> Pattern match successful";
890 unsigned targetVectorBitwidth;
901 class FlattenContiguousRowMajorTransferWritePattern
904 FlattenContiguousRowMajorTransferWritePattern(
MLIRContext *context,
905 unsigned vectorBitwidth,
908 targetVectorBitwidth(vectorBitwidth) {}
910 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
912 auto loc = transferWriteOp.
getLoc();
913 Value vector = transferWriteOp.getVector();
914 VectorType vectorType = cast<VectorType>(vector.
getType());
915 Value source = transferWriteOp.getBase();
916 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
923 if (vectorType.getRank() <= 1)
926 if (!vectorType.getElementType().isSignlessIntOrFloat())
928 unsigned trailingVectorDimBitwidth =
929 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
930 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
935 if (transferWriteOp.hasOutOfBoundsDim())
937 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
939 if (transferWriteOp.getMask())
944 int64_t firstDimToCollapse =
945 sourceType.getRank() -
946 vectorType.getShape().drop_while([](
auto v) {
return v == 1; }).size();
949 Value collapsedSource =
951 MemRefType collapsedSourceType =
952 cast<MemRefType>(collapsedSource.
getType());
953 int64_t collapsedRank = collapsedSourceType.getRank();
954 assert(collapsedRank == firstDimToCollapse + 1);
967 transferWriteOp.getIndices(), firstDimToCollapse);
970 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
971 vectorType.getElementType());
973 vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
974 vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
975 rewriter, loc, flatVector, collapsedSource, collapsedIndices,
981 rewriter.
eraseOp(transferWriteOp);
988 unsigned targetVectorBitwidth;
998 class RewriteScalarExtractOfTransferRead
1001 RewriteScalarExtractOfTransferRead(
MLIRContext *context,
1003 bool allowMultipleUses)
1005 allowMultipleUses(allowMultipleUses) {}
1007 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1010 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
1014 if (isa<VectorType>(extractOp.getResult().getType()))
1017 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
1020 if (allowMultipleUses &&
1021 !llvm::all_of(xferOp->getUses(), [](
OpOperand &use) {
1022 return isa<vector::ExtractOp>(use.getOwner());
1026 if (xferOp.getMask())
1029 if (!xferOp.getPermutationMap().isMinorIdentity())
1032 if (xferOp.hasOutOfBoundsDim())
1037 xferOp.getIndices().end());
1039 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
1044 if (
auto attr = dyn_cast<Attribute>(pos)) {
1045 int64_t offset = cast<IntegerAttr>(attr).getInt();
1047 rewriter, extractOp.getLoc(),
1050 Value dynamicOffset = cast<Value>(pos);
1054 rewriter, extractOp.getLoc(), sym0 + sym1,
1055 {newIndices[idx], dynamicOffset});
1059 if (
auto value = dyn_cast<Value>(composedIdx)) {
1060 newIndices[idx] = value;
1066 if (isa<MemRefType>(xferOp.getBase().getType())) {
1071 extractOp, xferOp.getBase(), newIndices);
1078 bool allowMultipleUses;
1083 class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
1086 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1089 auto vecType = xferOp.getVectorType();
1090 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
1093 if (xferOp.getMask())
1096 if (!xferOp.getPermutationMap().isMinorIdentity())
1099 Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
1100 xferOp.getVector());
1102 if (isa<MemRefType>(xferOp.getBase().getType())) {
1104 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1107 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
1117 LDBG() <<
"=== Starting transferOpflowOpt on root operation: "
1119 TransferOptimization opt(rewriter, rootOp);
1123 LDBG() <<
"Phase 1: Store-to-load forwarding";
1125 rootOp->
walk([&](vector::TransferReadOp read) {
1126 if (isa<MemRefType>(read.getShapedType())) {
1127 LDBG() <<
"Processing transfer_read #" << ++readCount <<
": " << *read;
1128 opt.storeToLoadForwarding(read);
1131 LDBG() <<
"Phase 1 complete. Removing dead operations from forwarding";
1134 LDBG() <<
"Phase 2: Dead store elimination";
1136 rootOp->
walk([&](vector::TransferWriteOp write) {
1137 if (isa<MemRefType>(write.getShapedType())) {
1138 LDBG() <<
"Processing transfer_write #" << ++writeCount <<
": " << *write;
1139 opt.deadStoreOp(write);
1142 LDBG() <<
"Phase 2 complete. Removing dead operations from dead store "
1145 LDBG() <<
"=== transferOpflowOpt complete";
1150 bool allowMultipleUses) {
1152 benefit, allowMultipleUses);
1156 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
1159 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1163 void mlir::vector::populateFlattenVectorTransferPatterns(
1166 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
1167 FlattenContiguousRowMajorTransferWritePattern>(
1168 patterns.getContext(), targetVectorBitwidth, benefit);
1169 populateDropUnitDimWithShapeCastPatterns(
patterns, benefit);
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
A class for computing basic dominance information.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.