MLIR 22.0.0git
MaskedloadToLoad.cpp
Go to the documentation of this file.
1//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
24#include "llvm/Support/MathExtras.h"
25
26namespace mlir::amdgpu {
27#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
28#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
29} // namespace mlir::amdgpu
30
31using namespace mlir;
32using namespace mlir::amdgpu;
33
34/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
35/// and `arith.select` if the memref is in buffer address space.
36static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
37 vector::MaskedLoadOp maskedOp) {
38 auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
39 if (!memRefType)
40 return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
41
42 Attribute addrSpace = memRefType.getMemorySpace();
43 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
44 return rewriter.notifyMatchFailure(maskedOp, "no address space");
45
46 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
47 amdgpu::AddressSpace::FatRawBuffer)
48 return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
49
50 return success();
51}
52
54 vector::MaskedLoadOp maskedOp,
55 bool passthru) {
56 VectorType vectorType = maskedOp.getVectorType();
57 Value load = vector::LoadOp::create(
58 builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59 if (passthru)
60 load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
61 load, maskedOp.getPassThru());
62 return load;
63}
64
65/// Check if the given value comes from a broadcasted i1 condition.
66static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
67 auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
68 if (!broadcastOp)
69 return failure();
70 if (isa<VectorType>(broadcastOp.getSourceType()))
71 return failure();
72 return broadcastOp.getSource();
73}
74
75static constexpr char kMaskedloadNeedsMask[] =
76 "amdgpu.buffer_maskedload_needs_mask";
77
78namespace {
79
80struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
82
83 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
84 PatternRewriter &rewriter) const override {
85 if (maskedOp->hasAttr(kMaskedloadNeedsMask))
86 return failure();
87
88 if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
89 return failure();
90 }
91
92 // Check if this is either a full inbounds load or an empty, oob load. If
93 // so, take the fast path and don't generate an if condition, because we
94 // know doing the oob load is always safe.
95 if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
96 Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
97 maskedOp, /*passthru=*/true);
98 rewriter.replaceOp(maskedOp, load);
99 return success();
100 }
101
102 Location loc = maskedOp.getLoc();
103 Value src = maskedOp.getBase();
104
105 VectorType vectorType = maskedOp.getVectorType();
106 int64_t vectorSize = vectorType.getNumElements();
107 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
108 SmallVector<OpFoldResult> indices = maskedOp.getIndices();
109
110 auto stridedMetadata =
111 memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
113 stridedMetadata.getConstifiedMixedStrides();
114 SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
115 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
116 memref::LinearizedMemRefInfo linearizedInfo;
117 OpFoldResult linearizedIndices;
118 std::tie(linearizedInfo, linearizedIndices) =
119 memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
120 elementBitWidth, offset, sizes,
121 strides, indices);
122
123 // delta = bufferSize - linearizedOffset
124 Value vectorSizeOffset =
125 arith::ConstantIndexOp::create(rewriter, loc, vectorSize);
126 Value linearIndex =
127 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
129 rewriter, loc, linearizedInfo.linearizedSize);
130 Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
131
132 // 1) check if delta < vectorSize
133 Value isOutofBounds = arith::CmpIOp::create(
134 rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
135
136 // 2) check if (detla % elements_per_word != 0)
137 Value elementsPerWord = arith::ConstantIndexOp::create(
138 rewriter, loc, llvm::divideCeil(32, elementBitWidth));
139 Value isNotWordAligned = arith::CmpIOp::create(
140 rewriter, loc, arith::CmpIPredicate::ne,
141 arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
142 arith::ConstantIndexOp::create(rewriter, loc, 0));
143
144 // We take the fallback of maskedload default lowering only it is both
145 // out-of-bounds and not word aligned. The fallback ensures correct results
146 // when loading at the boundary of the buffer since buffer load returns
147 // inconsistent zeros for the whole word when boundary is crossed.
148 Value ifCondition =
149 arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
150
151 auto thenBuilder = [&](OpBuilder &builder, Location loc) {
152 Operation *read = builder.clone(*maskedOp.getOperation());
153 read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
154 Value readResult = read->getResult(0);
155 scf::YieldOp::create(builder, loc, readResult);
156 };
157
158 auto elseBuilder = [&](OpBuilder &builder, Location loc) {
159 Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
160 /*passthru=*/true);
161 scf::YieldOp::create(rewriter, loc, res);
162 };
163
164 auto ifOp =
165 scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
166
167 rewriter.replaceOp(maskedOp, ifOp);
168
169 return success();
170 }
171};
172
173struct FullMaskedLoadToConditionalLoad
174 : OpRewritePattern<vector::MaskedLoadOp> {
176
177 LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
178 PatternRewriter &rewriter) const override {
179 FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
180 if (failed(maybeCond)) {
181 return failure();
182 }
183
184 Value cond = maybeCond.value();
185 auto trueBuilder = [&](OpBuilder &builder, Location loc) {
186 Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
187 /*passthru=*/false);
188 scf::YieldOp::create(rewriter, loc, res);
189 };
190 auto falseBuilder = [&](OpBuilder &builder, Location loc) {
191 scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
192 };
193 auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
194 falseBuilder);
195 rewriter.replaceOp(loadOp, ifOp);
196 return success();
197 }
198};
199
200struct FullMaskedStoreToConditionalStore
201 : OpRewritePattern<vector::MaskedStoreOp> {
203
204 LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
205 PatternRewriter &rewriter) const override {
206 FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
207 if (failed(maybeCond)) {
208 return failure();
210 Value cond = maybeCond.value();
211
212 auto trueBuilder = [&](OpBuilder &builder, Location loc) {
213 vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
214 storeOp.getBase(), storeOp.getIndices());
215 scf::YieldOp::create(rewriter, loc);
216 };
217 auto ifOp =
218 scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
219 rewriter.replaceOp(storeOp, ifOp);
220 return success();
221 }
222};
224} // namespace
225
228 patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
229 FullMaskedStoreToConditionalStore>(patterns.getContext(),
230 benefit);
234 : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
242};
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru)
static constexpr char kMaskedloadNeedsMask[]
static FailureOr< Value > matchFullMask(OpBuilder &b, Value val)
Check if the given value comes from a broadcasted i1 condition.
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, vector::MaskedLoadOp maskedOp)
This pattern supports lowering of: vector.maskedload to vector.load and arith.select if the memref is...
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:98
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
This class represents a single result from folding an operation.
OpT getOperation()
Return the current operation being transformed.
Definition Pass.h:378
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext & getContext()
Return the MLIR context for the current operation being transformed.
Definition Pass.h:177
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:218
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition MemRefUtils.h:50