35 static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36 return loadOp.getMemRef();
41 static memref::LoadOp rebuildLoadOp(
RewriterBase &rewriter,
42 memref::LoadOp loadOp,
Value srcMemRef,
45 return memref::LoadOp::create(rewriter, loc, srcMemRef, indices,
46 loadOp.getNontemporal());
52 getLoadOpViewSizeForEachDim(
RewriterBase &rewriter, memref::LoadOp loadOp) {
53 MemRefType ldTy = loadOp.getMemRefType();
54 unsigned loadRank = ldTy.getRank();
65 static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66 return storeOp.getMemRef();
71 static memref::StoreOp rebuildStoreOp(
RewriterBase &rewriter,
72 memref::StoreOp storeOp,
Value srcMemRef,
75 return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
76 srcMemRef, indices, storeOp.getNontemporal());
82 getStoreOpViewSizeForEachDim(
RewriterBase &rewriter, memref::StoreOp storeOp) {
83 MemRefType ldTy = storeOp.getMemRefType();
84 unsigned loadRank = ldTy.getRank();
95 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
96 return ldMatrixOp.getSrcMemref();
101 static nvgpu::LdMatrixOp rebuildLdMatrixOp(
RewriterBase &rewriter,
102 nvgpu::LdMatrixOp ldMatrixOp,
106 return nvgpu::LdMatrixOp::create(
107 rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
108 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
118 template <
typename TransferLikeOp>
119 static FailureOr<Value>
120 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
121 Value src = transferLikeOp.getBase();
122 if (isa<MemRefType>(src.
getType()))
129 static vector::TransferReadOp
131 vector::TransferReadOp transferReadOp,
Value srcMemRef,
133 Location loc = transferReadOp.getLoc();
134 return vector::TransferReadOp::create(
135 rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices,
136 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
137 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
147 static vector::TransferWriteOp
149 vector::TransferWriteOp transferWriteOp,
Value srcMemRef,
151 Location loc = transferWriteOp.getLoc();
152 return vector::TransferWriteOp::create(
153 rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices,
154 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
155 transferWriteOp.getInBoundsAttr());
166 template <
typename LoadStoreLikeOp,
167 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
168 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
169 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
170 assert(!
failed(failureOrSrcMemRef) &&
"Generic getSrcMemRef cannot be used");
171 return *failureOrSrcMemRef;
179 template <
typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
182 LoadStoreLikeOp loadStoreLikeOp) {
183 Location loc = loadStoreLikeOp.getLoc();
184 auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
185 rewriter, loc, getSrcMemRef(loadStoreLikeOp));
187 extractStridedMetadataOp.getConstifiedMixedSizes();
195 for (
auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
197 rewriter, loc, s0 - s1, {srcSize, indice}));
220 template <
typename LoadStoreLikeOp,
221 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
222 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
227 getGenericOpViewSizeForEachDim<
229 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
233 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
235 FailureOr<Value> failureOrSrcMemRef =
236 getFailureOrSrcMemRef(loadStoreLikeOp);
237 if (
failed(failureOrSrcMemRef))
239 "source is not a memref");
240 Value srcMemRef = *failureOrSrcMemRef;
241 auto ldStTy = cast<MemRefType>(srcMemRef.
getType());
242 unsigned loadStoreRank = ldStTy.getRank();
244 if (loadStoreRank == 0)
246 "0-D accesses don't need rewriting");
254 loadStoreLikeOp,
"no computation to extract: offsets are 0s");
260 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
261 assert(sizes.size() == loadStoreRank &&
262 "Expected one size per load dimension");
263 Location loc = loadStoreLikeOp.getLoc();
268 memref::SubViewOp::create(rewriter, loc, srcMemRef,
274 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
275 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
276 rewriter.
replaceOp(loadStoreLikeOp, newLoadStore->getResults());
285 LoadStoreLikeOpRewriter<
289 getLoadOpViewSizeForEachDim>,
290 LoadStoreLikeOpRewriter<
294 getStoreOpViewSizeForEachDim>,
295 LoadStoreLikeOpRewriter<
297 getLdMatrixOpSrcMemRef,
299 LoadStoreLikeOpRewriter<
300 vector::TransferReadOp,
301 getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
302 rebuildTransferReadOp>,
303 LoadStoreLikeOpRewriter<
304 vector::TransferWriteOp,
305 getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
306 rebuildTransferWriteOp>>(
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...