21 #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
22 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
42 bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
43 if (!xferOp.getMask())
50 if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
54 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
58 Attribute addrSpace = memRefType.getMemorySpace();
59 if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
62 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
63 amdgpu::AddressSpace::FatRawBuffer)
67 if (!memRefType.isLastDimUnitStride())
74 for (
unsigned i : broadcastedDims)
75 unbroadcastedVectorShape[i] = 1;
76 unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
77 unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
78 requiresBroadcasting = !broadcastedDims.empty();
82 auto memrefElTy = memRefType.getElementType();
83 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
87 if (!isa<VectorType>(memrefElTy) &&
88 memrefElTy != xferOp.getVectorType().getElementType())
92 if (xferOp.hasOutOfBoundsDim())
95 if (xferOp.getVectorType().getRank() != 1)
98 xferOp,
"vector type is not rank 1, can't create masked load, needs "
106 struct TransferReadLowering final :
OpRewritePattern<vector::TransferReadOp> {
109 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
112 bool requiresBroadcasting =
false;
113 VectorType unbroadcastedVectorType;
115 unbroadcastedVectorType))) {
120 Value fill = rewriter.
create<vector::SplatOp>(loc, unbroadcastedVectorType,
121 readOp.getPadding());
123 loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
124 Value res = rewriter.
create<arith::SelectOp>(loc, unbroadcastedVectorType,
125 readOp.getMask(), load, fill);
128 if (requiresBroadcasting) {
129 res = rewriter.
create<vector::BroadcastOp>(loc, readOp.getVectorType(),
147 : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
148 AmdgpuTransferReadToLoadPass> {
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)
This pattern supports lowering of: vector.transfer_read to a combination of vector....
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
void runOnOperation() override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...