MLIR  22.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.gather' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 
28 #define DEBUG_TYPE "vector-broadcast-lowering"
29 
30 using namespace mlir;
31 using namespace mlir::vector;
32 
33 namespace {
34 /// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
35 /// outermost dimension. For example:
36 /// ```
37 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
38 /// ... into vector<2x3xf32>
39 ///
40 /// ==>
41 ///
42 /// %0 = arith.constant dense<0.0> : vector<2x3xf32>
43 /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
44 /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
45 /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
46 /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
47 /// ```
48 ///
49 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
50 ///
51 /// Supports vector types with a fixed leading dimension.
52 struct UnrollGather : OpRewritePattern<vector::GatherOp> {
54 
55  LogicalResult matchAndRewrite(vector::GatherOp op,
56  PatternRewriter &rewriter) const override {
57  Value indexVec = op.getIndices();
58  Value maskVec = op.getMask();
59  Value passThruVec = op.getPassThru();
60 
61  auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
62  VectorType subTy, int64_t index) {
63  int64_t thisIdx[1] = {index};
64 
65  Value indexSubVec =
66  vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
67  Value maskSubVec =
68  vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
69  Value passThruSubVec =
70  vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
71  return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
72  op.getOffsets(), indexSubVec, maskSubVec,
73  passThruSubVec);
74  };
75 
76  return unrollVectorOp(op, rewriter, unrollGatherFn);
77  }
78 };
79 
80 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
81 /// MemRef with updated indices that model the strided access.
82 ///
83 /// ```mlir
84 /// %subview = memref.subview %M (...)
85 /// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
86 /// %gather = vector.gather %subview[%idxs] (...)
87 /// : memref<100xf32, strided<[3]>>
88 /// ```
89 /// ==>
90 /// ```mlir
91 /// %collapse_shape = memref.collapse_shape %M (...)
92 /// : memref<100x3xf32> into memref<300xf32>
93 /// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
94 /// %gather = vector.gather %collapse_shape[%new_idxs] (...)
95 /// : memref<300xf32> (...)
96 /// ```
97 ///
98 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
99 /// but should be fairly straightforward to extend beyond that.
100 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
102 
103  LogicalResult matchAndRewrite(vector::GatherOp op,
104  PatternRewriter &rewriter) const override {
105  Value base = op.getBase();
106 
107  // TODO: Strided accesses might be coming from other ops as well
108  auto subview = base.getDefiningOp<memref::SubViewOp>();
109  if (!subview)
110  return failure();
111 
112  auto sourceType = subview.getSource().getType();
113 
114  // TODO: Allow ranks > 2.
115  if (sourceType.getRank() != 2)
116  return failure();
117 
118  // Get strides
119  auto layout = subview.getResult().getType().getLayout();
120  auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
121  if (!stridedLayoutAttr)
122  return failure();
123 
124  // TODO: Allow the access to be strided in multiple dimensions.
125  if (stridedLayoutAttr.getStrides().size() != 1)
126  return failure();
127 
128  int64_t srcTrailingDim = sourceType.getShape().back();
129 
130  // Assume that the stride matches the trailing dimension of the source
131  // memref.
132  // TODO: Relax this assumption.
133  if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
134  return failure();
135 
136  // 1. Collapse the input memref so that it's "flat".
137  SmallVector<ReassociationIndices> reassoc = {{0, 1}};
138  Value collapsed = memref::CollapseShapeOp::create(
139  rewriter, op.getLoc(), subview.getSource(), reassoc);
140 
141  // 2. Generate new gather indices that will model the
142  // strided access.
143  IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
144  VectorType vType = op.getIndices().getType();
145  Value mulCst = arith::ConstantOp::create(
146  rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
147 
148  Value newIdxs =
149  arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
150 
151  // 3. Create an updated gather op with the collapsed input memref and the
152  // updated indices.
153  Value newGather = vector::GatherOp::create(
154  rewriter, op.getLoc(), op.getResult().getType(), collapsed,
155  op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
156  rewriter.replaceOp(op, newGather);
157 
158  return success();
159  }
160 };
161 
162 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
163 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
164 /// loads/extracts are made conditional using `scf.if` ops.
165 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
167 
168  LogicalResult matchAndRewrite(vector::GatherOp op,
169  PatternRewriter &rewriter) const override {
170  VectorType resultTy = op.getType();
171  if (resultTy.getRank() != 1)
172  return rewriter.notifyMatchFailure(op, "unsupported rank");
173 
174  if (resultTy.isScalable())
175  return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
176 
177  Location loc = op.getLoc();
178  Type elemTy = resultTy.getElementType();
179  // Vector type with a single element. Used to generate `vector.loads`.
180  VectorType elemVecTy = VectorType::get({1}, elemTy);
181 
182  Value condMask = op.getMask();
183  Value base = op.getBase();
184 
185  // vector.load requires the most minor memref dim to have unit stride
186  // (unless reading exactly 1 element)
187  if (auto memType = dyn_cast<MemRefType>(base.getType())) {
188  if (auto stridesAttr =
189  dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
190  if (stridesAttr.getStrides().back() != 1 &&
191  resultTy.getNumElements() != 1)
192  return failure();
193  }
194  }
195 
196  Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
197  loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
198  op.getIndices());
199  auto baseOffsets = llvm::to_vector(op.getOffsets());
200  Value lastBaseOffset = baseOffsets.back();
201 
202  Value result = op.getPassThru();
203 
204  // Emit a conditional access for each vector element.
205  for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
206  int64_t thisIdx[1] = {i};
207  Value condition =
208  vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
209  Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
210  baseOffsets.back() =
211  rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
212 
213  auto loadBuilder = [&](OpBuilder &b, Location loc) {
214  Value extracted;
215  if (isa<MemRefType>(base.getType())) {
216  // `vector.load` does not support scalar result; emit a vector load
217  // and extract the single result instead.
218  Value load =
219  vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets);
220  int64_t zeroIdx[1] = {0};
221  extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
222  } else {
223  extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets);
224  }
225 
226  Value newResult =
227  vector::InsertOp::create(b, loc, extracted, result, thisIdx);
228  scf::YieldOp::create(b, loc, newResult);
229  };
230  auto passThruBuilder = [result](OpBuilder &b, Location loc) {
231  scf::YieldOp::create(b, loc, result);
232  };
233 
234  result = scf::IfOp::create(rewriter, loc, condition,
235  /*thenBuilder=*/loadBuilder,
236  /*elseBuilder=*/passThruBuilder)
237  .getResult(0);
238  }
239 
240  rewriter.replaceOp(op, result);
241  return success();
242  }
243 };
244 } // namespace
245 
248  patterns.add<UnrollGather>(patterns.getContext(), benefit);
249 }
250 
253  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
254  patterns.getContext(), benefit);
255 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319