24#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
25#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
31 auto sourceType = cast<BaseMemRefType>(source.
getType());
36 StridedLayoutAttr::get(source.
getContext(), staticOffsets.front(), {});
37 return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
38 sourceType.getMemorySpace());
53static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
57 auto sourceType = cast<MemRefType>(source.
getType());
58 auto sourceRank =
static_cast<unsigned>(sourceType.getRank());
60 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
64 newExtractStridedMetadata =
65 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
68 auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
72 : rewriter.getIndexAttr(dim);
76 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
77 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
80 origStrides.reserve(sourceRank);
83 strides.reserve(sourceRank);
87 for (
auto i : llvm::seq(0u, sourceRank)) {
88 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
90 if (!subStrides.empty()) {
92 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
95 origStrides.emplace_back(origStride);
98 auto &&[expr, values] =
102 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
108 auto &&[base, offset, ignore] =
111 return memref::ReinterpretCastOp::create(rewriter, loc, retType, base, offset,
118 return type.getRank() != 0;
123 return type.getLayout().isIdentity() ||
124 isa<StridedLayoutAttr>(type.getLayout());
131 LogicalResult matchAndRewrite(memref::LoadOp op,
153 LogicalResult matchAndRewrite(memref::StoreOp op,
167 Value value = op.getValue();
176 LogicalResult matchAndRewrite(memref::SubViewOp op,
177 PatternRewriter &rewriter)
const override {
181 Value memref = op.getSource();
188 Location loc = op.getLoc();
189 SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
190 SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
191 SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
192 auto &&[base, finalOffset, strides] =
195 auto srcType = cast<MemRefType>(memref.
getType());
196 auto resultType = cast<MemRefType>(op.getType());
197 unsigned subRank =
static_cast<unsigned>(resultType.getRank());
199 llvm::SmallBitVector droppedDims = op.getDroppedDims();
201 SmallVector<OpFoldResult> finalSizes;
202 finalSizes.reserve(subRank);
204 SmallVector<OpFoldResult> finalStrides;
205 finalStrides.reserve(subRank);
207 for (
auto i : llvm::seq(0u,
static_cast<unsigned>(srcType.getRank()))) {
208 if (droppedDims.test(i))
211 finalSizes.push_back(subSizes[i]);
212 finalStrides.push_back(strides[i]);
216 op, resultType, base, finalOffset, finalSizes, finalStrides);
221struct GpuDecomposeMemrefsPass
224 void runOnOperation()
override {
230 return signalPassFailure();
237 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides={})
Base type for affine expression.
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...