36 return loadOp.getMemRef();
41 static memref::LoadOp rebuildLoadOp(
RewriterBase &rewriter,
42 memref::LoadOp loadOp,
Value srcMemRef,
45 return rewriter.
create<memref::LoadOp>(loc, srcMemRef, indices,
46 loadOp.getNontemporal());
52 getLoadOpViewSizeForEachDim(
RewriterBase &rewriter, memref::LoadOp loadOp) {
53 MemRefType ldTy = loadOp.getMemRefType();
54 unsigned loadRank = ldTy.getRank();
66 return storeOp.getMemRef();
71 static memref::StoreOp rebuildStoreOp(
RewriterBase &rewriter,
72 memref::StoreOp storeOp,
Value srcMemRef,
75 return rewriter.
create<memref::StoreOp>(loc, storeOp.getValueToStore(),
77 storeOp.getNontemporal());
83 getStoreOpViewSizeForEachDim(
RewriterBase &rewriter, memref::StoreOp storeOp) {
84 MemRefType ldTy = storeOp.getMemRefType();
85 unsigned loadRank = ldTy.getRank();
96 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97 return ldMatrixOp.getSrcMemref();
102 static nvgpu::LdMatrixOp rebuildLdMatrixOp(
RewriterBase &rewriter,
103 nvgpu::LdMatrixOp ldMatrixOp,
107 return rewriter.
create<nvgpu::LdMatrixOp>(
108 loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
119 template <
typename TransferLikeOp>
121 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122 Value src = transferLikeOp.getSource();
123 if (isa<MemRefType>(src.
getType()))
130 static vector::TransferReadOp
132 vector::TransferReadOp transferReadOp,
Value srcMemRef,
134 Location loc = transferReadOp.getLoc();
135 return rewriter.
create<vector::TransferReadOp>(
136 loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
148 static vector::TransferWriteOp
150 vector::TransferWriteOp transferWriteOp,
Value srcMemRef,
152 Location loc = transferWriteOp.getLoc();
153 return rewriter.
create<vector::TransferWriteOp>(
154 loc, transferWriteOp.getValue(), srcMemRef, indices,
155 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156 transferWriteOp.getInBoundsAttr());
167 template <
typename LoadStoreLikeOp,
169 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171 assert(!
failed(failureOrSrcMemRef) &&
"Generic getSrcMemRef cannot be used");
172 return *failureOrSrcMemRef;
180 template <
typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
183 LoadStoreLikeOp loadStoreLikeOp) {
184 Location loc = loadStoreLikeOp.getLoc();
185 auto extractStridedMetadataOp =
186 rewriter.
create<memref::ExtractStridedMetadataOp>(
187 loc, getSrcMemRef(loadStoreLikeOp));
189 extractStridedMetadataOp.getConstifiedMixedSizes();
197 for (
auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
199 rewriter, loc, s0 - s1, {srcSize, indice}));
222 template <
typename LoadStoreLikeOp,
224 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
229 getGenericOpViewSizeForEachDim<
231 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
235 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
238 getFailureOrSrcMemRef(loadStoreLikeOp);
239 if (
failed(failureOrSrcMemRef))
241 "source is not a memref");
242 Value srcMemRef = *failureOrSrcMemRef;
243 auto ldStTy = cast<MemRefType>(srcMemRef.
getType());
244 unsigned loadStoreRank = ldStTy.getRank();
246 if (loadStoreRank == 0)
248 "0-D accesses don't need rewriting");
254 if (std::all_of(indices.begin(), indices.end(),
256 return isConstantIntValue(opFold, 0);
259 loadStoreLikeOp,
"no computation to extract: offsets are 0s");
265 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
266 assert(sizes.size() == loadStoreRank &&
267 "Expected one size per load dimension");
268 Location loc = loadStoreLikeOp.getLoc();
273 rewriter.
create<memref::SubViewOp>(loc, srcMemRef,
278 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
279 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
280 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
281 rewriter.
replaceOp(loadStoreLikeOp, newLoadStore->getResults());
290 LoadStoreLikeOpRewriter<
294 getLoadOpViewSizeForEachDim>,
295 LoadStoreLikeOpRewriter<
299 getStoreOpViewSizeForEachDim>,
300 LoadStoreLikeOpRewriter<
302 getLdMatrixOpSrcMemRef,
304 LoadStoreLikeOpRewriter<
305 vector::TransferReadOp,
306 getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
307 rebuildTransferReadOp>,
308 LoadStoreLikeOpRewriter<
309 vector::TransferWriteOp,
310 getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
311 rebuildTransferWriteOp>>(
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...