16#include "llvm/ADT/TypeSwitch.h"
19#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
20#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
34 Value &memrefBase, StringRef role) {
40 .Case([&](memref::SubViewOp subviewOp) {
42 rewriter, loc, subviewOp.getMixedOffsets(),
43 subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
indices,
45 memrefBase = subviewOp.getSource();
48 .Case([&](memref::ExpandShapeOp expandShapeOp) {
50 loc, rewriter, expandShapeOp,
indices, resolvedIndices,
false);
51 memrefBase = expandShapeOp.getViewSource();
54 .Case([&](memref::CollapseShapeOp collapseShapeOp) {
56 loc, rewriter, collapseShapeOp,
indices, resolvedIndices);
57 memrefBase = collapseShapeOp.getViewSource();
62 op, (role +
" producer is not one of SubViewOp, ExpandShapeOp, or "
75 Value memrefSource, memrefDest;
79 sourceIndices, memrefSource,
"source");
81 if (failed(foldSrcResult)) {
82 memrefSource = op.getSrc();
83 sourceIndices = op.getSrcIndices();
88 destIndices, memrefDest,
"destination");
90 if (failed(foldDstResult)) {
91 memrefDest = op.getDst();
92 destIndices = op.getDstIndices();
95 if (failed(foldSrcResult) && failed(foldDstResult))
99 op, memrefSource, sourceIndices, memrefDest, destIndices,
100 op.getTransferType(), op.getAsync());
106template <
typename OpTy>
114 Value globalBase, ldsBase;
116 LogicalResult didFoldGlobal =
118 globalIndices, globalBase,
"global");
119 if (failed(didFoldGlobal)) {
120 globalBase = op.getGlobal();
121 globalIndices = op.getGlobalIndices();
124 LogicalResult didFoldLds =
126 ldsIndices, ldsBase,
"lds");
127 if (failed(didFoldLds)) {
128 ldsBase = op.getLds();
129 ldsIndices = op.getLdsIndices();
132 if (failed(didFoldGlobal) && failed(didFoldLds))
136 globalIndices, ldsBase, ldsIndices);
150 op.getSrcIndices(), sourceIndices, memrefSource,
155 memrefSource, sourceIndices);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class implements the operand iterators for the Operation class.
OpT getOperation()
Return the current operation being transformed.
Operation is the basic unit of execution within MLIR.
MLIRContext & getContext()
Return the MLIR context for the current operation being transformed.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AmdgpuFoldMemRefOpsPassBase Base
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, Value view, mlir::OperandRange indices, SmallVectorImpl< Value > &resolvedIndices, Value &memrefBase, StringRef role)
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(TransposeLoadOp op, PatternRewriter &rewriter) const override