MLIR  22.0.0git
MaskedloadToLoad.cpp
Go to the documentation of this file.
1 //===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/Support/MathExtras.h"
25 
26 namespace mlir::amdgpu {
27 #define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
28 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
29 } // namespace mlir::amdgpu
30 
31 using namespace mlir;
32 using namespace mlir::amdgpu;
33 
34 /// This pattern supports lowering of: `vector.maskedload` to `vector.load`
35 /// and `arith.select` if the memref is in buffer address space.
36 static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
37  vector::MaskedLoadOp maskedOp) {
38  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
39  if (!memRefType)
40  return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
41 
42  Attribute addrSpace = memRefType.getMemorySpace();
43  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
44  return rewriter.notifyMatchFailure(maskedOp, "no address space");
45 
46  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
47  amdgpu::AddressSpace::FatRawBuffer)
48  return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
49 
50  return success();
51 }
52 
54  vector::MaskedLoadOp maskedOp,
55  bool passthru) {
56  VectorType vectorType = maskedOp.getVectorType();
57  Value load = vector::LoadOp::create(
58  builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59  if (passthru)
60  load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
61  load, maskedOp.getPassThru());
62  return load;
63 }
64 
65 /// Check if the given value comes from a broadcasted i1 condition.
66 static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
67  auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
68  if (!broadcastOp)
69  return failure();
70  if (isa<VectorType>(broadcastOp.getSourceType()))
71  return failure();
72  return broadcastOp.getSource();
73 }
74 
75 static constexpr char kMaskedloadNeedsMask[] =
76  "amdgpu.buffer_maskedload_needs_mask";
77 
78 namespace {
79 
80 struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
82 
83  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
84  PatternRewriter &rewriter) const override {
85  if (maskedOp->hasAttr(kMaskedloadNeedsMask))
86  return failure();
87 
88  if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
89  return failure();
90  }
91 
92  // Check if this is either a full inbounds load or an empty, oob load. If
93  // so, take the fast path and don't generate an if condition, because we
94  // know doing the oob load is always safe.
95  if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
96  Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
97  maskedOp, /*passthru=*/true);
98  rewriter.replaceOp(maskedOp, load);
99  return success();
100  }
101 
102  Location loc = maskedOp.getLoc();
103  Value src = maskedOp.getBase();
104 
105  VectorType vectorType = maskedOp.getVectorType();
106  int64_t vectorSize = vectorType.getNumElements();
107  int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
108  SmallVector<OpFoldResult> indices = maskedOp.getIndices();
109 
110  auto stridedMetadata =
111  memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
112  SmallVector<OpFoldResult> strides =
113  stridedMetadata.getConstifiedMixedStrides();
114  SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
115  OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
116  memref::LinearizedMemRefInfo linearizedInfo;
117  OpFoldResult linearizedIndices;
118  std::tie(linearizedInfo, linearizedIndices) =
119  memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
120  elementBitWidth, offset, sizes,
121  strides, indices);
122 
123  // delta = bufferSize - linearizedOffset
124  Value vectorSizeOffset =
125  arith::ConstantIndexOp::create(rewriter, loc, vectorSize);
126  Value linearIndex =
127  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
129  rewriter, loc, linearizedInfo.linearizedSize);
130  Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
131 
132  // 1) check if delta < vectorSize
133  Value isOutofBounds = arith::CmpIOp::create(
134  rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
135 
136  // 2) check if (detla % elements_per_word != 0)
137  Value elementsPerWord = arith::ConstantIndexOp::create(
138  rewriter, loc, llvm::divideCeil(32, elementBitWidth));
139  Value isNotWordAligned = arith::CmpIOp::create(
140  rewriter, loc, arith::CmpIPredicate::ne,
141  arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
142  arith::ConstantIndexOp::create(rewriter, loc, 0));
143 
144  // We take the fallback of maskedload default lowering only it is both
145  // out-of-bounds and not word aligned. The fallback ensures correct results
146  // when loading at the boundary of the buffer since buffer load returns
147  // inconsistent zeros for the whole word when boundary is crossed.
148  Value ifCondition =
149  arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
150 
151  auto thenBuilder = [&](OpBuilder &builder, Location loc) {
152  Operation *read = builder.clone(*maskedOp.getOperation());
153  read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
154  Value readResult = read->getResult(0);
155  scf::YieldOp::create(builder, loc, readResult);
156  };
157 
158  auto elseBuilder = [&](OpBuilder &builder, Location loc) {
159  Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
160  /*passthru=*/true);
161  scf::YieldOp::create(rewriter, loc, res);
162  };
163 
164  auto ifOp =
165  scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
166 
167  rewriter.replaceOp(maskedOp, ifOp);
168 
169  return success();
170  }
171 };
172 
173 struct FullMaskedLoadToConditionalLoad
174  : OpRewritePattern<vector::MaskedLoadOp> {
176 
177  LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
178  PatternRewriter &rewriter) const override {
179  FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
180  if (failed(maybeCond)) {
181  return failure();
182  }
183 
184  Value cond = maybeCond.value();
185  auto trueBuilder = [&](OpBuilder &builder, Location loc) {
186  Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
187  /*passthru=*/false);
188  scf::YieldOp::create(rewriter, loc, res);
189  };
190  auto falseBuilder = [&](OpBuilder &builder, Location loc) {
191  scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
192  };
193  auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
194  falseBuilder);
195  rewriter.replaceOp(loadOp, ifOp);
196  return success();
197  }
198 };
199 
200 struct FullMaskedStoreToConditionalStore
201  : OpRewritePattern<vector::MaskedStoreOp> {
203 
204  LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
205  PatternRewriter &rewriter) const override {
206  FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
207  if (failed(maybeCond)) {
208  return failure();
209  }
210  Value cond = maybeCond.value();
211 
212  auto trueBuilder = [&](OpBuilder &builder, Location loc) {
213  vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
214  storeOp.getBase(), storeOp.getIndices());
215  scf::YieldOp::create(rewriter, loc);
216  };
217  auto ifOp =
218  scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
219  rewriter.replaceOp(storeOp, ifOp);
220  return success();
221  }
222 };
223 
224 } // namespace
225 
228  patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
229  FullMaskedStoreToConditionalStore>(patterns.getContext(),
230  benefit);
231 }
232 
234  : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
235  void runOnOperation() override {
238  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
239  return signalPassFailure();
240  }
241  }
242 };
static MLIRContext * getContext(OpFoldResult val)
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru)
static constexpr char kMaskedloadNeedsMask[]
static FailureOr< Value > matchFullMask(OpBuilder &b, Value val)
Check if the given value comes from a broadcasted i1 condition.
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, vector::MaskedLoadOp maskedOp)
This pattern supports lowering of: vector.maskedload to vector.load and arith.select if the memref is...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50