24 #include "llvm/Support/MathExtras.h"
27 #define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
28 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
37 vector::MaskedLoadOp maskedOp) {
38 auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
42 Attribute addrSpace = memRefType.getMemorySpace();
43 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
46 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
47 amdgpu::AddressSpace::FatRawBuffer)
54 vector::MaskedLoadOp maskedOp,
56 VectorType vectorType = maskedOp.getVectorType();
57 Value load = vector::LoadOp::create(
58 builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
60 load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
61 load, maskedOp.getPassThru());
70 if (isa<VectorType>(broadcastOp.getSourceType()))
72 return broadcastOp.getSource();
76 "amdgpu.buffer_maskedload_needs_mask";
83 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
103 Value src = maskedOp.getBase();
105 VectorType vectorType = maskedOp.getVectorType();
106 int64_t vectorSize = vectorType.getNumElements();
107 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
110 auto stridedMetadata =
111 memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
113 stridedMetadata.getConstifiedMixedStrides();
115 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
118 std::tie(linearizedInfo, linearizedIndices) =
120 elementBitWidth, offset, sizes,
124 Value vectorSizeOffset =
130 Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
133 Value isOutofBounds = arith::CmpIOp::create(
134 rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
139 Value isNotWordAligned = arith::CmpIOp::create(
140 rewriter, loc, arith::CmpIPredicate::ne,
141 arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
149 arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
155 scf::YieldOp::create(builder, loc, readResult);
161 scf::YieldOp::create(rewriter, loc, res);
165 scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
173 struct FullMaskedLoadToConditionalLoad
177 LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
179 FailureOr<Value> maybeCond =
matchFullMask(rewriter, loadOp.getMask());
184 Value cond = maybeCond.value();
188 scf::YieldOp::create(rewriter, loc, res);
191 scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
193 auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
200 struct FullMaskedStoreToConditionalStore
204 LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
206 FailureOr<Value> maybeCond =
matchFullMask(rewriter, storeOp.getMask());
210 Value cond = maybeCond.value();
213 vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
214 storeOp.getBase(), storeOp.getIndices());
215 scf::YieldOp::create(rewriter, loc);
218 scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
228 patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
229 FullMaskedStoreToConditionalStore>(
patterns.getContext(),
234 : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
239 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru)
static constexpr char kMaskedloadNeedsMask[]
static FailureOr< Value > matchFullMask(OpBuilder &b, Value val)
Check if the given value comes from a broadcasted i1 condition.
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, vector::MaskedLoadOp maskedOp)
This pattern supports lowering of: vector.maskedload to vector.load and arith.select if the memref is...
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void runOnOperation() override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize