25 #include "llvm/Support/MathExtras.h"
28 #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
29 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
49 bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
50 if (!xferOp.getMask())
57 if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
61 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
65 Attribute addrSpace = memRefType.getMemorySpace();
66 if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
69 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
70 amdgpu::AddressSpace::FatRawBuffer)
74 if (!memRefType.isLastDimUnitStride())
77 if (memRefType.getElementTypeBitWidth() < 8)
84 for (
unsigned i : broadcastedDims)
85 unbroadcastedVectorShape[i] = 1;
86 unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
87 unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
88 requiresBroadcasting = !broadcastedDims.empty();
92 auto memrefElTy = memRefType.getElementType();
93 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
97 if (!isa<VectorType>(memrefElTy) &&
98 memrefElTy != xferOp.getVectorType().getElementType())
102 if (xferOp.hasOutOfBoundsDim())
105 if (xferOp.getVectorType().getRank() != 1)
108 xferOp,
"vector type is not rank 1, can't create masked load, needs "
115 vector::TransferReadOp readOp,
116 bool requiresBroadcasting,
117 VectorType unbroadcastedVectorType) {
118 Value fill = builder.
create<vector::SplatOp>(loc, unbroadcastedVectorType,
119 readOp.getPadding());
121 loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
122 Value res = builder.
create<arith::SelectOp>(loc, unbroadcastedVectorType,
123 readOp.getMask(), load, fill);
125 if (requiresBroadcasting) {
126 res = builder.
create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
132 "amdgpu.buffer_transfer_read_needs_mask";
136 struct TransferReadLowering final :
OpRewritePattern<vector::TransferReadOp> {
139 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
144 bool requiresBroadcasting =
false;
145 VectorType unbroadcastedVectorType;
147 unbroadcastedVectorType))) {
152 Value src = readOp.getSource();
154 VectorType vectorType = readOp.getVectorType();
155 int64_t vectorSize = vectorType.getNumElements();
156 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
159 auto stridedMetadata =
160 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, src);
162 stridedMetadata.getConstifiedMixedStrides();
164 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
166 std::tie(std::ignore, linearizedIndices) =
168 elementBitWidth, offset, sizes,
179 unsigned sourceRank = cast<ShapedType>(src.
getType()).getRank();
185 size_t symbolIndex = 0;
186 for (
size_t i = 0; i < sourceRank; ++i) {
193 strideExpr = symbols[symbolIndex++];
194 offsetValues.push_back(
201 sizeExpr = symbols[symbolIndex++];
202 offsetValues.push_back(
206 productExpressions.push_back(strideExpr * sizeExpr);
210 0, symbolIndex, productExpressions,
213 rewriter.
create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
216 Value vectorSizeOffset =
217 rewriter.
create<arith::ConstantIndexOp>(loc, vectorSize);
220 Value delta = rewriter.
create<arith::SubIOp>(loc, totalSize, linearIndex);
223 Value isOutofBounds = rewriter.
create<arith::CmpIOp>(
224 loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
227 Value elementsPerWord = rewriter.
create<arith::ConstantIndexOp>(
229 Value isNotWordAligned = rewriter.
create<arith::CmpIOp>(
230 loc, arith::CmpIPredicate::ne,
231 rewriter.
create<arith::RemUIOp>(loc, delta, elementsPerWord),
232 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
239 rewriter.
create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
245 builder.
create<scf::YieldOp>(loc, readResult);
250 builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
251 rewriter.
create<scf::YieldOp>(loc, res);
255 rewriter.
create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
271 : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
272 AmdgpuTransferReadToLoadPass> {
277 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)
This pattern supports lowering of: vector.transfer_read to a combination of vector....
static constexpr char kTransferReadNeedsMask[]
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::TransferReadOp readOp, bool requiresBroadcasting, VectorType unbroadcastedVectorType)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
void runOnOperation() override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...