35 static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36 return loadOp.getMemRef();
41 static memref::LoadOp rebuildLoadOp(
RewriterBase &rewriter,
42 memref::LoadOp loadOp,
Value srcMemRef,
45 return rewriter.
create<memref::LoadOp>(loc, srcMemRef, indices,
46 loadOp.getNontemporal());
52 getLoadOpViewSizeForEachDim(
RewriterBase &rewriter, memref::LoadOp loadOp) {
53 MemRefType ldTy = loadOp.getMemRefType();
54 unsigned loadRank = ldTy.getRank();
65 static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66 return storeOp.getMemRef();
71 static memref::StoreOp rebuildStoreOp(
RewriterBase &rewriter,
72 memref::StoreOp storeOp,
Value srcMemRef,
75 return rewriter.
create<memref::StoreOp>(loc, storeOp.getValueToStore(),
77 storeOp.getNontemporal());
83 getStoreOpViewSizeForEachDim(
RewriterBase &rewriter, memref::StoreOp storeOp) {
84 MemRefType ldTy = storeOp.getMemRefType();
85 unsigned loadRank = ldTy.getRank();
96 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97 return ldMatrixOp.getSrcMemref();
102 static nvgpu::LdMatrixOp rebuildLdMatrixOp(
RewriterBase &rewriter,
103 nvgpu::LdMatrixOp ldMatrixOp,
107 return rewriter.
create<nvgpu::LdMatrixOp>(
108 loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
119 template <
typename TransferLikeOp>
120 static FailureOr<Value>
121 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122 Value src = transferLikeOp.getBase();
123 if (isa<MemRefType>(src.
getType()))
130 static vector::TransferReadOp
132 vector::TransferReadOp transferReadOp,
Value srcMemRef,
134 Location loc = transferReadOp.getLoc();
135 return rewriter.
create<vector::TransferReadOp>(
136 loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
148 static vector::TransferWriteOp
150 vector::TransferWriteOp transferWriteOp,
Value srcMemRef,
152 Location loc = transferWriteOp.getLoc();
153 return rewriter.
create<vector::TransferWriteOp>(
154 loc, transferWriteOp.getValue(), srcMemRef, indices,
155 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156 transferWriteOp.getInBoundsAttr());
167 template <
typename LoadStoreLikeOp,
168 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
169 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171 assert(!failed(failureOrSrcMemRef) &&
"Generic getSrcMemRef cannot be used");
172 return *failureOrSrcMemRef;
180 template <
typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
183 LoadStoreLikeOp loadStoreLikeOp) {
184 Location loc = loadStoreLikeOp.getLoc();
185 auto extractStridedMetadataOp =
186 rewriter.
create<memref::ExtractStridedMetadataOp>(
187 loc, getSrcMemRef(loadStoreLikeOp));
189 extractStridedMetadataOp.getConstifiedMixedSizes();
197 for (
auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
199 rewriter, loc, s0 - s1, {srcSize, indice}));
222 template <
typename LoadStoreLikeOp,
223 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
224 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
229 getGenericOpViewSizeForEachDim<
231 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
235 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
237 FailureOr<Value> failureOrSrcMemRef =
238 getFailureOrSrcMemRef(loadStoreLikeOp);
239 if (failed(failureOrSrcMemRef))
241 "source is not a memref");
242 Value srcMemRef = *failureOrSrcMemRef;
243 auto ldStTy = cast<MemRefType>(srcMemRef.
getType());
244 unsigned loadStoreRank = ldStTy.getRank();
246 if (loadStoreRank == 0)
248 "0-D accesses don't need rewriting");
256 loadStoreLikeOp,
"no computation to extract: offsets are 0s");
262 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
263 assert(sizes.size() == loadStoreRank &&
264 "Expected one size per load dimension");
265 Location loc = loadStoreLikeOp.getLoc();
270 rewriter.
create<memref::SubViewOp>(loc, srcMemRef,
275 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
276 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
277 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
278 rewriter.
replaceOp(loadStoreLikeOp, newLoadStore->getResults());
287 LoadStoreLikeOpRewriter<
291 getLoadOpViewSizeForEachDim>,
292 LoadStoreLikeOpRewriter<
296 getStoreOpViewSizeForEachDim>,
297 LoadStoreLikeOpRewriter<
299 getLdMatrixOpSrcMemRef,
301 LoadStoreLikeOpRewriter<
302 vector::TransferReadOp,
303 getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
304 rebuildTransferReadOp>,
305 LoadStoreLikeOpRewriter<
306 vector::TransferWriteOp,
307 getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
308 rebuildTransferWriteOp>>(
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...