MLIR 22.0.0git
MaskedloadToLoad.cpp
Go to the documentation of this file.
1//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
24#include "llvm/Support/MathExtras.h"
25
26namespace mlir::amdgpu {
27#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
28#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
29} // namespace mlir::amdgpu
30
31using namespace mlir;
32using namespace mlir::amdgpu;
33
34/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
35/// and `arith.select` if the memref is in buffer address space.
36static LogicalResult hasBufferAddressSpace(Type type) {
37 auto memRefType = dyn_cast<MemRefType>(type);
38 if (!memRefType)
39 return failure();
40
41 Attribute addrSpace = memRefType.getMemorySpace();
42 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
43 return failure();
44
45 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
46 amdgpu::AddressSpace::FatRawBuffer)
47 return failure();
48
49 return success();
50}
51
53 vector::MaskedLoadOp maskedOp,
54 bool passthru) {
55 VectorType vectorType = maskedOp.getVectorType();
56 Value load = vector::LoadOp::create(
57 builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
58 if (passthru)
59 load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
60 load, maskedOp.getPassThru());
61 return load;
62}
63
64/// Check if the given value comes from a broadcasted i1 condition.
65static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
66 auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
67 if (!broadcastOp)
68 return failure();
69 if (isa<VectorType>(broadcastOp.getSourceType()))
70 return failure();
71 return broadcastOp.getSource();
72}
73
74static constexpr char kMaskedloadNeedsMask[] =
75 "amdgpu.buffer_maskedload_needs_mask";
76
77namespace {
78
79struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
81
82 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
83 PatternRewriter &rewriter) const override {
84 if (maskedOp->hasAttr(kMaskedloadNeedsMask))
85 return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
86
87 if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
88 return rewriter.notifyMatchFailure(
89 maskedOp, "isn't a load from a fat buffer resource");
90 }
91
92 // Check if this is either a full inbounds load or an empty, oob load. If
93 // so, take the fast path and don't generate an if condition, because we
94 // know doing the oob load is always safe.
95 if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
96 Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
97 maskedOp, /*passthru=*/true);
98 rewriter.replaceOp(maskedOp, load);
99 return success();
100 }
101
102 Location loc = maskedOp.getLoc();
103 Value src = maskedOp.getBase();
104
105 VectorType vectorType = maskedOp.getVectorType();
106 int64_t vectorSize = vectorType.getNumElements();
107 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
108 SmallVector<OpFoldResult> indices = maskedOp.getIndices();
109
110 auto stridedMetadata =
111 memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
113 stridedMetadata.getConstifiedMixedStrides();
114 SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
115 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
116 memref::LinearizedMemRefInfo linearizedInfo;
117 OpFoldResult linearizedIndices;
118 std::tie(linearizedInfo, linearizedIndices) =
119 memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
120 elementBitWidth, offset, sizes,
121 strides, indices);
122
123 // delta = bufferSize - linearizedOffset
124 Value vectorSizeOffset =
125 arith::ConstantIndexOp::create(rewriter, loc, vectorSize);
126 Value linearIndex =
127 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
129 rewriter, loc, linearizedInfo.linearizedSize);
130 Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
131
132 // 1) check if delta < vectorSize
133 Value isOutofBounds = arith::CmpIOp::create(
134 rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
135
136 // 2) check if (detla % elements_per_word != 0)
137 Value elementsPerWord = arith::ConstantIndexOp::create(
138 rewriter, loc, llvm::divideCeil(32, elementBitWidth));
139 Value isNotWordAligned = arith::CmpIOp::create(
140 rewriter, loc, arith::CmpIPredicate::ne,
141 arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
142 arith::ConstantIndexOp::create(rewriter, loc, 0));
143
144 // We take the fallback of maskedload default lowering only it is both
145 // out-of-bounds and not word aligned. The fallback ensures correct results
146 // when loading at the boundary of the buffer since buffer load returns
147 // inconsistent zeros for the whole word when boundary is crossed.
148 Value ifCondition =
149 arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
150
151 auto thenBuilder = [&](OpBuilder &builder, Location loc) {
152 Operation *read = builder.clone(*maskedOp.getOperation());
153 read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
154 Value readResult = read->getResult(0);
155 scf::YieldOp::create(builder, loc, readResult);
156 };
157
158 auto elseBuilder = [&](OpBuilder &builder, Location loc) {
159 Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
160 /*passthru=*/true);
161 scf::YieldOp::create(rewriter, loc, res);
162 };
163
164 auto ifOp =
165 scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
166
167 rewriter.replaceOp(maskedOp, ifOp);
168
169 return success();
170 }
171};
172
173struct FullMaskedLoadToConditionalLoad
174 : OpRewritePattern<vector::MaskedLoadOp> {
176
177 LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
178 PatternRewriter &rewriter) const override {
179 if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
180 return rewriter.notifyMatchFailure(
181 loadOp, "buffer loads are handled by a more specialized pattern");
182
183 FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
184 if (failed(maybeCond)) {
185 return rewriter.notifyMatchFailure(loadOp,
186 "isn't loading a broadcasted scalar");
187 }
188
189 Value cond = maybeCond.value();
190 auto trueBuilder = [&](OpBuilder &builder, Location loc) {
191 Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
192 /*passthru=*/false);
193 scf::YieldOp::create(rewriter, loc, res);
194 };
195 auto falseBuilder = [&](OpBuilder &builder, Location loc) {
196 scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
197 };
198 auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
199 falseBuilder);
200 rewriter.replaceOp(loadOp, ifOp);
201 return success();
203};
205struct FullMaskedStoreToConditionalStore
206 : OpRewritePattern<vector::MaskedStoreOp> {
209 LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
210 PatternRewriter &rewriter) const override {
211 // A condition-free implementation of fully masked stores requires
212 // 1) an accessor for the num_records field on buffer resources/fat pointers
213 // 2) knowledge that said field will always be set accurately - that is,
214 // that writes to x < num_records of offset wouldn't trap, which is
215 // something a pattern user would need to assert or we'd need to prove.
216 //
217 // Therefore, conditional stores to buffers still go down this path at
218 // present.
219
220 FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
221 if (failed(maybeCond)) {
222 return failure();
224 Value cond = maybeCond.value();
225
226 auto trueBuilder = [&](OpBuilder &builder, Location loc) {
227 vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
228 storeOp.getBase(), storeOp.getIndices());
229 scf::YieldOp::create(rewriter, loc);
230 };
231 auto ifOp =
232 scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
233 rewriter.replaceOp(storeOp, ifOp);
234 return success();
235 }
237
238} // namespace
239
242 patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
243 FullMaskedStoreToConditionalStore>(patterns.getContext(),
244 benefit);
245}
246
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru)
static constexpr char kMaskedloadNeedsMask[]
static FailureOr< Value > matchFullMask(OpBuilder &b, Value val)
Check if the given value comes from a broadcasted i1 condition.
static LogicalResult hasBufferAddressSpace(Type type)
This pattern supports lowering of: vector.maskedload to vector.load and arith.select if the memref is...
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:98
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
This class represents a single result from folding an operation.
OpT getOperation()
Return the current operation being transformed.
Definition Pass.h:389
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext & getContext()
Return the MLIR context for the current operation being transformed.
Definition Pass.h:178
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:226
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition MemRefUtils.h:50