26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
44 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
48 auto sourceType = cast<MemRefType>(source.
getType());
49 auto sourceRank =
static_cast<unsigned>(sourceType.getRank());
51 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
55 newExtractStridedMetadata =
56 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, source);
67 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
68 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
71 origStrides.reserve(sourceRank);
74 strides.reserve(sourceRank);
78 for (
auto i : llvm::seq(0u, sourceRank)) {
79 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
81 if (!subStrides.empty()) {
83 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
86 origStrides.emplace_back(origStride);
89 auto &&[expr, values] =
93 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
99 auto &&[base, offset, ignore] =
101 auto retType = cast<MemRefType>(base.getType());
102 return rewriter.
create<memref::ReinterpretCastOp>(loc, retType, base, offset,
103 std::nullopt, std::nullopt);
107 auto type = cast<MemRefType>(val.
getType());
108 return type.getRank() != 0;
112 auto type = cast<MemRefType>(val.
getType());
113 return type.getLayout().isIdentity() ||
114 isa<StridedLayoutAttr>(type.getLayout());
121 LogicalResult matchAndRewrite(memref::LoadOp op,
126 Value memref = op.getMemref();
143 LogicalResult matchAndRewrite(memref::StoreOp op,
148 Value memref = op.getMemref();
157 Value value = op.getValue();
166 LogicalResult matchAndRewrite(memref::SubViewOp op,
171 Value memref = op.getSource();
182 auto &&[base, finalOffset, strides] =
185 auto srcType = cast<MemRefType>(memref.
getType());
186 auto resultType = cast<MemRefType>(op.getType());
187 unsigned subRank =
static_cast<unsigned>(resultType.getRank());
189 llvm::SmallBitVector droppedDims = op.getDroppedDims();
192 finalSizes.reserve(subRank);
195 finalStrides.reserve(subRank);
197 for (
auto i : llvm::seq(0u,
static_cast<unsigned>(srcType.getRank()))) {
198 if (droppedDims.test(i))
201 finalSizes.push_back(subSizes[i]);
202 finalStrides.push_back(strides[i]);
206 op, resultType, base, finalOffset, finalSizes, finalStrides);
211 struct GpuDecomposeMemrefsPass
212 :
public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
214 void runOnOperation()
override {
221 return signalPassFailure();
228 patterns.
insert<FlattenLoad, FlattenStore, FlattenSubview>(
233 return std::make_unique<GpuDecomposeMemrefsPass>();
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static bool checkLayout(Value val)
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides=std::nullopt)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
std::unique_ptr< Pass > createGpuDecomposeMemrefsPass()
Pass decomposes memref ops inside gpu.launch body.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...