45 auto extractStridedMetadataOp =
46 memref::ExtractStridedMetadataOp::create(rewriter, loc, srcMemRef);
48 extractStridedMetadataOp.getConstifiedMixedSizes();
55 for (
auto [srcSize,
index] : llvm::zip_equal(srcSizes, mixedIndices)) {
57 rewriter, loc, s0 - s1, {srcSize,
index}));
72 memref::IndexedAccessOpInterface op) {
74 assert(srcMemRef &&
"expected indexed access with a memref");
76 MemRefType srcType = srcMemRef.getType();
77 int64_t srcRank = srcType.getRank();
80 assert(accessedRank <= srcRank &&
81 "can't access more dimensions than a memref has");
84 int64_t firstAccessedDim = srcRank - accessedRank;
93 auto ensureSrcSizes = [&]() {
94 if (srcSizes.empty()) {
95 auto extractStridedMetadataOp =
96 memref::ExtractStridedMetadataOp::create(rewriter, loc, srcMemRef);
97 srcSizes = extractStridedMetadataOp.getConstifiedMixedSizes();
101 for (
int64_t accessedDim : llvm::seq<int64_t>(0, accessedRank)) {
102 int64_t accessedSize = accessedShape[accessedDim];
103 int64_t dim = firstAccessedDim + accessedDim;
104 if (!ShapedType::isDynamic(accessedSize)) {
105 int64_t srcDimSize = srcType.getDimSize(dim);
106 if (!ShapedType::isDynamic(srcDimSize) || accessedSize == 1) {
119 rewriter, loc, s0 - s1, {srcSizes[dim],
indices[dim]});
125static memref::SubViewOp createSubviewForAccess(
RewriterBase &rewriter,
133 return memref::SubViewOp::create(rewriter, loc, srcMemRef,
152struct IndexedAccessOpRewriter final
156 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
157 PatternRewriter &rewriter)
const override {
162 int64_t rank = srcMemRef.getType().getRank();
165 "0-D accesses don't need rewriting");
167 if (
static_cast<int64_t
>(op.getAccessedShape().size()) > rank)
169 op,
"can't access more dimensions than a memref has");
171 if (!op.hasInboundsIndices())
178 op,
"no computation to extract: offsets are 0s");
180 SmallVector<OpFoldResult> subviewSizes =
181 getIndexedAccessViewSizes(rewriter, op);
183 Location loc = op.getLoc();
184 auto subview = createSubviewForAccess(rewriter, loc, srcMemRef,
185 op.getIndices(), subviewSizes);
186 SmallVector<Value> zeros = getZeroIndices(rewriter, loc, rank);
188 std::optional<SmallVector<Value>> newValues =
189 op.updateMemrefAndIndices(rewriter, subview.getResult(), zeros);
197struct TransferOpRewriter final
201 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
202 PatternRewriter &rewriter)
const override {
203 Value srcMemRef = op.getBase();
204 auto srcType = dyn_cast<MemRefType>(srcMemRef.
getType());
208 int64_t rank = srcType.getRank();
212 "0-D accesses don't need rewriting");
216 op,
"no computation to extract: offsets are 0s");
218 Location loc = op.getLoc();
220 SmallVector<OpFoldResult> strides(rank, rewriter.
getIndexAttr(1));
225 SmallVector<OpFoldResult> approximateSizes(
227 MemRefType subviewType = memref::SubViewOp::inferResultType(
228 srcType, offsets, approximateSizes, strides);
232 AffineMap permutationMap = op.getPermutationMap();
233 if (
failed(op.mayUpdateStartingPosition(subviewType, permutationMap)))
235 "failed op-specific preconditions");
237 SmallVector<OpFoldResult> sizes =
238 getRemainingSizes(rewriter, loc, srcMemRef, op.getIndices());
239 auto subview = createSubviewForAccess(rewriter, loc, srcMemRef,
240 op.getIndices(), sizes);
241 SmallVector<Value> zeros = getZeroIndices(rewriter, loc, rank);
243 op.updateStartingPosition(rewriter, subview.getResult(), zeros,
244 AffineMapAttr::get(permutationMap));
252 patterns.
add<IndexedAccessOpRewriter, TransferOpRewriter>(
static bool hasAllZeroIndices(LLVM::GEPOp gepOp)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from memory access operations such that these ac...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...