MLIR  21.0.0git
TransferReadToLoad.cpp
Go to the documentation of this file.
1 //===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
25 #include "llvm/Support/MathExtras.h"
26 
27 namespace mlir::amdgpu {
28 #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
29 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
30 } // namespace mlir::amdgpu
31 
32 using namespace mlir;
33 using namespace mlir::amdgpu;
34 
35 /// This pattern supports lowering of:
36 /// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
37 /// `vector.broadcast` if all of the following hold:
38 /// - The transfer op is masked.
39 /// - The memref is in buffer address space.
40 /// - Stride of most minor memref dimension must be 1.
41 /// - Out-of-bounds masking is not required.
42 /// - If the memref's element type is a vector type then it coincides with the
43 /// result type.
44 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
45 /// Note: those conditions mostly come from TransferReadToVectorLoadLowering
46 /// pass.
47 static LogicalResult transferPreconditions(
48  PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
49  bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
50  if (!xferOp.getMask())
51  return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
52 
53  // Permutations are handled by VectorToSCF or
54  // populateVectorTransferPermutationMapLoweringPatterns.
55  // We let the 0-d corner case pass-through as it is supported.
56  SmallVector<unsigned> broadcastedDims;
57  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
58  &broadcastedDims))
59  return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
60 
61  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
62  if (!memRefType)
63  return rewriter.notifyMatchFailure(xferOp, "not a memref source");
64 
65  Attribute addrSpace = memRefType.getMemorySpace();
66  if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
67  return rewriter.notifyMatchFailure(xferOp, "no address space");
68 
69  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
70  amdgpu::AddressSpace::FatRawBuffer)
71  return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
72 
73  // Non-unit strides are handled by VectorToSCF.
74  if (!memRefType.isLastDimUnitStride())
75  return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
76 
77  if (memRefType.getElementTypeBitWidth() < 8)
78  return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
79 
80  // If there is broadcasting involved then we first load the unbroadcasted
81  // vector, and then broadcast it with `vector.broadcast`.
82  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
83  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
84  for (unsigned i : broadcastedDims)
85  unbroadcastedVectorShape[i] = 1;
86  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
87  unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
88  requiresBroadcasting = !broadcastedDims.empty();
89 
90  // `vector.load` supports vector types as memref's elements only when the
91  // resulting vector type is the same as the element type.
92  auto memrefElTy = memRefType.getElementType();
93  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
94  return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
95 
96  // Otherwise, element types of the memref and the vector must match.
97  if (!isa<VectorType>(memrefElTy) &&
98  memrefElTy != xferOp.getVectorType().getElementType())
99  return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
100 
101  // Out-of-bounds dims are handled by MaterializeTransferMask.
102  if (xferOp.hasOutOfBoundsDim())
103  return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
104 
105  if (xferOp.getVectorType().getRank() != 1)
106  // vector.maskedload operates on 1-D vectors.
107  return rewriter.notifyMatchFailure(
108  xferOp, "vector type is not rank 1, can't create masked load, needs "
109  "VectorToSCF");
110 
111  return success();
112 }
113 
115  vector::TransferReadOp readOp,
116  bool requiresBroadcasting,
117  VectorType unbroadcastedVectorType) {
118  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
119  readOp.getPadding());
120  Value load = builder.create<vector::LoadOp>(
121  loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
122  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
123  readOp.getMask(), load, fill);
124  // Insert a broadcasting op if required.
125  if (requiresBroadcasting) {
126  res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
127  }
128  return res;
129 }
130 
131 static constexpr char kTransferReadNeedsMask[] =
132  "amdgpu.buffer_transfer_read_needs_mask";
133 
134 namespace {
135 
136 struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
138 
139  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
140  PatternRewriter &rewriter) const override {
141  if (readOp->hasAttr(kTransferReadNeedsMask))
142  return failure();
143 
144  bool requiresBroadcasting = false;
145  VectorType unbroadcastedVectorType;
146  if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
147  unbroadcastedVectorType))) {
148  return failure();
149  }
150 
151  Location loc = readOp.getLoc();
152  Value src = readOp.getSource();
153 
154  VectorType vectorType = readOp.getVectorType();
155  int64_t vectorSize = vectorType.getNumElements();
156  int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
157  SmallVector<OpFoldResult> indices = readOp.getIndices();
158 
159  auto stridedMetadata =
160  rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
161  SmallVector<OpFoldResult> strides =
162  stridedMetadata.getConstifiedMixedStrides();
163  SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
164  OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
165  OpFoldResult linearizedIndices;
166  std::tie(std::ignore, linearizedIndices) =
167  memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
168  elementBitWidth, offset, sizes,
169  strides, indices);
170 
171  // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
172  // Note below doesn't give the correct result for the linearized size.
173  // Value totalSize = getValueOrCreateConstantIndexOp(
174  // rewriter, loc, linearizedInfo.linearizedSize);
175  // It computes the multiplied sizes of all dimensions instead of taking
176  // the maximum of each dimension size * stride.
177  SmallVector<AffineExpr> productExpressions;
178  SmallVector<Value> productResults;
179  unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
180 
181  SmallVector<AffineExpr> symbols(2 * sourceRank);
182  SmallVector<Value> offsetValues;
183  bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
184 
185  size_t symbolIndex = 0;
186  for (size_t i = 0; i < sourceRank; ++i) {
187  AffineExpr strideExpr, sizeExpr;
188  OpFoldResult stride = strides[i];
189  OpFoldResult size = sizes[i];
190  if (auto constantStride = getConstantIntValue(stride)) {
191  strideExpr = rewriter.getAffineConstantExpr(*constantStride);
192  } else {
193  strideExpr = symbols[symbolIndex++];
194  offsetValues.push_back(
195  getValueOrCreateConstantIndexOp(rewriter, loc, stride));
196  }
197 
198  if (auto constantSize = getConstantIntValue(size)) {
199  sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
200  } else {
201  sizeExpr = symbols[symbolIndex++];
202  offsetValues.push_back(
203  getValueOrCreateConstantIndexOp(rewriter, loc, size));
204  }
205 
206  productExpressions.push_back(strideExpr * sizeExpr);
207  }
208 
209  AffineMap maxMap = AffineMap::get(
210  /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
211  rewriter.getContext());
212  Value totalSize =
213  rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
214 
215  // delta = bufferSize - linearizedOffset
216  Value vectorSizeOffset =
217  rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
218  Value linearIndex =
219  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
220  Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
221 
222  // 1) check if delta < vectorSize
223  Value isOutofBounds = rewriter.create<arith::CmpIOp>(
224  loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
225 
226  // 2) check if (detla % elements_per_word != 0)
227  Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
228  loc, llvm::divideCeil(32, elementBitWidth));
229  Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
230  loc, arith::CmpIPredicate::ne,
231  rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
232  rewriter.create<arith::ConstantIndexOp>(loc, 0));
233 
234  // We take the fallback of transfer_read default lowering only it is both
235  // out-of-bounds and not word aligned. The fallback ensures correct results
236  // when loading at the boundary of the buffer since buffer load returns
237  // inconsistent zeros for the whole word when boundary is crossed.
238  Value ifCondition =
239  rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
240 
241  auto thenBuilder = [&](OpBuilder &builder, Location loc) {
242  Operation *read = builder.clone(*readOp.getOperation());
243  read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
244  Value readResult = read->getResult(0);
245  builder.create<scf::YieldOp>(loc, readResult);
246  };
247 
248  auto elseBuilder = [&](OpBuilder &builder, Location loc) {
250  builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
251  rewriter.create<scf::YieldOp>(loc, res);
252  };
253 
254  auto ifOp =
255  rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
256 
257  rewriter.replaceOp(readOp, ifOp);
258 
259  return success();
260  }
261 };
262 
263 } // namespace
264 
267  patterns.add<TransferReadLowering>(patterns.getContext());
268 }
269 
271  : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
272  AmdgpuTransferReadToLoadPass> {
273  void runOnOperation() override {
276  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
277  return signalPassFailure();
278  }
279  }
280 };
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)
This pattern supports lowering of: vector.transfer_read to a combination of vector....
static constexpr char kTransferReadNeedsMask[]
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::TransferReadOp readOp, bool requiresBroadcasting, VectorType unbroadcastedVectorType)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319