MLIR  22.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.gather' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 
28 #define DEBUG_TYPE "vector-broadcast-lowering"
29 
30 using namespace mlir;
31 using namespace mlir::vector;
32 
33 namespace {
34 /// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
35 /// outermost dimension. For example:
36 /// ```
37 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
38 /// ... into vector<2x3xf32>
39 ///
40 /// ==>
41 ///
42 /// %0 = arith.constant dense<0.0> : vector<2x3xf32>
43 /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
44 /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
45 /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
46 /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
47 /// ```
48 ///
49 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
50 ///
51 /// Supports vector types with a fixed leading dimension.
52 struct UnrollGather : OpRewritePattern<vector::GatherOp> {
54 
55  LogicalResult matchAndRewrite(vector::GatherOp op,
56  PatternRewriter &rewriter) const override {
57  Value indexVec = op.getIndices();
58  Value maskVec = op.getMask();
59  Value passThruVec = op.getPassThru();
60 
61  auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
62  VectorType subTy, int64_t index) {
63  int64_t thisIdx[1] = {index};
64 
65  Value indexSubVec =
66  vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
67  Value maskSubVec =
68  vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
69  Value passThruSubVec =
70  vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
71  return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
72  op.getOffsets(), indexSubVec, maskSubVec,
73  passThruSubVec, op.getAlignmentAttr());
74  };
75 
76  return unrollVectorOp(op, rewriter, unrollGatherFn);
77  }
78 };
79 
80 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
81 /// MemRef with updated indices that model the strided access.
82 ///
83 /// ```mlir
84 /// %subview = memref.subview %M (...)
85 /// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
86 /// %gather = vector.gather %subview[%idxs] (...)
87 /// : memref<100xf32, strided<[3]>>
88 /// ```
89 /// ==>
90 /// ```mlir
91 /// %collapse_shape = memref.collapse_shape %M (...)
92 /// : memref<100x3xf32> into memref<300xf32>
93 /// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
94 /// %gather = vector.gather %collapse_shape[%new_idxs] (...)
95 /// : memref<300xf32> (...)
96 /// ```
97 ///
98 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
99 /// but should be fairly straightforward to extend beyond that.
100 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
102 
103  LogicalResult matchAndRewrite(vector::GatherOp op,
104  PatternRewriter &rewriter) const override {
105  Value base = op.getBase();
106 
107  // TODO: Strided accesses might be coming from other ops as well
108  auto subview = base.getDefiningOp<memref::SubViewOp>();
109  if (!subview)
110  return failure();
111 
112  auto sourceType = subview.getSource().getType();
113 
114  // TODO: Allow ranks > 2.
115  if (sourceType.getRank() != 2)
116  return failure();
117 
118  // Get strides
119  auto layout = subview.getResult().getType().getLayout();
120  auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
121  if (!stridedLayoutAttr)
122  return failure();
123 
124  // TODO: Allow the access to be strided in multiple dimensions.
125  if (stridedLayoutAttr.getStrides().size() != 1)
126  return failure();
127 
128  int64_t srcTrailingDim = sourceType.getShape().back();
129 
130  // Assume that the stride matches the trailing dimension of the source
131  // memref.
132  // TODO: Relax this assumption.
133  if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
134  return failure();
135 
136  // 1. Collapse the input memref so that it's "flat".
137  SmallVector<ReassociationIndices> reassoc = {{0, 1}};
138  Value collapsed = memref::CollapseShapeOp::create(
139  rewriter, op.getLoc(), subview.getSource(), reassoc);
140 
141  // 2. Generate new gather indices that will model the
142  // strided access.
143  IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
144  VectorType vType = op.getIndices().getType();
145  Value mulCst = arith::ConstantOp::create(
146  rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
147 
148  Value newIdxs =
149  arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
150 
151  // 3. Create an updated gather op with the collapsed input memref and the
152  // updated indices.
153  Value newGather = vector::GatherOp::create(
154  rewriter, op.getLoc(), op.getResult().getType(), collapsed,
155  op.getOffsets(), newIdxs, op.getMask(), op.getPassThru(),
156  op.getAlignmentAttr());
157  rewriter.replaceOp(op, newGather);
158 
159  return success();
160  }
161 };
162 
163 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
164 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
165 /// loads/extracts are made conditional using `scf.if` ops.
166 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
168 
169  LogicalResult matchAndRewrite(vector::GatherOp op,
170  PatternRewriter &rewriter) const override {
171  VectorType resultTy = op.getType();
172  if (resultTy.getRank() != 1)
173  return rewriter.notifyMatchFailure(op, "unsupported rank");
174 
175  if (resultTy.isScalable())
176  return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
177 
178  Location loc = op.getLoc();
179  Type elemTy = resultTy.getElementType();
180  // Vector type with a single element. Used to generate `vector.loads`.
181  VectorType elemVecTy = VectorType::get({1}, elemTy);
182 
183  Value condMask = op.getMask();
184  Value base = op.getBase();
185 
186  // vector.load requires the most minor memref dim to have unit stride
187  // (unless reading exactly 1 element)
188  if (auto memType = dyn_cast<MemRefType>(base.getType())) {
189  if (auto stridesAttr =
190  dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
191  if (stridesAttr.getStrides().back() != 1 &&
192  resultTy.getNumElements() != 1)
193  return failure();
194  }
195  }
196 
197  Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
198  loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
199  op.getIndices());
200  auto baseOffsets = llvm::to_vector(op.getOffsets());
201  Value lastBaseOffset = baseOffsets.back();
202 
203  Value result = op.getPassThru();
204  BoolAttr nontemporalAttr = nullptr;
205  IntegerAttr alignmentAttr = op.getAlignmentAttr();
206 
207  // Emit a conditional access for each vector element.
208  for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
209  int64_t thisIdx[1] = {i};
210  Value condition =
211  vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
212  Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
213  baseOffsets.back() =
214  rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
215 
216  auto loadBuilder = [&](OpBuilder &b, Location loc) {
217  Value extracted;
218  if (isa<MemRefType>(base.getType())) {
219  // `vector.load` does not support scalar result; emit a vector load
220  // and extract the single result instead.
221  Value load =
222  vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets,
223  nontemporalAttr, alignmentAttr);
224  int64_t zeroIdx[1] = {0};
225  extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
226  } else {
227  extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets);
228  }
229 
230  Value newResult =
231  vector::InsertOp::create(b, loc, extracted, result, thisIdx);
232  scf::YieldOp::create(b, loc, newResult);
233  };
234  auto passThruBuilder = [result](OpBuilder &b, Location loc) {
235  scf::YieldOp::create(b, loc, result);
236  };
237 
238  result = scf::IfOp::create(rewriter, loc, condition,
239  /*thenBuilder=*/loadBuilder,
240  /*elseBuilder=*/passThruBuilder)
241  .getResult(0);
242  }
243 
244  rewriter.replaceOp(op, result);
245  return success();
246  }
247 };
248 } // namespace
249 
252  patterns.add<UnrollGather>(patterns.getContext(), benefit);
253 }
254 
257  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
258  patterns.getContext(), benefit);
259 }
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322