MLIR  21.0.0git
TransferReadToLoad.cpp
Go to the documentation of this file.
1 //===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
25 #include "llvm/Support/MathExtras.h"
26 
27 namespace mlir::amdgpu {
28 #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
29 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
30 } // namespace mlir::amdgpu
31 
32 using namespace mlir;
33 using namespace mlir::amdgpu;
34 
35 /// This pattern supports lowering of:
36 /// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
37 /// `vector.broadcast` if all of the following hold:
38 /// - The transfer op is masked.
39 /// - The memref is in buffer address space.
40 /// - Stride of most minor memref dimension must be 1.
41 /// - Out-of-bounds masking is not required.
42 /// - If the memref's element type is a vector type then it coincides with the
43 /// result type.
44 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
45 /// Note: those conditions mostly come from TransferReadToVectorLoadLowering
46 /// pass.
47 static LogicalResult transferPreconditions(
48  PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
49  bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
50  if (!xferOp.getMask())
51  return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
52 
53  // Permutations are handled by VectorToSCF or
54  // populateVectorTransferPermutationMapLoweringPatterns.
55  // We let the 0-d corner case pass-through as it is supported.
56  SmallVector<unsigned> broadcastedDims;
57  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
58  &broadcastedDims))
59  return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
60 
61  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
62  if (!memRefType)
63  return rewriter.notifyMatchFailure(xferOp, "not a memref source");
64 
65  Attribute addrSpace = memRefType.getMemorySpace();
66  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
67  return rewriter.notifyMatchFailure(xferOp, "no address space");
68 
69  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
70  amdgpu::AddressSpace::FatRawBuffer)
71  return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
72 
73  // Non-unit strides are handled by VectorToSCF.
74  if (!memRefType.isLastDimUnitStride())
75  return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
76 
77  if (memRefType.getElementTypeBitWidth() < 8)
78  return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
79 
80  // If there is broadcasting involved then we first load the unbroadcasted
81  // vector, and then broadcast it with `vector.broadcast`.
82  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
83  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
84  for (unsigned i : broadcastedDims)
85  unbroadcastedVectorShape[i] = 1;
86  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
87  unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
88  requiresBroadcasting = !broadcastedDims.empty();
89 
90  // `vector.load` supports vector types as memref's elements only when the
91  // resulting vector type is the same as the element type.
92  auto memrefElTy = memRefType.getElementType();
93  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
94  return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
95 
96  // Otherwise, element types of the memref and the vector must match.
97  if (!isa<VectorType>(memrefElTy) &&
98  memrefElTy != xferOp.getVectorType().getElementType())
99  return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
100 
101  // Out-of-bounds dims are handled by MaterializeTransferMask.
102  if (xferOp.hasOutOfBoundsDim())
103  return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
104 
105  if (xferOp.getVectorType().getRank() != 1)
106  // vector.maskedload operates on 1-D vectors.
107  return rewriter.notifyMatchFailure(
108  xferOp, "vector type is not rank 1, can't create masked load, needs "
109  "VectorToSCF");
110 
111  return success();
112 }
113 
115  vector::TransferReadOp readOp,
116  bool requiresBroadcasting,
117  VectorType unbroadcastedVectorType) {
118  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
119  readOp.getPadding());
120  Value load = builder.create<vector::LoadOp>(
121  loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
122  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
123  readOp.getMask(), load, fill);
124  // Insert a broadcasting op if required.
125  if (requiresBroadcasting) {
126  res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
127  }
128  return res;
129 }
130 
131 static constexpr char kTransferReadNeedsMask[] =
132  "amdgpu.buffer_transfer_read_needs_mask";
133 
134 namespace {
135 
136 struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
138 
139  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
140  PatternRewriter &rewriter) const override {
141  if (readOp->hasAttr(kTransferReadNeedsMask))
142  return failure();
143 
144  bool requiresBroadcasting = false;
145  VectorType unbroadcastedVectorType;
146  if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
147  unbroadcastedVectorType))) {
148  return failure();
149  }
150 
151  Location loc = readOp.getLoc();
152  Value src = readOp.getBase();
153 
154  VectorType vectorType = readOp.getVectorType();
155  int64_t vectorSize = vectorType.getNumElements();
156  int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
157  SmallVector<OpFoldResult> indices = readOp.getIndices();
158 
159  auto stridedMetadata =
160  rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
161  SmallVector<OpFoldResult> strides =
162  stridedMetadata.getConstifiedMixedStrides();
163  SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
164  OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
165  memref::LinearizedMemRefInfo linearizedInfo;
166  OpFoldResult linearizedIndices;
167  std::tie(linearizedInfo, linearizedIndices) =
168  memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
169  elementBitWidth, offset, sizes,
170  strides, indices);
171 
172  // delta = bufferSize - linearizedOffset
173  Value vectorSizeOffset =
174  rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
175  Value linearIndex =
176  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
178  rewriter, loc, linearizedInfo.linearizedSize);
179  Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
180 
181  // 1) check if delta < vectorSize
182  Value isOutofBounds = rewriter.create<arith::CmpIOp>(
183  loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
184 
185  // 2) check if (detla % elements_per_word != 0)
186  Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
187  loc, llvm::divideCeil(32, elementBitWidth));
188  Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
189  loc, arith::CmpIPredicate::ne,
190  rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
191  rewriter.create<arith::ConstantIndexOp>(loc, 0));
192 
193  // We take the fallback of transfer_read default lowering only it is both
194  // out-of-bounds and not word aligned. The fallback ensures correct results
195  // when loading at the boundary of the buffer since buffer load returns
196  // inconsistent zeros for the whole word when boundary is crossed.
197  Value ifCondition =
198  rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
199 
200  auto thenBuilder = [&](OpBuilder &builder, Location loc) {
201  Operation *read = builder.clone(*readOp.getOperation());
202  read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
203  Value readResult = read->getResult(0);
204  builder.create<scf::YieldOp>(loc, readResult);
205  };
206 
207  auto elseBuilder = [&](OpBuilder &builder, Location loc) {
209  builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
210  rewriter.create<scf::YieldOp>(loc, res);
211  };
212 
213  auto ifOp =
214  rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
215 
216  rewriter.replaceOp(readOp, ifOp);
217 
218  return success();
219  }
220 };
221 
222 } // namespace
223 
226  patterns.add<TransferReadLowering>(patterns.getContext());
227 }
228 
230  : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
231  AmdgpuTransferReadToLoadPass> {
232  void runOnOperation() override {
235  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
236  return signalPassFailure();
237  }
238  }
239 };
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)
This pattern supports lowering of: vector.transfer_read to a combination of vector....
static constexpr char kTransferReadNeedsMask[]
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::TransferReadOp readOp, bool requiresBroadcasting, VectorType unbroadcastedVectorType)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
UnitAttr getUnitAttr()
Definition: Builders.cpp:96
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50