26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "vector-transfer-opt"
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
47 class TransferOptimization {
50 : rewriter(rewriter), dominators(op), postDominators(op) {}
51 void deadStoreOp(vector::TransferWriteOp);
52 void storeToLoadForwarding(vector::TransferReadOp);
64 std::vector<Operation *> opToErase;
72 "This function only works for ops i the same region");
74 if (dominators.dominates(start, dest))
90 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
91 LLVM_DEBUG(
DBGS() <<
"Candidate for dead store: " << *write.getOperation()
94 Operation *firstOverwriteCandidate =
nullptr;
98 llvm::SmallDenseSet<Operation *, 32> processed;
99 while (!users.empty()) {
102 if (!processed.insert(user).second)
104 if (isa<ViewLikeOpInterface>(user)) {
110 if (user == write.getOperation())
112 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
115 cast<MemrefValue>(nextWrite.getSource()),
116 cast<MemrefValue>(write.getSource())) &&
118 postDominators.postDominates(nextWrite, write)) {
119 if (firstOverwriteCandidate ==
nullptr ||
120 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
121 firstOverwriteCandidate = nextWrite;
124 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
128 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
131 cast<VectorTransferOpInterface>(write.getOperation()),
132 cast<VectorTransferOpInterface>(transferOp.getOperation()),
136 blockingAccesses.push_back(user);
138 if (firstOverwriteCandidate ==
nullptr)
142 assert(writeAncestor &&
143 "write op should be recursively part of the top region");
145 for (
Operation *access : blockingAccesses) {
149 if (accessAncestor ==
nullptr ||
150 !isReachable(writeAncestor, accessAncestor))
152 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
153 LLVM_DEBUG(
DBGS() <<
"Store may not be dead due to op: "
154 << *accessAncestor <<
"\n");
158 LLVM_DEBUG(
DBGS() <<
"Found dead store: " << *write.getOperation()
159 <<
" overwritten by: " << *firstOverwriteCandidate <<
"\n");
160 opToErase.push_back(write.getOperation());
174 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
175 if (read.hasOutOfBoundsDim())
177 LLVM_DEBUG(
DBGS() <<
"Candidate for Forwarding: " << *read.getOperation()
180 vector::TransferWriteOp lastwrite =
nullptr;
184 llvm::SmallDenseSet<Operation *, 32> processed;
185 while (!users.empty()) {
188 if (!processed.insert(user).second)
190 if (isa<ViewLikeOpInterface>(user)) {
196 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
200 cast<VectorTransferOpInterface>(write.getOperation()),
201 cast<VectorTransferOpInterface>(read.getOperation()),
205 cast<MemrefValue>(read.getSource()),
206 cast<MemrefValue>(write.getSource())) &&
208 if (lastwrite ==
nullptr || dominators.dominates(lastwrite, write))
211 assert(dominators.dominates(write, lastwrite));
215 blockingWrites.push_back(user);
218 if (lastwrite ==
nullptr)
223 assert(readAncestor &&
224 "read op should be recursively part of the top region");
226 for (
Operation *write : blockingWrites) {
230 if (writeAncestor ==
nullptr || !isReachable(writeAncestor, readAncestor))
232 if (!postDominators.postDominates(lastwrite, write)) {
233 LLVM_DEBUG(
DBGS() <<
"Fail to do write to read forwarding due to op: "
239 LLVM_DEBUG(
DBGS() <<
"Forward value from " << *lastwrite.getOperation()
240 <<
" to: " << *read.getOperation() <<
"\n");
241 read.replaceAllUsesWith(lastwrite.getVector());
242 opToErase.push_back(read.getOperation());
248 for (
const auto size : mixedSizes) {
249 if (llvm::dyn_cast_if_present<Value>(size)) {
250 reducedShape.push_back(ShapedType::kDynamic);
254 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
257 reducedShape.push_back(value.getSExtValue());
268 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269 targetShape, inputType, offsets, sizes, strides);
278 MemRefType inputType = cast<MemRefType>(input.
getType());
284 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
289 return rewriter.
create<memref::SubViewOp>(loc, resultType, input, offsets,
295 return llvm::count_if(shape, [](int64_t dimSize) {
return dimSize != 1; });
304 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
306 newShape.push_back(dimSize);
307 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
309 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
313 static FailureOr<Value>
315 vector::CreateMaskOp op) {
316 auto type = op.getType();
318 if (reducedType.getRank() == type.getRank())
322 for (
auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323 type.getShape(), type.getScalableDims(), op.getOperands())) {
324 if (dim == 1 && !dimIsScalable) {
326 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327 if (!constant || (constant.value() != 1))
331 reducedOperands.push_back(operand);
334 .
create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
343 class TransferReadDropUnitDimsPattern
345 using MaskableOpRewritePattern::MaskableOpRewritePattern;
348 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
349 vector::MaskingOpInterface maskingOp,
351 auto loc = transferReadOp.getLoc();
352 Value vector = transferReadOp.getVector();
353 VectorType vectorType = cast<VectorType>(vector.
getType());
354 Value source = transferReadOp.getSource();
355 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
360 if (transferReadOp.hasOutOfBoundsDim())
362 if (!transferReadOp.getPermutationMap().isMinorIdentity())
366 if (reducedRank == sourceType.getRank())
370 if (reducedRank == 0 && maskingOp)
375 if (reducedRank != reducedVectorType.getRank())
377 if (llvm::any_of(transferReadOp.getIndices(), [](
Value v) {
378 return getConstantIntValue(v) != static_cast<int64_t>(0);
382 Value maskOp = transferReadOp.getMask();
384 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
387 transferReadOp,
"unsupported mask op, only 'vector.create_mask' is "
388 "currently supported");
389 FailureOr<Value> rankReducedCreateMask =
391 if (failed(rankReducedCreateMask))
393 maskOp = *rankReducedCreateMask;
396 Value reducedShapeSource =
398 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
402 Operation *newTransferReadOp = rewriter.
create<vector::TransferReadOp>(
403 loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
404 transferReadOp.getPadding(), maskOp,
408 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
409 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
410 maskingOp.getMask());
412 rewriter, newTransferReadOp, shapeCastMask);
415 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
416 loc, vectorType, newTransferReadOp->
getResults()[0]);
425 class TransferWriteDropUnitDimsPattern
427 using MaskableOpRewritePattern::MaskableOpRewritePattern;
430 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
431 vector::MaskingOpInterface maskingOp,
433 auto loc = transferWriteOp.getLoc();
434 Value vector = transferWriteOp.getVector();
435 VectorType vectorType = cast<VectorType>(vector.
getType());
436 Value source = transferWriteOp.getSource();
437 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
442 if (transferWriteOp.hasOutOfBoundsDim())
444 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
448 if (reducedRank == sourceType.getRank())
452 if (reducedRank == 0 && maskingOp)
457 if (reducedRank != reducedVectorType.getRank())
459 if (llvm::any_of(transferWriteOp.getIndices(), [](
Value v) {
460 return getConstantIntValue(v) != static_cast<int64_t>(0);
464 Value maskOp = transferWriteOp.getMask();
466 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
470 "unsupported mask op, only 'vector.create_mask' is "
471 "currently supported");
472 FailureOr<Value> rankReducedCreateMask =
474 if (failed(rankReducedCreateMask))
476 maskOp = *rankReducedCreateMask;
478 Value reducedShapeSource =
480 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
484 auto shapeCastSrc = rewriter.
createOrFold<vector::ShapeCastOp>(
485 loc, reducedVectorType, vector);
487 loc,
Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
491 auto shapeCastMask = rewriter.
createOrFold<vector::ShapeCastOp>(
492 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.
getI1Type()),
493 maskingOp.getMask());
498 if (transferWriteOp.hasPureTensorSemantics())
499 return newXferWrite->getResults()[0];
512 Value input, int64_t firstDimToCollapse) {
513 ShapedType inputType = cast<ShapedType>(input.
getType());
514 if (inputType.getRank() == 1)
517 for (int64_t i = 0; i < firstDimToCollapse; ++i)
520 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521 collapsedIndices.push_back(i);
522 reassociation.push_back(collapsedIndices);
523 return rewriter.
create<memref::CollapseShapeOp>(loc, input, reassociation);
532 int64_t firstDimToCollapse) {
533 assert(firstDimToCollapse <
static_cast<int64_t
>(indices.size()));
538 indices.begin(), indices.begin() + firstDimToCollapse);
541 if (llvm::all_of(indicesToCollapse,
isZeroIndex)) {
542 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
543 return indicesAfterCollapsing;
562 rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult();
568 auto &&[collapsedExpr, collapsedVals] =
571 rewriter, loc, collapsedExpr, collapsedVals);
573 if (
auto value = dyn_cast<Value>(collapsedOffset)) {
574 indicesAfterCollapsing.push_back(value);
576 indicesAfterCollapsing.push_back(rewriter.
create<arith::ConstantIndexOp>(
580 return indicesAfterCollapsing;
593 class FlattenContiguousRowMajorTransferReadPattern
596 FlattenContiguousRowMajorTransferReadPattern(
MLIRContext *context,
597 unsigned vectorBitwidth,
600 targetVectorBitwidth(vectorBitwidth) {}
602 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
604 auto loc = transferReadOp.
getLoc();
605 Value vector = transferReadOp.getVector();
606 VectorType vectorType = cast<VectorType>(vector.
getType());
607 auto source = transferReadOp.getSource();
608 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
615 if (vectorType.getRank() <= 1)
617 if (!vectorType.getElementType().isSignlessIntOrFloat())
619 unsigned trailingVectorDimBitwidth =
620 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
621 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
626 if (transferReadOp.hasOutOfBoundsDim())
628 if (!transferReadOp.getPermutationMap().isMinorIdentity())
630 if (transferReadOp.getMask())
633 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
636 Value collapsedSource =
638 MemRefType collapsedSourceType =
639 cast<MemRefType>(collapsedSource.
getType());
640 int64_t collapsedRank = collapsedSourceType.getRank();
641 assert(collapsedRank == firstDimToCollapse + 1);
654 transferReadOp.getIndices(), firstDimToCollapse);
657 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
658 vectorType.getElementType());
659 vector::TransferReadOp flatRead = rewriter.
create<vector::TransferReadOp>(
660 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
666 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
673 unsigned targetVectorBitwidth;
684 class FlattenContiguousRowMajorTransferWritePattern
687 FlattenContiguousRowMajorTransferWritePattern(
MLIRContext *context,
688 unsigned vectorBitwidth,
691 targetVectorBitwidth(vectorBitwidth) {}
693 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
695 auto loc = transferWriteOp.
getLoc();
696 Value vector = transferWriteOp.getVector();
697 VectorType vectorType = cast<VectorType>(vector.
getType());
698 Value source = transferWriteOp.getSource();
699 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
706 if (vectorType.getRank() <= 1)
709 if (!vectorType.getElementType().isSignlessIntOrFloat())
711 unsigned trailingVectorDimBitwidth =
712 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
713 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
718 if (transferWriteOp.hasOutOfBoundsDim())
720 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
722 if (transferWriteOp.getMask())
725 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
728 Value collapsedSource =
730 MemRefType collapsedSourceType =
731 cast<MemRefType>(collapsedSource.
getType());
732 int64_t collapsedRank = collapsedSourceType.getRank();
733 assert(collapsedRank == firstDimToCollapse + 1);
746 transferWriteOp.getIndices(), firstDimToCollapse);
749 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
750 vectorType.getElementType());
752 rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, vector);
753 vector::TransferWriteOp flatWrite =
754 rewriter.
create<vector::TransferWriteOp>(
755 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
760 rewriter.
eraseOp(transferWriteOp);
767 unsigned targetVectorBitwidth;
773 template <
class VectorExtractOp>
774 class RewriteScalarExtractOfTransferReadBase
779 RewriteScalarExtractOfTransferReadBase(
MLIRContext *context,
781 bool allowMultipleUses)
783 allowMultipleUses(allowMultipleUses) {}
785 LogicalResult match(VectorExtractOp extractOp)
const override {
787 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
791 if (isa<VectorType>(extractOp.getResult().getType()))
794 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
797 if (allowMultipleUses &&
798 !llvm::all_of(xferOp->getUses(), [](
OpOperand &use) {
799 return isa<vector::ExtractOp, vector::ExtractElementOp>(
804 if (xferOp.getMask())
807 if (!xferOp.getPermutationMap().isMinorIdentity())
810 if (xferOp.hasOutOfBoundsDim())
816 bool allowMultipleUses;
826 class RewriteScalarExtractElementOfTransferRead
827 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
828 using RewriteScalarExtractOfTransferReadBase::
829 RewriteScalarExtractOfTransferReadBase;
831 void rewrite(vector::ExtractElementOp extractOp,
834 auto loc = extractOp.getLoc();
835 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
837 xferOp.getIndices().end());
838 if (extractOp.getPosition()) {
842 rewriter, loc, sym0 + sym1,
843 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
844 if (
auto value = dyn_cast<Value>(ofr)) {
845 newIndices[newIndices.size() - 1] = value;
847 newIndices[newIndices.size() - 1] =
848 rewriter.
create<arith::ConstantIndexOp>(loc,
852 if (isa<MemRefType>(xferOp.getSource().getType())) {
857 extractOp, xferOp.getSource(), newIndices);
870 class RewriteScalarExtractOfTransferRead
871 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
872 using RewriteScalarExtractOfTransferReadBase::
873 RewriteScalarExtractOfTransferReadBase;
875 void rewrite(vector::ExtractOp extractOp,
878 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
880 xferOp.getIndices().end());
882 assert(isa<Attribute>(pos) &&
"Unexpected non-constant index");
883 int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
884 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
886 rewriter, extractOp.getLoc(),
888 if (
auto value = dyn_cast<Value>(ofr)) {
889 newIndices[idx] = value;
891 newIndices[idx] = rewriter.
create<arith::ConstantIndexOp>(
895 if (isa<MemRefType>(xferOp.getSource().getType())) {
900 extractOp, xferOp.getSource(), newIndices);
907 class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
910 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
913 auto vecType = xferOp.getVectorType();
914 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
917 if (xferOp.getMask())
920 if (!xferOp.getPermutationMap().isMinorIdentity())
924 if (vecType.getRank() == 0) {
927 scalar = rewriter.
create<vector::ExtractElementOp>(xferOp.getLoc(),
931 scalar = rewriter.
create<vector::ExtractOp>(xferOp.getLoc(),
932 xferOp.getVector(), pos);
935 if (isa<MemRefType>(xferOp.getSource().getType())) {
937 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
940 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
950 TransferOptimization opt(rewriter, rootOp);
953 rootOp->
walk([&](vector::TransferReadOp read) {
954 if (isa<MemRefType>(read.getShapedType()))
955 opt.storeToLoadForwarding(read);
958 rootOp->
walk([&](vector::TransferWriteOp write) {
959 if (isa<MemRefType>(write.getShapedType()))
960 opt.deadStoreOp(write);
967 bool allowMultipleUses) {
968 patterns.add<RewriteScalarExtractElementOfTransferRead,
969 RewriteScalarExtractOfTransferRead>(
patterns.getContext(),
970 benefit, allowMultipleUses);
977 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
985 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
986 FlattenContiguousRowMajorTransferWritePattern>(
987 patterns.getContext(), targetVectorBitwidth, benefit);
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
A class for computing basic dominance information.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.