26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "vector-transfer-opt"
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
47 class TransferOptimization {
50 : rewriter(rewriter), dominators(op), postDominators(op) {}
51 void deadStoreOp(vector::TransferWriteOp);
52 void storeToLoadForwarding(vector::TransferReadOp);
64 std::vector<Operation *> opToErase;
72 "This function only works for ops i the same region");
74 if (dominators.dominates(start, dest))
81 while (!worklist.empty()) {
82 Block *bb = worklist.pop_back_val();
83 if (!visited.insert(bb).second)
85 if (dominators.dominates(bb, destBlock))
103 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104 LLVM_DEBUG(
DBGS() <<
"Candidate for dead store: " << *write.getOperation()
107 Operation *firstOverwriteCandidate =
nullptr;
112 llvm::SmallDenseSet<Operation *, 32> processed;
113 while (!users.empty()) {
116 if (!processed.insert(user).second)
118 if (isa<memref::SubViewOp, memref::CastOp>(user)) {
124 if (user == write.getOperation())
126 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
129 cast<MemrefValue>(nextWrite.getSource()),
130 cast<MemrefValue>(write.getSource())) &&
132 postDominators.postDominates(nextWrite, write)) {
133 if (firstOverwriteCandidate ==
nullptr ||
134 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
135 firstOverwriteCandidate = nextWrite;
138 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
142 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
145 cast<VectorTransferOpInterface>(write.getOperation()),
146 cast<VectorTransferOpInterface>(transferOp.getOperation()),
150 blockingAccesses.push_back(user);
152 if (firstOverwriteCandidate ==
nullptr)
156 assert(writeAncestor &&
157 "write op should be recursively part of the top region");
159 for (
Operation *access : blockingAccesses) {
163 if (accessAncestor ==
nullptr ||
166 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
167 LLVM_DEBUG(
DBGS() <<
"Store may not be dead due to op: "
168 << *accessAncestor <<
"\n");
172 LLVM_DEBUG(
DBGS() <<
"Found dead store: " << *write.getOperation()
173 <<
" overwritten by: " << *firstOverwriteCandidate <<
"\n");
174 opToErase.push_back(write.getOperation());
188 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
189 if (read.hasOutOfBoundsDim())
191 LLVM_DEBUG(
DBGS() <<
"Candidate for Forwarding: " << *read.getOperation()
194 vector::TransferWriteOp lastwrite =
nullptr;
199 llvm::SmallDenseSet<Operation *, 32> processed;
200 while (!users.empty()) {
203 if (!processed.insert(user).second)
205 if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
211 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
215 cast<VectorTransferOpInterface>(write.getOperation()),
216 cast<VectorTransferOpInterface>(read.getOperation()),
220 cast<MemrefValue>(read.getSource()),
221 cast<MemrefValue>(write.getSource())) &&
223 if (lastwrite ==
nullptr || dominators.dominates(lastwrite, write))
226 assert(dominators.dominates(write, lastwrite));
230 blockingWrites.push_back(user);
233 if (lastwrite ==
nullptr)
238 assert(readAncestor &&
239 "read op should be recursively part of the top region");
241 for (
Operation *write : blockingWrites) {
245 if (writeAncestor ==
nullptr || !
isReachable(writeAncestor, readAncestor))
247 if (!postDominators.postDominates(lastwrite, write)) {
248 LLVM_DEBUG(
DBGS() <<
"Fail to do write to read forwarding due to op: "
254 LLVM_DEBUG(
DBGS() <<
"Forward value from " << *lastwrite.getOperation()
255 <<
" to: " << *read.getOperation() <<
"\n");
256 read.replaceAllUsesWith(lastwrite.getVector());
257 opToErase.push_back(read.getOperation());
263 for (
const auto size : mixedSizes) {
264 if (llvm::dyn_cast_if_present<Value>(size)) {
265 reducedShape.push_back(ShapedType::kDynamic);
269 auto value = cast<IntegerAttr>(size.get<
Attribute>()).getValue();
272 reducedShape.push_back(value.getSExtValue());
283 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
284 targetShape, inputType, offsets, sizes, strides);
293 MemRefType inputType = cast<MemRefType>(input.
getType());
299 MemRefType resultType =
dropUnitDims(inputType, offsets, sizes, strides);
304 return rewriter.
create<memref::SubViewOp>(loc, resultType, input, offsets,
310 return llvm::count_if(shape, [](int64_t dimSize) {
return dimSize != 1; });
319 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
321 newShape.push_back(dimSize);
322 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
324 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
328 static FailureOr<Value>
330 vector::CreateMaskOp op) {
331 auto type = op.getType();
333 if (reducedType.getRank() == type.getRank())
337 for (
auto [dim, dimIsScalable, operand] : llvm::zip_equal(
338 type.getShape(), type.getScalableDims(), op.
getOperands())) {
339 if (dim == 1 && !dimIsScalable) {
341 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
342 if (!constant || (constant.value() != 1))
346 reducedOperands.push_back(operand);
349 .
create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
358 class TransferReadDropUnitDimsPattern
362 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
364 auto loc = transferReadOp.getLoc();
365 Value vector = transferReadOp.getVector();
366 VectorType vectorType = cast<VectorType>(vector.
getType());
367 Value source = transferReadOp.getSource();
368 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
373 if (transferReadOp.hasOutOfBoundsDim())
375 if (!transferReadOp.getPermutationMap().isMinorIdentity())
379 if (reducedRank == sourceType.getRank())
384 if (reducedRank != reducedVectorType.getRank())
386 if (llvm::any_of(transferReadOp.getIndices(), [](
Value v) {
387 return getConstantIntValue(v) != static_cast<int64_t>(0);
391 Value maskOp = transferReadOp.getMask();
393 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
396 transferReadOp,
"unsupported mask op, only 'vector.create_mask' is "
397 "currently supported");
398 FailureOr<Value> rankReducedCreateMask =
400 if (failed(rankReducedCreateMask))
402 maskOp = *rankReducedCreateMask;
405 Value reducedShapeSource =
407 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
411 auto newTransferReadOp = rewriter.
create<vector::TransferReadOp>(
412 loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
413 transferReadOp.getPadding(), maskOp,
415 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
416 loc, vectorType, newTransferReadOp);
417 rewriter.
replaceOp(transferReadOp, shapeCast);
426 class TransferWriteDropUnitDimsPattern
430 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
432 auto loc = transferWriteOp.getLoc();
433 Value vector = transferWriteOp.getVector();
434 VectorType vectorType = cast<VectorType>(vector.
getType());
435 Value source = transferWriteOp.getSource();
436 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
441 if (transferWriteOp.hasOutOfBoundsDim())
443 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
447 if (reducedRank == sourceType.getRank())
452 if (reducedRank != reducedVectorType.getRank())
454 if (llvm::any_of(transferWriteOp.getIndices(), [](
Value v) {
455 return getConstantIntValue(v) != static_cast<int64_t>(0);
459 Value maskOp = transferWriteOp.getMask();
461 auto createMaskOp = maskOp.
getDefiningOp<vector::CreateMaskOp>();
465 "unsupported mask op, only 'vector.create_mask' is "
466 "currently supported");
467 FailureOr<Value> rankReducedCreateMask =
469 if (failed(rankReducedCreateMask))
471 maskOp = *rankReducedCreateMask;
473 Value reducedShapeSource =
475 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
479 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
480 loc, reducedVectorType, vector);
482 transferWriteOp,
Type(), shapeCast, reducedShapeSource, zeros,
494 Value input, int64_t firstDimToCollapse) {
495 ShapedType inputType = cast<ShapedType>(input.
getType());
496 if (inputType.getRank() == 1)
499 for (int64_t i = 0; i < firstDimToCollapse; ++i)
502 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
503 collapsedIndices.push_back(i);
504 reassociation.push_back(collapsedIndices);
505 return rewriter.
create<memref::CollapseShapeOp>(loc, input, reassociation);
514 int64_t firstDimToCollapse) {
515 assert(firstDimToCollapse <
static_cast<int64_t
>(indices.size()));
520 indices.begin(), indices.begin() + firstDimToCollapse);
523 if (llvm::all_of(indicesToCollapse,
isZeroIndex)) {
524 indicesAfterCollapsing.push_back(indicesToCollapse[0]);
525 return indicesAfterCollapsing;
544 rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult();
550 auto &&[collapsedExpr, collapsedVals] =
553 rewriter, loc, collapsedExpr, collapsedVals);
555 if (collapsedOffset.is<
Value>()) {
556 indicesAfterCollapsing.push_back(collapsedOffset.get<
Value>());
558 indicesAfterCollapsing.push_back(rewriter.
create<arith::ConstantIndexOp>(
562 return indicesAfterCollapsing;
575 class FlattenContiguousRowMajorTransferReadPattern
578 FlattenContiguousRowMajorTransferReadPattern(
MLIRContext *context,
579 unsigned vectorBitwidth,
582 targetVectorBitwidth(vectorBitwidth) {}
584 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
586 auto loc = transferReadOp.
getLoc();
587 Value vector = transferReadOp.getVector();
588 VectorType vectorType = cast<VectorType>(vector.
getType());
589 auto source = transferReadOp.getSource();
590 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
597 if (vectorType.getRank() <= 1)
599 if (!vectorType.getElementType().isSignlessIntOrFloat())
601 unsigned trailingVectorDimBitwidth =
602 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
603 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
608 if (transferReadOp.hasOutOfBoundsDim())
610 if (!transferReadOp.getPermutationMap().isMinorIdentity())
612 if (transferReadOp.getMask())
615 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
618 Value collapsedSource =
620 MemRefType collapsedSourceType =
621 cast<MemRefType>(collapsedSource.
getType());
622 int64_t collapsedRank = collapsedSourceType.getRank();
623 assert(collapsedRank == firstDimToCollapse + 1);
636 transferReadOp.getIndices(), firstDimToCollapse);
639 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
640 vectorType.getElementType());
641 vector::TransferReadOp flatRead = rewriter.
create<vector::TransferReadOp>(
642 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
648 transferReadOp, cast<VectorType>(vector.
getType()), flatRead);
655 unsigned targetVectorBitwidth;
666 class FlattenContiguousRowMajorTransferWritePattern
669 FlattenContiguousRowMajorTransferWritePattern(
MLIRContext *context,
670 unsigned vectorBitwidth,
673 targetVectorBitwidth(vectorBitwidth) {}
675 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
677 auto loc = transferWriteOp.
getLoc();
678 Value vector = transferWriteOp.getVector();
679 VectorType vectorType = cast<VectorType>(vector.
getType());
680 Value source = transferWriteOp.getSource();
681 MemRefType sourceType = dyn_cast<MemRefType>(source.
getType());
688 if (vectorType.getRank() <= 1)
691 if (!vectorType.getElementType().isSignlessIntOrFloat())
693 unsigned trailingVectorDimBitwidth =
694 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
695 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
700 if (transferWriteOp.hasOutOfBoundsDim())
702 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
704 if (transferWriteOp.getMask())
707 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
710 Value collapsedSource =
712 MemRefType collapsedSourceType =
713 cast<MemRefType>(collapsedSource.
getType());
714 int64_t collapsedRank = collapsedSourceType.getRank();
715 assert(collapsedRank == firstDimToCollapse + 1);
728 transferWriteOp.getIndices(), firstDimToCollapse);
731 VectorType flatVectorType =
VectorType::get({vectorType.getNumElements()},
732 vectorType.getElementType());
734 rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, vector);
735 vector::TransferWriteOp flatWrite =
736 rewriter.
create<vector::TransferWriteOp>(
737 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
742 rewriter.
eraseOp(transferWriteOp);
749 unsigned targetVectorBitwidth;
755 template <
class VectorExtractOp>
756 class RewriteScalarExtractOfTransferReadBase
761 RewriteScalarExtractOfTransferReadBase(
MLIRContext *context,
763 bool allowMultipleUses)
765 allowMultipleUses(allowMultipleUses) {}
767 LogicalResult match(VectorExtractOp extractOp)
const override {
769 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
773 if (isa<VectorType>(extractOp.getResult().getType()))
776 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
779 if (allowMultipleUses &&
780 !llvm::all_of(xferOp->getUses(), [](
OpOperand &use) {
781 return isa<vector::ExtractOp, vector::ExtractElementOp>(
786 if (xferOp.getMask())
789 if (!xferOp.getPermutationMap().isMinorIdentity())
792 if (xferOp.hasOutOfBoundsDim())
798 bool allowMultipleUses;
808 class RewriteScalarExtractElementOfTransferRead
809 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
810 using RewriteScalarExtractOfTransferReadBase::
811 RewriteScalarExtractOfTransferReadBase;
813 void rewrite(vector::ExtractElementOp extractOp,
816 auto loc = extractOp.getLoc();
817 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
819 xferOp.getIndices().end());
820 if (extractOp.getPosition()) {
824 rewriter, loc, sym0 + sym1,
825 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
826 if (ofr.is<
Value>()) {
827 newIndices[newIndices.size() - 1] = ofr.get<
Value>();
829 newIndices[newIndices.size() - 1] =
830 rewriter.
create<arith::ConstantIndexOp>(loc,
834 if (isa<MemRefType>(xferOp.getSource().getType())) {
839 extractOp, xferOp.getSource(), newIndices);
852 class RewriteScalarExtractOfTransferRead
853 :
public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
854 using RewriteScalarExtractOfTransferReadBase::
855 RewriteScalarExtractOfTransferReadBase;
857 void rewrite(vector::ExtractOp extractOp,
860 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
862 xferOp.getIndices().end());
864 assert(pos.is<
Attribute>() &&
"Unexpected non-constant index");
865 int64_t offset = cast<IntegerAttr>(pos.get<
Attribute>()).getInt();
866 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
868 rewriter, extractOp.getLoc(),
870 if (ofr.is<
Value>()) {
871 newIndices[idx] = ofr.get<
Value>();
873 newIndices[idx] = rewriter.
create<arith::ConstantIndexOp>(
877 if (isa<MemRefType>(xferOp.getSource().getType())) {
882 extractOp, xferOp.getSource(), newIndices);
889 class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
892 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
895 auto vecType = xferOp.getVectorType();
896 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
899 if (xferOp.getMask())
902 if (!xferOp.getPermutationMap().isMinorIdentity())
906 if (vecType.getRank() == 0) {
909 scalar = rewriter.
create<vector::ExtractElementOp>(xferOp.getLoc(),
913 scalar = rewriter.
create<vector::ExtractOp>(xferOp.getLoc(),
914 xferOp.getVector(), pos);
917 if (isa<MemRefType>(xferOp.getSource().getType())) {
919 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
922 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
932 TransferOptimization opt(rewriter, rootOp);
935 rootOp->
walk([&](vector::TransferReadOp read) {
936 if (isa<MemRefType>(read.getShapedType()))
937 opt.storeToLoadForwarding(read);
940 rootOp->
walk([&](vector::TransferWriteOp write) {
941 if (isa<MemRefType>(write.getShapedType()))
942 opt.deadStoreOp(write);
949 bool allowMultipleUses) {
950 patterns.
add<RewriteScalarExtractElementOfTransferRead,
951 RewriteScalarExtractOfTransferRead>(patterns.
getContext(),
952 benefit, allowMultipleUses);
953 patterns.
add<RewriteScalarWrite>(patterns.
getContext(), benefit);
959 .
add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
967 patterns.
add<FlattenContiguousRowMajorTransferReadPattern,
968 FlattenContiguousRowMajorTransferWritePattern>(
969 patterns.
getContext(), targetVectorBitwidth, benefit);
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
succ_iterator succ_begin()
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
A class for computing basic dominance information.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or are statically known to alias the same region of memory...
MemrefValue skipSubViewsAndCasts(MemrefValue source)
Walk up the source chain until something an op other than a memref.subview or memref....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
Include the generated interface declarations.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...