MLIR  21.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.gather' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
34 
35 #define DEBUG_TYPE "vector-broadcast-lowering"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 namespace {
41 /// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
42 /// outermost dimension. For example:
43 /// ```
44 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
45 /// ... into vector<2x3xf32>
46 ///
47 /// ==>
48 ///
49 /// %0 = arith.constant dense<0.0> : vector<2x3xf32>
50 /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
51 /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
52 /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
53 /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
54 /// ```
55 ///
56 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
57 ///
58 /// Supports vector types with a fixed leading dimension.
59 struct UnrollGather : OpRewritePattern<vector::GatherOp> {
61 
62  LogicalResult matchAndRewrite(vector::GatherOp op,
63  PatternRewriter &rewriter) const override {
64  VectorType resultTy = op.getType();
65  if (resultTy.getRank() < 2)
66  return rewriter.notifyMatchFailure(op, "already 1-D");
67 
68  // Unrolling doesn't take vscale into account. Pattern is disabled for
69  // vectors with leading scalable dim(s).
70  if (resultTy.getScalableDims().front())
71  return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
72 
73  Location loc = op.getLoc();
74  Value indexVec = op.getIndexVec();
75  Value maskVec = op.getMask();
76  Value passThruVec = op.getPassThru();
77 
78  Value result = rewriter.create<arith::ConstantOp>(
79  loc, resultTy, rewriter.getZeroAttr(resultTy));
80 
81  VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
82 
83  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
84  int64_t thisIdx[1] = {i};
85 
86  Value indexSubVec =
87  rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
88  Value maskSubVec =
89  rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
90  Value passThruSubVec =
91  rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
92  Value subGather = rewriter.create<vector::GatherOp>(
93  loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
94  passThruSubVec);
95  result =
96  rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
97  }
98 
99  rewriter.replaceOp(op, result);
100  return success();
101  }
102 };
103 
104 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
105 /// MemRef with updated indices that model the strided access.
106 ///
107 /// ```mlir
108 /// %subview = memref.subview %M (...)
109 /// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
110 /// %gather = vector.gather %subview[%idxs] (...)
111 /// : memref<100xf32, strided<[3]>>
112 /// ```
113 /// ==>
114 /// ```mlir
115 /// %collapse_shape = memref.collapse_shape %M (...)
116 /// : memref<100x3xf32> into memref<300xf32>
117 /// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
118 /// %gather = vector.gather %collapse_shape[%new_idxs] (...)
119 /// : memref<300xf32> (...)
120 /// ```
121 ///
122 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
123 /// but should be fairly straightforward to extend beyond that.
124 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
126 
127  LogicalResult matchAndRewrite(vector::GatherOp op,
128  PatternRewriter &rewriter) const override {
129  Value base = op.getBase();
130 
131  // TODO: Strided accesses might be coming from other ops as well
132  auto subview = base.getDefiningOp<memref::SubViewOp>();
133  if (!subview)
134  return failure();
135 
136  auto sourceType = subview.getSource().getType();
137 
138  // TODO: Allow ranks > 2.
139  if (sourceType.getRank() != 2)
140  return failure();
141 
142  // Get strides
143  auto layout = subview.getResult().getType().getLayout();
144  auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
145  if (!stridedLayoutAttr)
146  return failure();
147 
148  // TODO: Allow the access to be strided in multiple dimensions.
149  if (stridedLayoutAttr.getStrides().size() != 1)
150  return failure();
151 
152  int64_t srcTrailingDim = sourceType.getShape().back();
153 
154  // Assume that the stride matches the trailing dimension of the source
155  // memref.
156  // TODO: Relax this assumption.
157  if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
158  return failure();
159 
160  // 1. Collapse the input memref so that it's "flat".
161  SmallVector<ReassociationIndices> reassoc = {{0, 1}};
162  Value collapsed = rewriter.create<memref::CollapseShapeOp>(
163  op.getLoc(), subview.getSource(), reassoc);
164 
165  // 2. Generate new gather indices that will model the
166  // strided access.
167  IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
168  VectorType vType = op.getIndexVec().getType();
169  Value mulCst = rewriter.create<arith::ConstantOp>(
170  op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
171 
172  Value newIdxs =
173  rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
174 
175  // 3. Create an updated gather op with the collapsed input memref and the
176  // updated indices.
177  Value newGather = rewriter.create<vector::GatherOp>(
178  op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
179  newIdxs, op.getMask(), op.getPassThru());
180  rewriter.replaceOp(op, newGather);
181 
182  return success();
183  }
184 };
185 
186 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
187 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
188 /// loads/extracts are made conditional using `scf.if` ops.
189 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
191 
192  LogicalResult matchAndRewrite(vector::GatherOp op,
193  PatternRewriter &rewriter) const override {
194  VectorType resultTy = op.getType();
195  if (resultTy.getRank() != 1)
196  return rewriter.notifyMatchFailure(op, "unsupported rank");
197 
198  if (resultTy.isScalable())
199  return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
200 
201  Location loc = op.getLoc();
202  Type elemTy = resultTy.getElementType();
203  // Vector type with a single element. Used to generate `vector.loads`.
204  VectorType elemVecTy = VectorType::get({1}, elemTy);
205 
206  Value condMask = op.getMask();
207  Value base = op.getBase();
208 
209  // vector.load requires the most minor memref dim to have unit stride
210  // (unless reading exactly 1 element)
211  if (auto memType = dyn_cast<MemRefType>(base.getType())) {
212  if (auto stridesAttr =
213  dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
214  if (stridesAttr.getStrides().back() != 1 &&
215  resultTy.getNumElements() != 1)
216  return failure();
217  }
218  }
219 
220  Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
221  loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
222  op.getIndexVec());
223  auto baseOffsets = llvm::to_vector(op.getIndices());
224  Value lastBaseOffset = baseOffsets.back();
225 
226  Value result = op.getPassThru();
227 
228  // Emit a conditional access for each vector element.
229  for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
230  int64_t thisIdx[1] = {i};
231  Value condition =
232  rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
233  Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
234  baseOffsets.back() =
235  rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
236 
237  auto loadBuilder = [&](OpBuilder &b, Location loc) {
238  Value extracted;
239  if (isa<MemRefType>(base.getType())) {
240  // `vector.load` does not support scalar result; emit a vector load
241  // and extract the single result instead.
242  Value load =
243  b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
244  int64_t zeroIdx[1] = {0};
245  extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
246  } else {
247  extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
248  }
249 
250  Value newResult =
251  b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
252  b.create<scf::YieldOp>(loc, newResult);
253  };
254  auto passThruBuilder = [result](OpBuilder &b, Location loc) {
255  b.create<scf::YieldOp>(loc, result);
256  };
257 
258  result =
259  rewriter
260  .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
261  /*elseBuilder=*/passThruBuilder)
262  .getResult(0);
263  }
264 
265  rewriter.replaceOp(op, result);
266  return success();
267  }
268 };
269 } // namespace
270 
273  patterns.add<UnrollGather>(patterns.getContext(), benefit);
274 }
275 
278  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
279  patterns.getContext(), benefit);
280 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:295
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319