24#include "llvm/Support/MathExtras.h"
27#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
28#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
37 vector::MaskedLoadOp maskedOp) {
38 auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
42 Attribute addrSpace = memRefType.getMemorySpace();
43 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
46 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
47 amdgpu::AddressSpace::FatRawBuffer)
54 vector::MaskedLoadOp maskedOp,
56 VectorType vectorType = maskedOp.getVectorType();
58 builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
60 load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
61 load, maskedOp.getPassThru());
70 if (isa<VectorType>(broadcastOp.getSourceType()))
72 return broadcastOp.getSource();
76 "amdgpu.buffer_maskedload_needs_mask";
83 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
103 Value src = maskedOp.getBase();
105 VectorType vectorType = maskedOp.getVectorType();
106 int64_t vectorSize = vectorType.getNumElements();
107 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
110 auto stridedMetadata =
111 memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
113 stridedMetadata.getConstifiedMixedStrides();
115 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
118 std::tie(linearizedInfo, linearizedIndices) =
120 elementBitWidth, offset, sizes,
124 Value vectorSizeOffset =
130 Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
133 Value isOutofBounds = arith::CmpIOp::create(
134 rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
138 rewriter, loc, llvm::divideCeil(32, elementBitWidth));
139 Value isNotWordAligned = arith::CmpIOp::create(
140 rewriter, loc, arith::CmpIPredicate::ne,
141 arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
149 arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
155 scf::YieldOp::create(builder, loc, readResult);
161 scf::YieldOp::create(rewriter, loc, res);
165 scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
173struct FullMaskedLoadToConditionalLoad
177 LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
179 FailureOr<Value> maybeCond =
matchFullMask(rewriter, loadOp.getMask());
180 if (failed(maybeCond)) {
184 Value cond = maybeCond.value();
188 scf::YieldOp::create(rewriter, loc, res);
191 scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
193 auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
200struct FullMaskedStoreToConditionalStore
204 LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
207 if (failed(maybeCond)) {
210 Value cond = maybeCond.value();
213 vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
214 storeOp.getBase(), storeOp.getIndices());
215 scf::YieldOp::create(rewriter, loc);
218 scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
228 patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
229 FullMaskedStoreToConditionalStore>(
patterns.getContext(),
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru)
static constexpr char kMaskedloadNeedsMask[]
static FailureOr< Value > matchFullMask(OpBuilder &b, Value val)
Check if the given value comes from a broadcasted i1 condition.
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, vector::MaskedLoadOp maskedOp)
This pattern supports lowering of: vector.maskedload to vector.load and arith.select if the memref is...
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents a single result from folding an operation.
OpT getOperation()
Return the current operation being transformed.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext & getContext()
Return the MLIR context for the current operation being transformed.
void signalPassFailure()
Signal that some invariant was broken when running.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize