26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
33 auto sourceType = cast<BaseMemRefType>(source.
getType());
40 sourceType.getMemorySpace());
55 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
59 auto sourceType = cast<MemRefType>(source.
getType());
60 auto sourceRank =
static_cast<unsigned>(sourceType.getRank());
62 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
66 newExtractStridedMetadata =
67 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, source);
78 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
79 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
82 origStrides.reserve(sourceRank);
85 strides.reserve(sourceRank);
89 for (
auto i : llvm::seq(0u, sourceRank)) {
90 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
92 if (!subStrides.empty()) {
94 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
97 origStrides.emplace_back(origStride);
100 auto &&[expr, values] =
104 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
110 auto &&[base, offset, ignore] =
113 return rewriter.
create<memref::ReinterpretCastOp>(loc, retType, base, offset,
114 std::nullopt, std::nullopt);
118 auto type = cast<MemRefType>(val.
getType());
119 return type.getRank() != 0;
123 auto type = cast<MemRefType>(val.
getType());
124 return type.getLayout().isIdentity() ||
125 isa<StridedLayoutAttr>(type.getLayout());
132 LogicalResult matchAndRewrite(memref::LoadOp op,
137 Value memref = op.getMemref();
154 LogicalResult matchAndRewrite(memref::StoreOp op,
159 Value memref = op.getMemref();
168 Value value = op.getValue();
177 LogicalResult matchAndRewrite(memref::SubViewOp op,
182 Value memref = op.getSource();
193 auto &&[base, finalOffset, strides] =
196 auto srcType = cast<MemRefType>(memref.
getType());
197 auto resultType = cast<MemRefType>(op.getType());
198 unsigned subRank =
static_cast<unsigned>(resultType.getRank());
200 llvm::SmallBitVector droppedDims = op.getDroppedDims();
203 finalSizes.reserve(subRank);
206 finalStrides.reserve(subRank);
208 for (
auto i : llvm::seq(0u,
static_cast<unsigned>(srcType.getRank()))) {
209 if (droppedDims.test(i))
212 finalSizes.push_back(subSizes[i]);
213 finalStrides.push_back(strides[i]);
217 op, resultType, base, finalOffset, finalSizes, finalStrides);
222 struct GpuDecomposeMemrefsPass
223 :
public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
225 void runOnOperation()
override {
231 return signalPassFailure();
238 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
243 return std::make_unique<GpuDecomposeMemrefsPass>();
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides=std::nullopt)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::unique_ptr< Pass > createGpuDecomposeMemrefsPass()
Pass decomposes memref ops inside gpu.launch body.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...