35static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36 return loadOp.getMemRef();
41static memref::LoadOp rebuildLoadOp(
RewriterBase &rewriter,
42 memref::LoadOp loadOp,
Value srcMemRef,
45 return memref::LoadOp::create(rewriter, loc, srcMemRef,
indices,
46 loadOp.getNontemporal());
52getLoadOpViewSizeForEachDim(
RewriterBase &rewriter, memref::LoadOp loadOp) {
53 MemRefType ldTy = loadOp.getMemRefType();
54 unsigned loadRank = ldTy.getRank();
65static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66 return storeOp.getMemRef();
71static memref::StoreOp rebuildStoreOp(
RewriterBase &rewriter,
72 memref::StoreOp storeOp,
Value srcMemRef,
75 return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
76 srcMemRef,
indices, storeOp.getNontemporal());
82getStoreOpViewSizeForEachDim(
RewriterBase &rewriter, memref::StoreOp storeOp) {
83 MemRefType ldTy = storeOp.getMemRefType();
84 unsigned loadRank = ldTy.getRank();
95static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
96 return ldMatrixOp.getSrcMemref();
101static nvgpu::LdMatrixOp rebuildLdMatrixOp(
RewriterBase &rewriter,
102 nvgpu::LdMatrixOp ldMatrixOp,
106 return nvgpu::LdMatrixOp::create(
107 rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef,
indices,
108 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
118template <
typename TransferLikeOp>
119static FailureOr<Value>
120getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
121 Value src = transferLikeOp.getBase();
122 if (isa<MemRefType>(src.
getType()))
129static vector::TransferReadOp
131 vector::TransferReadOp transferReadOp,
Value srcMemRef,
133 Location loc = transferReadOp.getLoc();
134 return vector::TransferReadOp::create(
135 rewriter, loc, transferReadOp.getResult().getType(), srcMemRef,
indices,
136 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
137 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
147static vector::TransferWriteOp
149 vector::TransferWriteOp transferWriteOp,
Value srcMemRef,
151 Location loc = transferWriteOp.getLoc();
152 return vector::TransferWriteOp::create(
153 rewriter, loc, transferWriteOp.getValue(), srcMemRef,
indices,
154 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
155 transferWriteOp.getInBoundsAttr());
166template <
typename LoadStoreLikeOp,
167 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
168static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
169 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
170 assert(!
failed(failureOrSrcMemRef) &&
"Generic getSrcMemRef cannot be used");
171 return *failureOrSrcMemRef;
179template <
typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
182 LoadStoreLikeOp loadStoreLikeOp) {
183 Location loc = loadStoreLikeOp.getLoc();
184 auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
185 rewriter, loc, getSrcMemRef(loadStoreLikeOp));
187 extractStridedMetadataOp.getConstifiedMixedSizes();
195 for (
auto [srcSize, indice] : llvm::zip(srcSizes,
indices)) {
197 rewriter, loc, s0 - s1, {srcSize, indice}));
220template <
typename LoadStoreLikeOp,
221 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
222 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
227 getGenericOpViewSizeForEachDim<
229 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
233 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
235 FailureOr<Value> failureOrSrcMemRef =
236 getFailureOrSrcMemRef(loadStoreLikeOp);
237 if (
failed(failureOrSrcMemRef))
239 "source is not a memref");
240 Value srcMemRef = *failureOrSrcMemRef;
241 auto ldStTy = cast<MemRefType>(srcMemRef.
getType());
242 unsigned loadStoreRank = ldStTy.getRank();
244 if (loadStoreRank == 0)
246 "0-D accesses don't need rewriting");
254 loadStoreLikeOp,
"no computation to extract: offsets are 0s");
260 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
261 assert(sizes.size() == loadStoreRank &&
262 "Expected one size per load dimension");
263 Location loc = loadStoreLikeOp.getLoc();
268 memref::SubViewOp::create(rewriter, loc, srcMemRef,
274 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
275 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
276 rewriter.
replaceOp(loadStoreLikeOp, newLoadStore->getResults());
285 LoadStoreLikeOpRewriter<
289 getLoadOpViewSizeForEachDim>,
290 LoadStoreLikeOpRewriter<
294 getStoreOpViewSizeForEachDim>,
295 LoadStoreLikeOpRewriter<
297 getLdMatrixOpSrcMemRef,
299 LoadStoreLikeOpRewriter<
300 vector::TransferReadOp,
301 getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
302 rebuildTransferReadOp>,
303 LoadStoreLikeOpRewriter<
304 vector::TransferWriteOp,
305 getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
306 rebuildTransferWriteOp>>(
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...