24 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
25 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
31 auto sourceType = cast<BaseMemRefType>(source.
getType());
38 sourceType.getMemorySpace());
53 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
57 auto sourceType = cast<MemRefType>(source.
getType());
58 auto sourceRank =
static_cast<unsigned>(sourceType.getRank());
60 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
64 newExtractStridedMetadata =
65 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
68 auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
72 : rewriter.getIndexAttr(dim);
76 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
77 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
80 origStrides.reserve(sourceRank);
83 strides.reserve(sourceRank);
87 for (
auto i : llvm::seq(0u, sourceRank)) {
88 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
90 if (!subStrides.empty()) {
92 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
95 origStrides.emplace_back(origStride);
98 auto &&[expr, values] =
102 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
108 auto &&[base, offset, ignore] =
111 return memref::ReinterpretCastOp::create(rewriter, loc, retType, base, offset,
117 auto type = cast<MemRefType>(val.
getType());
118 return type.getRank() != 0;
122 auto type = cast<MemRefType>(val.
getType());
123 return type.getLayout().isIdentity() ||
124 isa<StridedLayoutAttr>(type.getLayout());
131 LogicalResult matchAndRewrite(memref::LoadOp op,
136 Value memref = op.getMemref();
153 LogicalResult matchAndRewrite(memref::StoreOp op,
158 Value memref = op.getMemref();
167 Value value = op.getValue();
176 LogicalResult matchAndRewrite(memref::SubViewOp op,
181 Value memref = op.getSource();
192 auto &&[base, finalOffset, strides] =
195 auto srcType = cast<MemRefType>(memref.
getType());
196 auto resultType = cast<MemRefType>(op.getType());
197 unsigned subRank =
static_cast<unsigned>(resultType.getRank());
199 llvm::SmallBitVector droppedDims = op.getDroppedDims();
202 finalSizes.reserve(subRank);
205 finalStrides.reserve(subRank);
207 for (
auto i : llvm::seq(0u,
static_cast<unsigned>(srcType.getRank()))) {
208 if (droppedDims.test(i))
211 finalSizes.push_back(subSizes[i]);
212 finalStrides.push_back(strides[i]);
216 op, resultType, base, finalOffset, finalSizes, finalStrides);
221 struct GpuDecomposeMemrefsPass
222 :
public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
224 void runOnOperation()
override {
230 return signalPassFailure();
237 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides={})
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...