24#include "llvm/Support/MathExtras.h"
27#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
28#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
37 auto memRefType = dyn_cast<MemRefType>(type);
41 Attribute addrSpace = memRefType.getMemorySpace();
42 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
45 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
46 amdgpu::AddressSpace::FatRawBuffer)
53 vector::MaskedLoadOp maskedOp,
55 VectorType vectorType = maskedOp.getVectorType();
57 builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59 load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
60 load, maskedOp.getPassThru());
69 if (isa<VectorType>(broadcastOp.getSourceType()))
71 return broadcastOp.getSource();
75 "amdgpu.buffer_maskedload_needs_mask";
82 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
89 maskedOp,
"isn't a load from a fat buffer resource");
103 Value src = maskedOp.getBase();
105 VectorType vectorType = maskedOp.getVectorType();
106 int64_t vectorSize = vectorType.getNumElements();
107 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
110 auto stridedMetadata =
111 memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
113 stridedMetadata.getConstifiedMixedStrides();
115 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
118 std::tie(linearizedInfo, linearizedIndices) =
120 elementBitWidth, offset, sizes,
124 Value vectorSizeOffset =
130 Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
133 Value isOutofBounds = arith::CmpIOp::create(
134 rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
138 rewriter, loc, llvm::divideCeil(32, elementBitWidth));
139 Value isNotWordAligned = arith::CmpIOp::create(
140 rewriter, loc, arith::CmpIPredicate::ne,
141 arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
149 arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
155 scf::YieldOp::create(builder, loc, readResult);
161 scf::YieldOp::create(rewriter, loc, res);
165 scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
173struct FullMaskedLoadToConditionalLoad
177 LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
181 loadOp,
"buffer loads are handled by a more specialized pattern");
183 FailureOr<Value> maybeCond =
matchFullMask(rewriter, loadOp.getMask());
184 if (failed(maybeCond)) {
186 "isn't loading a broadcasted scalar");
189 Value cond = maybeCond.value();
193 scf::YieldOp::create(rewriter, loc, res);
196 scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
198 auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
205struct FullMaskedStoreToConditionalStore
209 LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
221 if (failed(maybeCond)) {
224 Value cond = maybeCond.value();
227 vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
228 storeOp.getBase(), storeOp.getIndices());
229 scf::YieldOp::create(rewriter, loc);
232 scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
242 patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
243 FullMaskedStoreToConditionalStore>(
patterns.getContext(),
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru)
static constexpr char kMaskedloadNeedsMask[]
static FailureOr< Value > matchFullMask(OpBuilder &b, Value val)
Check if the given value comes from a broadcasted i1 condition.
static LogicalResult hasBufferAddressSpace(Type type)
This pattern supports lowering of: vector.maskedload to vector.load and arith.select if the memref is...
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents a single result from folding an operation.
OpT getOperation()
Return the current operation being transformed.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext & getContext()
Return the MLIR context for the current operation being transformed.
void signalPassFailure()
Signal that some invariant was broken when running.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize