MLIR  19.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.gather' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-broadcast-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 namespace {
42 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
43 /// outermost dimension. For example:
44 /// ```
45 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
46 /// ... into vector<2x3xf32>
47 ///
48 /// ==>
49 ///
50 /// %0 = arith.constant dense<0.0> : vector<2x3xf32>
51 /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
52 /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
53 /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
54 /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
55 /// ```
56 ///
57 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
58 struct FlattenGather : OpRewritePattern<vector::GatherOp> {
60 
61  LogicalResult matchAndRewrite(vector::GatherOp op,
62  PatternRewriter &rewriter) const override {
63  VectorType resultTy = op.getType();
64  if (resultTy.getRank() < 2)
65  return rewriter.notifyMatchFailure(op, "already flat");
66 
67  Location loc = op.getLoc();
68  Value indexVec = op.getIndexVec();
69  Value maskVec = op.getMask();
70  Value passThruVec = op.getPassThru();
71 
72  Value result = rewriter.create<arith::ConstantOp>(
73  loc, resultTy, rewriter.getZeroAttr(resultTy));
74 
75  Type subTy = VectorType::get(resultTy.getShape().drop_front(),
76  resultTy.getElementType());
77 
78  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
79  int64_t thisIdx[1] = {i};
80 
81  Value indexSubVec =
82  rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
83  Value maskSubVec =
84  rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
85  Value passThruSubVec =
86  rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
87  Value subGather = rewriter.create<vector::GatherOp>(
88  loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
89  passThruSubVec);
90  result =
91  rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
92  }
93 
94  rewriter.replaceOp(op, result);
95  return success();
96  }
97 };
98 
99 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
100 /// MemRef with updated indices that model the strided access.
101 ///
102 /// ```mlir
103 /// %subview = memref.subview %M (...)
104 /// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
105 /// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
106 /// ```
107 /// ==>
108 /// ```mlir
109 /// %collapse_shape = memref.collapse_shape %M (...)
110 /// : memref<100x3xf32> into memref<300xf32>
111 /// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
112 /// %gather = vector.gather %collapse_shape[%new_idxs] (...)
113 /// : memref<300xf32> (...)
114 /// ```
115 ///
116 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
117 /// but should be fairly straightforward to extend beyond that.
118 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
120 
121  LogicalResult matchAndRewrite(vector::GatherOp op,
122  PatternRewriter &rewriter) const override {
123  Value base = op.getBase();
124 
125  // TODO: Strided accesses might be coming from other ops as well
126  auto subview = base.getDefiningOp<memref::SubViewOp>();
127  if (!subview)
128  return failure();
129 
130  auto sourceType = subview.getSource().getType();
131 
132  // TODO: Allow ranks > 2.
133  if (sourceType.getRank() != 2)
134  return failure();
135 
136  // Get strides
137  auto layout = subview.getResult().getType().getLayout();
138  auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
139  if (!stridedLayoutAttr)
140  return failure();
141 
142  // TODO: Allow the access to be strided in multiple dimensions.
143  if (stridedLayoutAttr.getStrides().size() != 1)
144  return failure();
145 
146  int64_t srcTrailingDim = sourceType.getShape().back();
147 
148  // Assume that the stride matches the trailing dimension of the source
149  // memref.
150  // TODO: Relax this assumption.
151  if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
152  return failure();
153 
154  // 1. Collapse the input memref so that it's "flat".
155  SmallVector<ReassociationIndices> reassoc = {{0, 1}};
156  Value collapsed = rewriter.create<memref::CollapseShapeOp>(
157  op.getLoc(), subview.getSource(), reassoc);
158 
159  // 2. Generate new gather indices that will model the
160  // strided access.
161  IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
162  VectorType vType = op.getIndexVec().getType();
163  Value mulCst = rewriter.create<arith::ConstantOp>(
164  op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
165 
166  Value newIdxs =
167  rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
168 
169  // 3. Create an updated gather op with the collapsed input memref and the
170  // updated indices.
171  Value newGather = rewriter.create<vector::GatherOp>(
172  op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
173  newIdxs, op.getMask(), op.getPassThru());
174  rewriter.replaceOp(op, newGather);
175 
176  return success();
177  }
178 };
179 
180 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
181 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
182 /// loads/extracts are made conditional using `scf.if` ops.
183 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
185 
186  LogicalResult matchAndRewrite(vector::GatherOp op,
187  PatternRewriter &rewriter) const override {
188  VectorType resultTy = op.getType();
189  if (resultTy.getRank() != 1)
190  return rewriter.notifyMatchFailure(op, "unsupported rank");
191 
192  Location loc = op.getLoc();
193  Type elemTy = resultTy.getElementType();
194  // Vector type with a single element. Used to generate `vector.loads`.
195  VectorType elemVecTy = VectorType::get({1}, elemTy);
196 
197  Value condMask = op.getMask();
198  Value base = op.getBase();
199 
200  // vector.load requires the most minor memref dim to have unit stride
201  if (auto memType = dyn_cast<MemRefType>(base.getType())) {
202  if (auto stridesAttr =
203  dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
204  if (stridesAttr.getStrides().back() != 1)
205  return failure();
206  }
207  }
208 
209  Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
210  loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
211  op.getIndexVec());
212  auto baseOffsets = llvm::to_vector(op.getIndices());
213  Value lastBaseOffset = baseOffsets.back();
214 
215  Value result = op.getPassThru();
216 
217  // Emit a conditional access for each vector element.
218  for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
219  int64_t thisIdx[1] = {i};
220  Value condition =
221  rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
222  Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
223  baseOffsets.back() =
224  rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
225 
226  auto loadBuilder = [&](OpBuilder &b, Location loc) {
227  Value extracted;
228  if (isa<MemRefType>(base.getType())) {
229  // `vector.load` does not support scalar result; emit a vector load
230  // and extract the single result instead.
231  Value load =
232  b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
233  int64_t zeroIdx[1] = {0};
234  extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
235  } else {
236  extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
237  }
238 
239  Value newResult =
240  b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
241  b.create<scf::YieldOp>(loc, newResult);
242  };
243  auto passThruBuilder = [result](OpBuilder &b, Location loc) {
244  b.create<scf::YieldOp>(loc, result);
245  };
246 
247  result =
248  rewriter
249  .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
250  /*elseBuilder=*/passThruBuilder)
251  .getResult(0);
252  }
253 
254  rewriter.replaceOp(op, result);
255  return success();
256  }
257 };
258 } // namespace
259 
261  RewritePatternSet &patterns, PatternBenefit benefit) {
262  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
263  Gather1DToConditionalLoads>(patterns.getContext(), benefit);
264 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362