MLIR  21.0.0git
TransferReadToLoad.cpp
Go to the documentation of this file.
1 //===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir::amdgpu {
21 #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
22 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
23 } // namespace mlir::amdgpu
24 
25 using namespace mlir;
26 using namespace mlir::amdgpu;
27 
28 /// This pattern supports lowering of:
29 /// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
30 /// `vector.broadcast` if all of the following hold:
31 /// - The transfer op is masked.
32 /// - The memref is in buffer address space.
33 /// - Stride of most minor memref dimension must be 1.
34 /// - Out-of-bounds masking is not required.
35 /// - If the memref's element type is a vector type then it coincides with the
36 /// result type.
37 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
38 /// Note: those conditions mostly come from TransferReadToVectorLoadLowering
39 /// pass.
40 static LogicalResult transferPreconditions(
41  PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
42  bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
43  if (!xferOp.getMask())
44  return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
45 
46  // Permutations are handled by VectorToSCF or
47  // populateVectorTransferPermutationMapLoweringPatterns.
48  // We let the 0-d corner case pass-through as it is supported.
49  SmallVector<unsigned> broadcastedDims;
50  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
51  &broadcastedDims))
52  return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
53 
54  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
55  if (!memRefType)
56  return rewriter.notifyMatchFailure(xferOp, "not a memref source");
57 
58  Attribute addrSpace = memRefType.getMemorySpace();
59  if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
60  return rewriter.notifyMatchFailure(xferOp, "no address space");
61 
62  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
63  amdgpu::AddressSpace::FatRawBuffer)
64  return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
65 
66  // Non-unit strides are handled by VectorToSCF.
67  if (!memRefType.isLastDimUnitStride())
68  return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
69 
70  // If there is broadcasting involved then we first load the unbroadcasted
71  // vector, and then broadcast it with `vector.broadcast`.
72  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
73  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
74  for (unsigned i : broadcastedDims)
75  unbroadcastedVectorShape[i] = 1;
76  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
77  unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
78  requiresBroadcasting = !broadcastedDims.empty();
79 
80  // `vector.load` supports vector types as memref's elements only when the
81  // resulting vector type is the same as the element type.
82  auto memrefElTy = memRefType.getElementType();
83  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
84  return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
85 
86  // Otherwise, element types of the memref and the vector must match.
87  if (!isa<VectorType>(memrefElTy) &&
88  memrefElTy != xferOp.getVectorType().getElementType())
89  return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
90 
91  // Out-of-bounds dims are handled by MaterializeTransferMask.
92  if (xferOp.hasOutOfBoundsDim())
93  return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
94 
95  if (xferOp.getVectorType().getRank() != 1)
96  // vector.maskedload operates on 1-D vectors.
97  return rewriter.notifyMatchFailure(
98  xferOp, "vector type is not rank 1, can't create masked load, needs "
99  "VectorToSCF");
100 
101  return success();
102 }
103 
104 namespace {
105 
106 struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
108 
109  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
110  PatternRewriter &rewriter) const override {
111 
112  bool requiresBroadcasting = false;
113  VectorType unbroadcastedVectorType;
114  if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
115  unbroadcastedVectorType))) {
116  return failure();
117  }
118 
119  Location loc = readOp.getLoc();
120  Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
121  readOp.getPadding());
122  Value load = rewriter.create<vector::LoadOp>(
123  loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
124  Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
125  readOp.getMask(), load, fill);
126 
127  // Insert a broadcasting op if required.
128  if (requiresBroadcasting) {
129  res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
130  res);
131  }
132 
133  rewriter.replaceOp(readOp, res);
134 
135  return success();
136  }
137 };
138 
139 } // namespace
140 
143  patterns.add<TransferReadLowering>(patterns.getContext());
144 }
145 
147  : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
148  AmdgpuTransferReadToLoadPass> {
149  void runOnOperation() override {
152  walkAndApplyPatterns(getOperation(), std::move(patterns));
153  }
154 };
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)
This pattern supports lowering of: vector.transfer_read to a combination of vector....
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368