26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "vector-transfer-opt"
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
47 class TransferOptimization {
50 : rewriter(rewriter), dominators(op), postDominators(op) {}
51 void deadStoreOp(vector::TransferWriteOp);
52 void storeToLoadForwarding(vector::TransferReadOp);
64 std::vector<Operation *> opToErase;
72 "This function only works for ops i the same region");
74 if (dominators.dominates(start, dest))
90 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
91 LLVM_DEBUG(
DBGS() <<
"Candidate for dead store: " << *write.getOperation()
94 Operation *firstOverwriteCandidate =
nullptr;
98 llvm::SmallDenseSet<Operation *, 32> processed;
99 while (!users.empty()) {
102 if (!processed.insert(user).second)
104 if (isa<ViewLikeOpInterface>(user)) {
110 if (user == write.getOperation())
112 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
115 cast<MemrefValue>(nextWrite.getBase()),
116 cast<MemrefValue>(write.getBase())) &&
118 postDominators.postDominates(nextWrite, write)) {
119 if (firstOverwriteCandidate ==
nullptr ||
120 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
121 firstOverwriteCandidate = nextWrite;
124 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
128 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
131 cast<VectorTransferOpInterface>(write.getOperation()),
132 cast<VectorTransferOpInterface>(transferOp.getOperation()),
136 blockingAccesses.push_back(user);
138 if (firstOverwriteCandidate ==
nullptr)
142 assert(writeAncestor &&
143 "write op should be recursively part of the top region");
145 for (
Operation *access : blockingAccesses) {
149 if (accessAncestor ==
nullptr ||
150 !isReachable(writeAncestor, accessAncestor))
152 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
153 LLVM_DEBUG(
DBGS() <<
"Store may not be dead due to op: "
154 << *accessAncestor <<
"\n");
158 LLVM_DEBUG(
DBGS() <<
"Found dead store: " << *write.getOperation()
159 <<
" overwritten by: " << *firstOverwriteCandidate <<
"\n");
160 opToErase.push_back(write.getOperation());
174 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
175 if (read.hasOutOfBoundsDim())
177 LLVM_DEBUG(
DBGS() <<
"Candidate for Forwarding: " << *read.getOperation()
180 vector::TransferWriteOp lastwrite =
nullptr;
184 llvm::SmallDenseSet<Operation *, 32> processed;
185 while (!users.empty()) {
188 if (!processed.insert(user).second)
190 if (isa<ViewLikeOpInterface>(user)) {
196 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
200 cast<VectorTransferOpInterface>(write.getOperation()),
201 cast<VectorTransferOpInterface>(read.getOperation()),
205 cast<MemrefValue>(read.getBase()),
206 cast<MemrefValue>(write.getBase())) &&
208 if (lastwrite ==
nullptr || dominators.dominates(lastwrite, write))
211 assert(dominators.dominates(write, lastwrite));
215 blockingWrites.push_back(user);
218 if (lastwrite ==
nullptr)
223 assert(readAncestor &&
224 "read op should be recursively part of the top region");
226 for (
Operation *write : blockingWrites) {
230 if (writeAncestor ==
nullptr || !isReachable(writeAncestor, readAncestor))
232 if (!postDominators.postDominates(lastwrite, write)) {
233 LLVM_DEBUG(
DBGS() <<
"Fail to do write to read forwarding due to op: "
239 LLVM_DEBUG(
DBGS() <<
"Forward value from " << *lastwrite.getOperation()
240 <<
" to: " << *read.getOperation() <<
"\n");
241 read.replaceAllUsesWith(lastwrite.getVector());
242 opToErase.push_back(read.getOperation());
248 for (
const auto size : mixedSizes) {
249 if (llvm::dyn_cast_if_present<Value>(size)) {
250 reducedShape.push_back(ShapedType::kDynamic);
254 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
257 reducedShape.push_back(value.getSExtValue());
268 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269 targetShape, inputType, offsets, sizes, strides);
270 return rankReducedType.canonicalizeStridedLayout();
278 MemRefType inputType = cast<MemRefType>(input.
getType());
284 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
286 if (resultType.canonicalizeStridedLayout() ==
287 inputType.canonicalizeStridedLayout())
289 return rewriter.
create<memref::SubViewOp>(loc, resultType, input, offsets,
295 return llvm::count_if(shape, [](int64_t dimSize) {
return dimSize != 1; });
304 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
306 newShape.push_back(dimSize);
307 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
309 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
313 static FailureOr<Value>
315 vector::CreateMaskOp op) {
316 auto type = op.getType();
318 if (reducedType.getRank() == type.getRank())
322 for (
auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323 type.getShape(), type.getScalableDims(), op.getOperands())) {
324 if (dim == 1 && !dimIsScalable) {
326 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327 if (!constant || (constant.value() != 1))
331 reducedOperands.push_back(operand);
334 .
create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
343 class TransferReadDropUnitDimsPattern
345 using MaskableOpRewritePattern::MaskableOpRewritePattern;
348 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
349 vector::MaskingOpInterface maskingOp,
351 auto loc = transferReadOp.getLoc();
352 Value vector = transferReadOp.getVector();
353 VectorType vectorType = cast<VectorType>(vector.
getType());
354 Value source = transferReadOp.getBase();
355 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
360 if (transferReadOp.hasOutOfBoundsDim())
362 if (!transferReadOp.getPermutationMap().isMinorIdentity())
366 if (reducedRank == sourceType.getRank())
370 if (reducedRank == 0 && maskingOp)
375 if (reducedRank != reducedVectorType.getRank())
377 if (llvm::any_of(transferReadOp.getIndices(), [](
Value v) {
378 return getConstantIntValue(v) != static_cast<int64_t>(0);
382 Value maskOp = transferReadOp.getMask();
384 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
387 transferReadOp,
"unsupported mask op, only 'vector.create_mask' is "
388 "currently supported");
389 FailureOr<Value> rankReducedCreateMask =
391 if (failed(rankReducedCreateMask))
393 maskOp = *rankReducedCreateMask;
396 Value reducedShapeSource =
398 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
402 Operation *newTransferReadOp = rewriter.
create<vector::TransferReadOp>(
403 loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
404 transferReadOp.getPadding(), maskOp,
408 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
409 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
410 maskingOp.getMask());
412 rewriter, newTransferReadOp, shapeCastMask);
415 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
416 loc, vectorType, newTransferReadOp->
getResults()[0]);
425 class TransferWriteDropUnitDimsPattern
427 using MaskableOpRewritePattern::MaskableOpRewritePattern;
430 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
431 vector::MaskingOpInterface maskingOp,
433 auto loc = transferWriteOp.getLoc();
434 Value vector = transferWriteOp.getVector();
435 VectorType vectorType = cast<VectorType>(vector.
getType());
436 Value source = transferWriteOp.getBase();
437 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
442 if (transferWriteOp.hasOutOfBoundsDim())
444 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
448 if (reducedRank == sourceType.getRank())
452 if (reducedRank == 0 && maskingOp)
457 if (reducedRank != reducedVectorType.getRank())
459 if (llvm::any_of(transferWriteOp.getIndices(), [](
Value v) {
460 return getConstantIntValue(v) != static_cast<int64_t>(0);
464 Value maskOp = transferWriteOp.getMask();
466 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
470 "unsupported mask op, only 'vector.create_mask' is "
471 "currently supported");
472 FailureOr<Value> rankReducedCreateMask =
474 if (failed(rankReducedCreateMask))
476 maskOp = *rankReducedCreateMask;
478 Value reducedShapeSource =
480 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
484 auto shapeCastSrc = rewriter.
createOrFold<vector::ShapeCastOp>(
485 loc, reducedVectorType, vector);
487 loc,
Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
491 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
492 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
493 maskingOp.getMask());
498 if (transferWriteOp.hasPureTensorSemantics())
499 return newXferWrite->getResults()[0];
512 Value input, int64_t firstDimToCollapse) {
513 ShapedType inputType = cast<ShapedType>(input.
getType());
514 if (inputType.getRank() == 1)
517 for (int64_t i = 0; i < firstDimToCollapse; ++i)
520 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521 collapsedIndices.push_back(i);
522 reassociation.push_back(collapsedIndices);
523 return rewriter.
create<memref::CollapseShapeOp>(loc, input, reassociation);
532 int64_t firstDimToCollapse) {
533 assert(firstDimToCollapse <
static_cast<int64_t
>(indices.size()));
538 indices.begin(), indices.begin() + firstDimToCollapse);
542 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
543 return indicesAfterCollapsing;
562 rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult();
568 auto &&[collapsedExpr, collapsedVals] =
571 rewriter, loc, collapsedExpr, collapsedVals);
573 if (
auto value = dyn_cast<Value>(collapsedOffset)) {
574 indicesAfterCollapsing.push_back(value);
576 indicesAfterCollapsing.push_back(rewriter.
create<arith::ConstantIndexOp>(
580 return indicesAfterCollapsing;
592 class FlattenContiguousRowMajorTransferReadPattern
595 FlattenContiguousRowMajorTransferReadPattern(
MLIRContext *context,
596 unsigned vectorBitwidth,
599 targetVectorBitwidth(vectorBitwidth) {}
601 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
603 auto loc = transferReadOp.
getLoc();
604 Value vector = transferReadOp.getVector();
605 VectorType vectorType = cast<VectorType>(vector.
getType());
606 auto source = transferReadOp.getBase();
607 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
614 if (vectorType.getRank() <= 1)
616 if (!vectorType.getElementType().isSignlessIntOrFloat())
618 unsigned trailingVectorDimBitwidth =
619 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
620 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
625 if (transferReadOp.hasOutOfBoundsDim())
627 if (!transferReadOp.getPermutationMap().isMinorIdentity())
629 if (transferReadOp.getMask())
634 int64_t firstDimToCollapse =
635 sourceType.getRank() -
636 vectorType.getShape().drop_while([](
auto v) {
return v == 1; }).size();
639 Value collapsedSource =
641 MemRefType collapsedSourceType =
642 cast<MemRefType>(collapsedSource.
getType());
643 int64_t collapsedRank = collapsedSourceType.getRank();
644 assert(collapsedRank == firstDimToCollapse + 1);
657 transferReadOp.getIndices(), firstDimToCollapse);
660 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
661 vectorType.getElementType());
662 vector::TransferReadOp flatRead = rewriter.
create<vector::TransferReadOp>(
663 loc, flatVectorType, collapsedSource, collapsedIndices,
664 transferReadOp.getPadding(), collapsedMap);
670 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
677 unsigned targetVectorBitwidth;
688 class FlattenContiguousRowMajorTransferWritePattern
691 FlattenContiguousRowMajorTransferWritePattern(
MLIRContext *context,
692 unsigned vectorBitwidth,
695 targetVectorBitwidth(vectorBitwidth) {}
697 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
699 auto loc = transferWriteOp.
getLoc();
700 Value vector = transferWriteOp.getVector();
701 VectorType vectorType = cast<VectorType>(vector.
getType());
702 Value source = transferWriteOp.getBase();
703 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
710 if (vectorType.getRank() <= 1)
713 if (!vectorType.getElementType().isSignlessIntOrFloat())
715 unsigned trailingVectorDimBitwidth =
716 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
717 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
722 if (transferWriteOp.hasOutOfBoundsDim())
724 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
726 if (transferWriteOp.getMask())
731 int64_t firstDimToCollapse =
732 sourceType.getRank() -
733 vectorType.getShape().drop_while([](
auto v) {
return v == 1; }).size();
736 Value collapsedSource =
738 MemRefType collapsedSourceType =
739 cast<MemRefType>(collapsedSource.
getType());
740 int64_t collapsedRank = collapsedSourceType.getRank();
741 assert(collapsedRank == firstDimToCollapse + 1);
754 transferWriteOp.getIndices(), firstDimToCollapse);
757 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
758 vectorType.getElementType());
760 rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, vector);
761 vector::TransferWriteOp flatWrite =
762 rewriter.
create<vector::TransferWriteOp>(
763 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
768 rewriter.
eraseOp(transferWriteOp);
775 unsigned targetVectorBitwidth;
781 template <
class VectorExtractOp>
782 class RewriteScalarExtractOfTransferReadBase
787 RewriteScalarExtractOfTransferReadBase(
MLIRContext *context,
789 bool allowMultipleUses)
790 : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
792 LogicalResult match(VectorExtractOp extractOp)
const {
794 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
798 if (isa<VectorType>(extractOp.getResult().getType()))
801 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
804 if (allowMultipleUses &&
805 !llvm::all_of(xferOp->getUses(), [](
OpOperand &use) {
806 return isa<vector::ExtractOp, vector::ExtractElementOp>(
811 if (xferOp.getMask())
814 if (!xferOp.getPermutationMap().isMinorIdentity())
817 if (xferOp.hasOutOfBoundsDim())
823 bool allowMultipleUses;
833 class RewriteScalarExtractElementOfTransferRead
834 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
835 using RewriteScalarExtractOfTransferReadBase::
836 RewriteScalarExtractOfTransferReadBase;
838 LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
840 if (failed(match(extractOp)))
844 auto loc = extractOp.getLoc();
845 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
847 xferOp.getIndices().end());
848 if (extractOp.getPosition()) {
852 rewriter, loc, sym0 + sym1,
853 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
854 if (
auto value = dyn_cast<Value>(ofr)) {
855 newIndices[newIndices.size() - 1] = value;
857 newIndices[newIndices.size() - 1] =
858 rewriter.
create<arith::ConstantIndexOp>(loc,
862 if (isa<MemRefType>(xferOp.getBase().getType())) {
867 extractOp, xferOp.getBase(), newIndices);
882 class RewriteScalarExtractOfTransferRead
883 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
884 using RewriteScalarExtractOfTransferReadBase::
885 RewriteScalarExtractOfTransferReadBase;
887 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
889 if (failed(match(extractOp)))
893 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
895 xferOp.getIndices().end());
897 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
902 if (
auto attr = dyn_cast<Attribute>(pos)) {
903 int64_t offset = cast<IntegerAttr>(attr).getInt();
905 rewriter, extractOp.getLoc(),
908 Value dynamicOffset = cast<Value>(pos);
912 rewriter, extractOp.getLoc(), sym0 + sym1,
913 {newIndices[idx], dynamicOffset});
917 if (
auto value = dyn_cast<Value>(composedIdx)) {
918 newIndices[idx] = value;
920 newIndices[idx] = rewriter.
create<arith::ConstantIndexOp>(
924 if (isa<MemRefType>(xferOp.getBase().getType())) {
929 extractOp, xferOp.getBase(), newIndices);
938 class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
941 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
944 auto vecType = xferOp.getVectorType();
945 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
948 if (xferOp.getMask())
951 if (!xferOp.getPermutationMap().isMinorIdentity())
955 rewriter.
create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
957 if (isa<MemRefType>(xferOp.getBase().getType())) {
959 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
962 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
972 TransferOptimization opt(rewriter, rootOp);
975 rootOp->
walk([&](vector::TransferReadOp read) {
976 if (isa<MemRefType>(read.getShapedType()))
977 opt.storeToLoadForwarding(read);
980 rootOp->
walk([&](vector::TransferWriteOp write) {
981 if (isa<MemRefType>(write.getShapedType()))
982 opt.deadStoreOp(write);
989 bool allowMultipleUses) {
990 patterns.add<RewriteScalarExtractElementOfTransferRead,
991 RewriteScalarExtractOfTransferRead>(
patterns.getContext(),
992 benefit, allowMultipleUses);
996 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
999 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
1003 void mlir::vector::populateFlattenVectorTransferPatterns(
1006 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
1007 FlattenContiguousRowMajorTransferWritePattern>(
1008 patterns.getContext(), targetVectorBitwidth, benefit);
1009 populateDropUnitDimWithShapeCastPatterns(
patterns, benefit);
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
A class for computing basic dominance information.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.