26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
33 auto sourceType = cast<BaseMemRefType>(source.
getType());
40 sourceType.getMemorySpace());
55 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
59 auto sourceType = cast<MemRefType>(source.
getType());
60 auto sourceRank =
static_cast<unsigned>(sourceType.getRank());
62 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
66 newExtractStridedMetadata =
67 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, source);
70 auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
74 : rewriter.getIndexAttr(dim);
78 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
79 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
82 origStrides.reserve(sourceRank);
85 strides.reserve(sourceRank);
89 for (
auto i : llvm::seq(0u, sourceRank)) {
90 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
92 if (!subStrides.empty()) {
94 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
97 origStrides.emplace_back(origStride);
100 auto &&[expr, values] =
104 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
110 auto &&[base, offset, ignore] =
113 return rewriter.
create<memref::ReinterpretCastOp>(loc, retType, base, offset,
119 auto type = cast<MemRefType>(val.
getType());
120 return type.getRank() != 0;
124 auto type = cast<MemRefType>(val.
getType());
125 return type.getLayout().isIdentity() ||
126 isa<StridedLayoutAttr>(type.getLayout());
133 LogicalResult matchAndRewrite(memref::LoadOp op,
138 Value memref = op.getMemref();
155 LogicalResult matchAndRewrite(memref::StoreOp op,
160 Value memref = op.getMemref();
169 Value value = op.getValue();
178 LogicalResult matchAndRewrite(memref::SubViewOp op,
183 Value memref = op.getSource();
194 auto &&[base, finalOffset, strides] =
197 auto srcType = cast<MemRefType>(memref.
getType());
198 auto resultType = cast<MemRefType>(op.getType());
199 unsigned subRank =
static_cast<unsigned>(resultType.getRank());
201 llvm::SmallBitVector droppedDims = op.getDroppedDims();
204 finalSizes.reserve(subRank);
207 finalStrides.reserve(subRank);
209 for (
auto i : llvm::seq(0u,
static_cast<unsigned>(srcType.getRank()))) {
210 if (droppedDims.test(i))
213 finalSizes.push_back(subSizes[i]);
214 finalStrides.push_back(strides[i]);
218 op, resultType, base, finalOffset, finalSizes, finalStrides);
223 struct GpuDecomposeMemrefsPass
224 :
public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
226 void runOnOperation()
override {
232 return signalPassFailure();
239 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides={})
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr getAffineSymbolExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...