23#include "llvm/ADT/Repeated.h"
36static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
37 return loadOp.getMemRef();
42static memref::LoadOp rebuildLoadOp(
RewriterBase &rewriter,
43 memref::LoadOp loadOp,
Value srcMemRef,
46 return memref::LoadOp::create(rewriter, loc, srcMemRef,
indices,
47 loadOp.getNontemporal());
53getLoadOpViewSizeForEachDim(
RewriterBase &rewriter, memref::LoadOp loadOp) {
54 MemRefType ldTy = loadOp.getMemRefType();
55 unsigned loadRank = ldTy.getRank();
66static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
67 return storeOp.getMemRef();
72static memref::StoreOp rebuildStoreOp(
RewriterBase &rewriter,
73 memref::StoreOp storeOp,
Value srcMemRef,
76 return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
77 srcMemRef,
indices, storeOp.getNontemporal());
83getStoreOpViewSizeForEachDim(
RewriterBase &rewriter, memref::StoreOp storeOp) {
84 MemRefType ldTy = storeOp.getMemRefType();
85 unsigned loadRank = ldTy.getRank();
96static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97 return ldMatrixOp.getSrcMemref();
102static nvgpu::LdMatrixOp rebuildLdMatrixOp(
RewriterBase &rewriter,
103 nvgpu::LdMatrixOp ldMatrixOp,
107 return nvgpu::LdMatrixOp::create(
108 rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef,
indices,
109 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
119template <
typename TransferLikeOp>
120static FailureOr<Value>
121getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122 Value src = transferLikeOp.getBase();
123 if (isa<MemRefType>(src.
getType()))
130static vector::TransferReadOp
132 vector::TransferReadOp transferReadOp,
Value srcMemRef,
134 Location loc = transferReadOp.getLoc();
135 return vector::TransferReadOp::create(
136 rewriter, loc, transferReadOp.getResult().getType(), srcMemRef,
indices,
137 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
148static vector::TransferWriteOp
150 vector::TransferWriteOp transferWriteOp,
Value srcMemRef,
152 Location loc = transferWriteOp.getLoc();
153 return vector::TransferWriteOp::create(
154 rewriter, loc, transferWriteOp.getValue(), srcMemRef,
indices,
155 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156 transferWriteOp.getInBoundsAttr());
167template <
typename LoadStoreLikeOp,
168 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
169static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171 assert(!
failed(failureOrSrcMemRef) &&
"Generic getSrcMemRef cannot be used");
172 return *failureOrSrcMemRef;
180template <
typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
183 LoadStoreLikeOp loadStoreLikeOp) {
184 Location loc = loadStoreLikeOp.getLoc();
185 auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
186 rewriter, loc, getSrcMemRef(loadStoreLikeOp));
188 extractStridedMetadataOp.getConstifiedMixedSizes();
196 for (
auto [srcSize, indice] : llvm::zip(srcSizes,
indices)) {
198 rewriter, loc, s0 - s1, {srcSize, indice}));
221template <
typename LoadStoreLikeOp,
222 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
223 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
228 getGenericOpViewSizeForEachDim<
230 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
234 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
236 FailureOr<Value> failureOrSrcMemRef =
237 getFailureOrSrcMemRef(loadStoreLikeOp);
238 if (
failed(failureOrSrcMemRef))
240 "source is not a memref");
241 Value srcMemRef = *failureOrSrcMemRef;
242 auto ldStTy = cast<MemRefType>(srcMemRef.
getType());
243 unsigned loadStoreRank = ldStTy.getRank();
245 if (loadStoreRank == 0)
247 "0-D accesses don't need rewriting");
255 loadStoreLikeOp,
"no computation to extract: offsets are 0s");
261 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
262 assert(sizes.size() == loadStoreRank &&
263 "Expected one size per load dimension");
264 Location loc = loadStoreLikeOp.getLoc();
269 memref::SubViewOp::create(rewriter, loc, srcMemRef,
275 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
276 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
277 rewriter.
replaceOp(loadStoreLikeOp, newLoadStore->getResults());
286 LoadStoreLikeOpRewriter<
290 getLoadOpViewSizeForEachDim>,
291 LoadStoreLikeOpRewriter<
295 getStoreOpViewSizeForEachDim>,
296 LoadStoreLikeOpRewriter<
298 getLdMatrixOpSrcMemRef,
300 LoadStoreLikeOpRewriter<
301 vector::TransferReadOp,
302 getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
303 rebuildTransferReadOp>,
304 LoadStoreLikeOpRewriter<
305 vector::TransferWriteOp,
306 getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
307 rebuildTransferWriteOp>>(
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...