25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "vector-transfer-opt"
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
46 class TransferOptimization {
49 : rewriter(rewriter), dominators(op), postDominators(op) {}
50 void deadStoreOp(vector::TransferWriteOp);
51 void storeToLoadForwarding(vector::TransferReadOp);
63 std::vector<Operation *> opToErase;
71 "This function only works for ops i the same region");
73 if (dominators.dominates(start, dest))
80 while (!worklist.empty()) {
81 Block *bb = worklist.pop_back_val();
82 if (!visited.insert(bb).second)
84 if (dominators.dominates(bb, destBlock))
102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103 LLVM_DEBUG(
DBGS() <<
"Candidate for dead store: " << *write.getOperation()
106 Operation *firstOverwriteCandidate =
nullptr;
107 Value source = write.getSource();
109 while (
auto subView = source.
getDefiningOp<memref::SubViewOp>())
110 source = subView.getSource();
113 llvm::SmallDenseSet<Operation *, 32> processed;
114 while (!users.empty()) {
117 if (!processed.insert(user).second)
119 if (
auto subView = dyn_cast<memref::SubViewOp>(user)) {
120 users.append(subView->getUsers().begin(), subView->getUsers().end());
125 if (user == write.getOperation())
127 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
129 if (write.getSource() == nextWrite.getSource() &&
131 postDominators.postDominates(nextWrite, write)) {
132 if (firstOverwriteCandidate ==
nullptr ||
133 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134 firstOverwriteCandidate = nextWrite;
137 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
141 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
144 cast<VectorTransferOpInterface>(write.getOperation()),
145 cast<VectorTransferOpInterface>(transferOp.getOperation()),
149 blockingAccesses.push_back(user);
151 if (firstOverwriteCandidate ==
nullptr)
155 assert(writeAncestor &&
156 "write op should be recursively part of the top region");
158 for (
Operation *access : blockingAccesses) {
162 if (accessAncestor ==
nullptr ||
165 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
166 LLVM_DEBUG(
DBGS() <<
"Store may not be dead due to op: "
167 << *accessAncestor <<
"\n");
171 LLVM_DEBUG(
DBGS() <<
"Found dead store: " << *write.getOperation()
172 <<
" overwritten by: " << *firstOverwriteCandidate <<
"\n");
173 opToErase.push_back(write.getOperation());
187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
188 if (read.hasOutOfBoundsDim())
190 LLVM_DEBUG(
DBGS() <<
"Candidate for Forwarding: " << *read.getOperation()
193 vector::TransferWriteOp lastwrite =
nullptr;
194 Value source = read.getSource();
196 while (
auto subView = source.
getDefiningOp<memref::SubViewOp>())
197 source = subView.getSource();
200 llvm::SmallDenseSet<Operation *, 32> processed;
201 while (!users.empty()) {
204 if (!processed.insert(user).second)
206 if (
auto subView = dyn_cast<memref::SubViewOp>(user)) {
207 users.append(subView->getUsers().begin(), subView->getUsers().end());
210 if (
auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211 users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
216 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
220 cast<VectorTransferOpInterface>(write.getOperation()),
221 cast<VectorTransferOpInterface>(read.getOperation()),
224 if (write.getSource() == read.getSource() &&
226 if (lastwrite ==
nullptr || dominators.dominates(lastwrite, write))
229 assert(dominators.dominates(write, lastwrite));
233 blockingWrites.push_back(user);
236 if (lastwrite ==
nullptr)
241 assert(readAncestor &&
242 "read op should be recursively part of the top region");
244 for (
Operation *write : blockingWrites) {
248 if (writeAncestor ==
nullptr || !
isReachable(writeAncestor, readAncestor))
250 if (!postDominators.postDominates(lastwrite, write)) {
251 LLVM_DEBUG(
DBGS() <<
"Fail to do write to read forwarding due to op: "
257 LLVM_DEBUG(
DBGS() <<
"Forward value from " << *lastwrite.getOperation()
258 <<
" to: " << *read.getOperation() <<
"\n");
259 read.replaceAllUsesWith(lastwrite.getVector());
260 opToErase.push_back(read.getOperation());
266 for (
const auto size : mixedSizes) {
267 if (llvm::dyn_cast_if_present<Value>(size)) {
268 reducedShape.push_back(ShapedType::kDynamic);
272 auto value = cast<IntegerAttr>(size.get<
Attribute>()).getValue();
275 reducedShape.push_back(value.getSExtValue());
286 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
287 targetShape, inputType, offsets, sizes, strides);
296 MemRefType inputType = cast<MemRefType>(input.
getType());
302 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
307 return rewriter.
create<memref::SubViewOp>(loc, resultType, input, offsets,
313 return llvm::count_if(shape, [](int64_t dimSize) {
return dimSize != 1; });
322 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
324 newShape.push_back(dimSize);
325 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
327 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
333 vector::CreateMaskOp op) {
334 auto type = op.getType();
336 if (reducedType.getRank() == type.getRank())
340 for (
auto [dim, dimIsScalable, operand] : llvm::zip_equal(
341 type.getShape(), type.getScalableDims(), op.
getOperands())) {
342 if (dim == 1 && !dimIsScalable) {
344 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
345 if (!constant || (constant.value() != 1))
349 reducedOperands.push_back(operand);
352 .
create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
361 class TransferReadDropUnitDimsPattern
365 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
367 auto loc = transferReadOp.getLoc();
368 Value vector = transferReadOp.getVector();
369 VectorType vectorType = cast<VectorType>(vector.
getType());
370 Value source = transferReadOp.getSource();
371 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
376 if (transferReadOp.hasOutOfBoundsDim())
378 if (!transferReadOp.getPermutationMap().isMinorIdentity())
382 if (reducedRank == sourceType.getRank())
387 if (reducedRank != reducedVectorType.getRank())
389 if (llvm::any_of(transferReadOp.getIndices(), [](
Value v) {
390 return getConstantIntValue(v) != static_cast<int64_t>(0);
394 Value maskOp = transferReadOp.getMask();
396 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
399 transferReadOp,
"unsupported mask op, only 'vector.create_mask' is "
400 "currently supported");
403 if (
failed(rankReducedCreateMask))
405 maskOp = *rankReducedCreateMask;
408 Value reducedShapeSource =
410 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
414 auto newTransferReadOp = rewriter.
create<vector::TransferReadOp>(
415 loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
416 transferReadOp.getPadding(), maskOp,
418 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
419 loc, vectorType, newTransferReadOp);
420 rewriter.
replaceOp(transferReadOp, shapeCast);
429 class TransferWriteDropUnitDimsPattern
433 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
435 auto loc = transferWriteOp.getLoc();
436 Value vector = transferWriteOp.getVector();
437 VectorType vectorType = cast<VectorType>(vector.
getType());
438 Value source = transferWriteOp.getSource();
439 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
444 if (transferWriteOp.hasOutOfBoundsDim())
446 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
450 if (reducedRank == sourceType.getRank())
455 if (reducedRank != reducedVectorType.getRank())
457 if (llvm::any_of(transferWriteOp.getIndices(), [](
Value v) {
458 return getConstantIntValue(v) != static_cast<int64_t>(0);
462 Value maskOp = transferWriteOp.getMask();
464 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
468 "unsupported mask op, only 'vector.create_mask' is "
469 "currently supported");
472 if (
failed(rankReducedCreateMask))
474 maskOp = *rankReducedCreateMask;
476 Value reducedShapeSource =
478 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
482 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
483 loc, reducedVectorType, vector);
485 transferWriteOp,
Type(), shapeCast, reducedShapeSource, zeros,
497 Value input, int64_t firstDimToCollapse) {
498 ShapedType inputType = cast<ShapedType>(input.
getType());
499 if (inputType.getRank() == 1)
502 for (int64_t i = 0; i < firstDimToCollapse; ++i)
505 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
506 collapsedIndices.push_back(i);
507 reassociation.push_back(collapsedIndices);
508 return rewriter.
create<memref::CollapseShapeOp>(loc, input, reassociation);
519 int64_t rank = indices.size();
520 if (firstDimToCollapse >= rank)
522 for (int64_t i = firstDimToCollapse; i < rank; ++i) {
524 if (!cst || cst.value() != 0)
527 outIndices = indices;
528 outIndices.resize(firstDimToCollapse + 1);
541 class FlattenContiguousRowMajorTransferReadPattern
544 FlattenContiguousRowMajorTransferReadPattern(
MLIRContext *context,
545 unsigned vectorBitwidth,
548 targetVectorBitwidth(vectorBitwidth) {}
550 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
552 auto loc = transferReadOp.getLoc();
553 Value vector = transferReadOp.getVector();
554 VectorType vectorType = cast<VectorType>(vector.
getType());
555 auto source = transferReadOp.getSource();
556 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
563 if (vectorType.getRank() <= 1)
565 if (!vectorType.getElementType().isSignlessIntOrFloat())
567 unsigned trailingVectorDimBitwidth =
568 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
569 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
574 if (transferReadOp.hasOutOfBoundsDim())
576 if (!transferReadOp.getPermutationMap().isMinorIdentity())
578 if (transferReadOp.getMask())
581 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
584 Value collapsedSource =
586 MemRefType collapsedSourceType =
587 dyn_cast<MemRefType>(collapsedSource.
getType());
588 int64_t collapsedRank = collapsedSourceType.getRank();
589 assert(collapsedRank == firstDimToCollapse + 1);
605 collapsedIndices))) {
608 collapsedIndices.append(indices.begin(),
609 indices.begin() + firstDimToCollapse);
627 rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult();
629 auto sourceShape = sourceType.getShape();
631 sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
634 ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
637 collapsedOffset, collapsedStrides, indicesToCollapse);
639 rewriter, loc, collapsedExpr, collapsedVals);
641 if (collapsedOffset.is<
Value>()) {
642 collapsedIndices.push_back(collapsedOffset.get<
Value>());
644 collapsedIndices.push_back(rewriter.
create<arith::ConstantIndexOp>(
650 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
651 vectorType.getElementType());
652 vector::TransferReadOp flatRead = rewriter.
create<vector::TransferReadOp>(
653 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
659 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
666 unsigned targetVectorBitwidth;
673 class FlattenContiguousRowMajorTransferWritePattern
676 FlattenContiguousRowMajorTransferWritePattern(
MLIRContext *context,
677 unsigned vectorBitwidth,
680 targetVectorBitwidth(vectorBitwidth) {}
682 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
684 auto loc = transferWriteOp.getLoc();
685 Value vector = transferWriteOp.getVector();
686 VectorType vectorType = cast<VectorType>(vector.
getType());
687 Value source = transferWriteOp.getSource();
688 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
692 if (vectorType.getRank() <= 1)
695 if (!vectorType.getElementType().isSignlessIntOrFloat())
697 unsigned trailingVectorDimBitwidth =
698 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
699 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
703 int64_t firstContiguousInnerDim =
704 sourceType.getRank() - vectorType.getRank();
706 if (transferWriteOp.hasOutOfBoundsDim())
708 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
710 if (transferWriteOp.getMask())
714 firstContiguousInnerDim,
718 Value collapsedSource =
720 MemRefType collapsedSourceType =
721 cast<MemRefType>(collapsedSource.
getType());
722 int64_t collapsedRank = collapsedSourceType.getRank();
723 assert(collapsedRank == firstContiguousInnerDim + 1);
728 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
729 vectorType.getElementType());
731 rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, vector);
732 vector::TransferWriteOp flatWrite =
733 rewriter.
create<vector::TransferWriteOp>(
734 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
736 rewriter.
eraseOp(transferWriteOp);
743 unsigned targetVectorBitwidth;
749 template <
class VectorExtractOp>
750 class RewriteScalarExtractOfTransferReadBase
755 RewriteScalarExtractOfTransferReadBase(
MLIRContext *context,
757 bool allowMultipleUses)
759 allowMultipleUses(allowMultipleUses) {}
761 LogicalResult match(VectorExtractOp extractOp)
const override {
763 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
767 if (isa<VectorType>(extractOp.getResult().getType()))
770 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
773 if (allowMultipleUses &&
774 !llvm::all_of(xferOp->getUses(), [](
OpOperand &use) {
775 return isa<vector::ExtractOp, vector::ExtractElementOp>(
780 if (xferOp.getMask())
783 if (!xferOp.getPermutationMap().isMinorIdentity())
786 if (xferOp.hasOutOfBoundsDim())
792 bool allowMultipleUses;
802 class RewriteScalarExtractElementOfTransferRead
803 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
804 using RewriteScalarExtractOfTransferReadBase::
805 RewriteScalarExtractOfTransferReadBase;
807 void rewrite(vector::ExtractElementOp extractOp,
810 auto loc = extractOp.getLoc();
811 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
813 xferOp.getIndices().end());
814 if (extractOp.getPosition()) {
818 rewriter, loc, sym0 + sym1,
819 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
820 if (ofr.is<
Value>()) {
821 newIndices[newIndices.size() - 1] = ofr.get<
Value>();
823 newIndices[newIndices.size() - 1] =
824 rewriter.
create<arith::ConstantIndexOp>(loc,
828 if (isa<MemRefType>(xferOp.getSource().getType())) {
833 extractOp, xferOp.getSource(), newIndices);
846 class RewriteScalarExtractOfTransferRead
847 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
848 using RewriteScalarExtractOfTransferReadBase::
849 RewriteScalarExtractOfTransferReadBase;
851 void rewrite(vector::ExtractOp extractOp,
854 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
856 xferOp.getIndices().end());
858 assert(pos.is<
Attribute>() &&
"Unexpected non-constant index");
859 int64_t offset = cast<IntegerAttr>(pos.get<
Attribute>()).getInt();
860 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
862 rewriter, extractOp.getLoc(),
864 if (ofr.is<
Value>()) {
865 newIndices[idx] = ofr.get<
Value>();
867 newIndices[idx] = rewriter.
create<arith::ConstantIndexOp>(
871 if (isa<MemRefType>(xferOp.getSource().getType())) {
876 extractOp, xferOp.getSource(), newIndices);
883 class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
886 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
889 auto vecType = xferOp.getVectorType();
890 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
893 if (xferOp.getMask())
896 if (!xferOp.getPermutationMap().isMinorIdentity())
900 if (vecType.getRank() == 0) {
903 scalar = rewriter.
create<vector::ExtractElementOp>(xferOp.getLoc(),
907 scalar = rewriter.
create<vector::ExtractOp>(xferOp.getLoc(),
908 xferOp.getVector(), pos);
911 if (isa<MemRefType>(xferOp.getSource().getType())) {
913 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
916 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
926 TransferOptimization opt(rewriter, rootOp);
929 rootOp->
walk([&](vector::TransferReadOp read) {
930 if (isa<MemRefType>(read.getShapedType()))
931 opt.storeToLoadForwarding(read);
934 rootOp->
walk([&](vector::TransferWriteOp write) {
935 if (isa<MemRefType>(write.getShapedType()))
936 opt.deadStoreOp(write);
943 bool allowMultipleUses) {
944 patterns.
add<RewriteScalarExtractElementOfTransferRead,
945 RewriteScalarExtractOfTransferRead>(patterns.
getContext(),
946 benefit, allowMultipleUses);
947 patterns.
add<RewriteScalarWrite>(patterns.
getContext(), benefit);
953 .
add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
961 patterns.
add<FlattenContiguousRowMajorTransferReadPattern,
962 FlattenContiguousRowMajorTransferWritePattern>(
963 patterns.
getContext(), targetVectorBitwidth, benefit);
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
succ_iterator succ_begin()
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
A class for computing basic dominance information.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...