35 static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36 return loadOp.getMemRef();
41 static memref::LoadOp rebuildLoadOp(
RewriterBase &rewriter,
42 memref::LoadOp loadOp,
Value srcMemRef,
45 return rewriter.
create<memref::LoadOp>(loc, srcMemRef, indices,
46 loadOp.getNontemporal());
52 getLoadOpViewSizeForEachDim(
RewriterBase &rewriter, memref::LoadOp loadOp) {
53 MemRefType ldTy = loadOp.getMemRefType();
54 unsigned loadRank = ldTy.getRank();
65 static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66 return storeOp.getMemRef();
71 static memref::StoreOp rebuildStoreOp(
RewriterBase &rewriter,
72 memref::StoreOp storeOp,
Value srcMemRef,
75 return rewriter.
create<memref::StoreOp>(loc, storeOp.getValueToStore(),
77 storeOp.getNontemporal());
83 getStoreOpViewSizeForEachDim(
RewriterBase &rewriter, memref::StoreOp storeOp) {
84 MemRefType ldTy = storeOp.getMemRefType();
85 unsigned loadRank = ldTy.getRank();
96 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97 return ldMatrixOp.getSrcMemref();
102 static nvgpu::LdMatrixOp rebuildLdMatrixOp(
RewriterBase &rewriter,
103 nvgpu::LdMatrixOp ldMatrixOp,
107 return rewriter.
create<nvgpu::LdMatrixOp>(
108 loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
119 template <
typename TransferLikeOp>
120 static FailureOr<Value>
121 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122 Value src = transferLikeOp.getSource();
123 if (isa<MemRefType>(src.
getType()))
130 static vector::TransferReadOp
132 vector::TransferReadOp transferReadOp,
Value srcMemRef,
134 Location loc = transferReadOp.getLoc();
135 return rewriter.
create<vector::TransferReadOp>(
136 loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
148 static vector::TransferWriteOp
150 vector::TransferWriteOp transferWriteOp,
Value srcMemRef,
152 Location loc = transferWriteOp.getLoc();
153 return rewriter.
create<vector::TransferWriteOp>(
154 loc, transferWriteOp.getValue(), srcMemRef, indices,
155 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156 transferWriteOp.getInBoundsAttr());
167 template <
typename LoadStoreLikeOp,
168 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
169 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171 assert(!failed(failureOrSrcMemRef) &&
"Generic getSrcMemRef cannot be used");
172 return *failureOrSrcMemRef;
180 template <
typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
183 LoadStoreLikeOp loadStoreLikeOp) {
184 Location loc = loadStoreLikeOp.getLoc();
185 auto extractStridedMetadataOp =
186 rewriter.
create<memref::ExtractStridedMetadataOp>(
187 loc, getSrcMemRef(loadStoreLikeOp));
189 extractStridedMetadataOp.getConstifiedMixedSizes();
197 for (
auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
199 rewriter, loc, s0 - s1, {srcSize, indice}));
222 template <
typename LoadStoreLikeOp,
223 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
224 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
229 getGenericOpViewSizeForEachDim<
231 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
235 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
237 FailureOr<Value> failureOrSrcMemRef =
238 getFailureOrSrcMemRef(loadStoreLikeOp);
239 if (failed(failureOrSrcMemRef))
241 "source is not a memref");
242 Value srcMemRef = *failureOrSrcMemRef;
243 auto ldStTy = cast<MemRefType>(srcMemRef.
getType());
244 unsigned loadStoreRank = ldStTy.getRank();
246 if (loadStoreRank == 0)
248 "0-D accesses don't need rewriting");
254 if (std::all_of(indices.begin(), indices.end(),
256 return isConstantIntValue(opFold, 0);
259 loadStoreLikeOp,
"no computation to extract: offsets are 0s");
265 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
266 assert(sizes.size() == loadStoreRank &&
267 "Expected one size per load dimension");
268 Location loc = loadStoreLikeOp.getLoc();
273 rewriter.
create<memref::SubViewOp>(loc, srcMemRef,
278 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
279 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
280 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
281 rewriter.
replaceOp(loadStoreLikeOp, newLoadStore->getResults());
290 LoadStoreLikeOpRewriter<
294 getLoadOpViewSizeForEachDim>,
295 LoadStoreLikeOpRewriter<
299 getStoreOpViewSizeForEachDim>,
300 LoadStoreLikeOpRewriter<
302 getLdMatrixOpSrcMemRef,
304 LoadStoreLikeOpRewriter<
305 vector::TransferReadOp,
306 getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
307 rebuildTransferReadOp>,
308 LoadStoreLikeOpRewriter<
309 vector::TransferWriteOp,
310 getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
311 rebuildTransferWriteOp>>(
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...