MLIR  20.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.gather' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
34 
35 #define DEBUG_TYPE "vector-broadcast-lowering"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 namespace {
41 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
42 /// outermost dimension. For example:
43 /// ```
44 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
45 /// ... into vector<2x3xf32>
46 ///
47 /// ==>
48 ///
49 /// %0 = arith.constant dense<0.0> : vector<2x3xf32>
50 /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
51 /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
52 /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
53 /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
54 /// ```
55 ///
56 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
57 ///
58 /// Supports vector types with a fixed leading dimension.
59 struct FlattenGather : OpRewritePattern<vector::GatherOp> {
61 
62  LogicalResult matchAndRewrite(vector::GatherOp op,
63  PatternRewriter &rewriter) const override {
64  VectorType resultTy = op.getType();
65  if (resultTy.getRank() < 2)
66  return rewriter.notifyMatchFailure(op, "already flat");
67 
68  // Unrolling doesn't take vscale into account. Pattern is disabled for
69  // vectors with leading scalable dim(s).
70  if (resultTy.getScalableDims().front())
71  return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
72 
73  Location loc = op.getLoc();
74  Value indexVec = op.getIndexVec();
75  Value maskVec = op.getMask();
76  Value passThruVec = op.getPassThru();
77 
78  Value result = rewriter.create<arith::ConstantOp>(
79  loc, resultTy, rewriter.getZeroAttr(resultTy));
80 
81  VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
82 
83  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
84  int64_t thisIdx[1] = {i};
85 
86  Value indexSubVec =
87  rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
88  Value maskSubVec =
89  rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
90  Value passThruSubVec =
91  rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
92  Value subGather = rewriter.create<vector::GatherOp>(
93  loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
94  passThruSubVec);
95  result =
96  rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
97  }
98 
99  rewriter.replaceOp(op, result);
100  return success();
101  }
102 };
103 
104 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
105 /// MemRef with updated indices that model the strided access.
106 ///
107 /// ```mlir
108 /// %subview = memref.subview %M (...)
109 /// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
110 /// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
111 /// ```
112 /// ==>
113 /// ```mlir
114 /// %collapse_shape = memref.collapse_shape %M (...)
115 /// : memref<100x3xf32> into memref<300xf32>
116 /// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
117 /// %gather = vector.gather %collapse_shape[%new_idxs] (...)
118 /// : memref<300xf32> (...)
119 /// ```
120 ///
121 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
122 /// but should be fairly straightforward to extend beyond that.
123 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
125 
126  LogicalResult matchAndRewrite(vector::GatherOp op,
127  PatternRewriter &rewriter) const override {
128  Value base = op.getBase();
129 
130  // TODO: Strided accesses might be coming from other ops as well
131  auto subview = base.getDefiningOp<memref::SubViewOp>();
132  if (!subview)
133  return failure();
134 
135  auto sourceType = subview.getSource().getType();
136 
137  // TODO: Allow ranks > 2.
138  if (sourceType.getRank() != 2)
139  return failure();
140 
141  // Get strides
142  auto layout = subview.getResult().getType().getLayout();
143  auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
144  if (!stridedLayoutAttr)
145  return failure();
146 
147  // TODO: Allow the access to be strided in multiple dimensions.
148  if (stridedLayoutAttr.getStrides().size() != 1)
149  return failure();
150 
151  int64_t srcTrailingDim = sourceType.getShape().back();
152 
153  // Assume that the stride matches the trailing dimension of the source
154  // memref.
155  // TODO: Relax this assumption.
156  if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
157  return failure();
158 
159  // 1. Collapse the input memref so that it's "flat".
160  SmallVector<ReassociationIndices> reassoc = {{0, 1}};
161  Value collapsed = rewriter.create<memref::CollapseShapeOp>(
162  op.getLoc(), subview.getSource(), reassoc);
163 
164  // 2. Generate new gather indices that will model the
165  // strided access.
166  IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
167  VectorType vType = op.getIndexVec().getType();
168  Value mulCst = rewriter.create<arith::ConstantOp>(
169  op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
170 
171  Value newIdxs =
172  rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
173 
174  // 3. Create an updated gather op with the collapsed input memref and the
175  // updated indices.
176  Value newGather = rewriter.create<vector::GatherOp>(
177  op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
178  newIdxs, op.getMask(), op.getPassThru());
179  rewriter.replaceOp(op, newGather);
180 
181  return success();
182  }
183 };
184 
185 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
186 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
187 /// loads/extracts are made conditional using `scf.if` ops.
188 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
190 
191  LogicalResult matchAndRewrite(vector::GatherOp op,
192  PatternRewriter &rewriter) const override {
193  VectorType resultTy = op.getType();
194  if (resultTy.getRank() != 1)
195  return rewriter.notifyMatchFailure(op, "unsupported rank");
196 
197  if (resultTy.isScalable())
198  return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
199 
200  Location loc = op.getLoc();
201  Type elemTy = resultTy.getElementType();
202  // Vector type with a single element. Used to generate `vector.loads`.
203  VectorType elemVecTy = VectorType::get({1}, elemTy);
204 
205  Value condMask = op.getMask();
206  Value base = op.getBase();
207 
208  // vector.load requires the most minor memref dim to have unit stride
209  if (auto memType = dyn_cast<MemRefType>(base.getType())) {
210  if (auto stridesAttr =
211  dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
212  if (stridesAttr.getStrides().back() != 1)
213  return failure();
214  }
215  }
216 
217  Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
218  loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
219  op.getIndexVec());
220  auto baseOffsets = llvm::to_vector(op.getIndices());
221  Value lastBaseOffset = baseOffsets.back();
222 
223  Value result = op.getPassThru();
224 
225  // Emit a conditional access for each vector element.
226  for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
227  int64_t thisIdx[1] = {i};
228  Value condition =
229  rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
230  Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
231  baseOffsets.back() =
232  rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
233 
234  auto loadBuilder = [&](OpBuilder &b, Location loc) {
235  Value extracted;
236  if (isa<MemRefType>(base.getType())) {
237  // `vector.load` does not support scalar result; emit a vector load
238  // and extract the single result instead.
239  Value load =
240  b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
241  int64_t zeroIdx[1] = {0};
242  extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
243  } else {
244  extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
245  }
246 
247  Value newResult =
248  b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
249  b.create<scf::YieldOp>(loc, newResult);
250  };
251  auto passThruBuilder = [result](OpBuilder &b, Location loc) {
252  b.create<scf::YieldOp>(loc, result);
253  };
254 
255  result =
256  rewriter
257  .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
258  /*elseBuilder=*/passThruBuilder)
259  .getResult(0);
260  }
261 
262  rewriter.replaceOp(op, result);
263  return success();
264  }
265 };
266 } // namespace
267 
270  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
271  Gather1DToConditionalLoads>(patterns.getContext(), benefit);
272 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
IndexType getIndexType()
Definition: Builders.cpp:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:342
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362