MLIR 23.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1//===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.gather' operation.
11//
12//===----------------------------------------------------------------------===//
13
25#include "mlir/IR/Location.h"
28
29#define DEBUG_TYPE "vector-broadcast-lowering"
30
31using namespace mlir;
32using namespace mlir::vector;
33
34namespace {
35/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
36/// outermost dimension. For example:
37/// ```
38/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
39/// ... into vector<2x3xf32>
40///
41/// ==>
42///
43/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
44/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
45/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
46/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
47/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
48/// ```
49///
50/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
51///
52/// Supports vector types with a fixed leading dimension.
53struct UnrollGather : OpRewritePattern<vector::GatherOp> {
54 using Base::Base;
55
56 LogicalResult matchAndRewrite(vector::GatherOp op,
57 PatternRewriter &rewriter) const override {
58 Value indexVec = op.getIndices();
59 Value maskVec = op.getMask();
60 Value passThruVec = op.getPassThru();
61
62 auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
63 VectorType subTy, int64_t index) {
64 int64_t thisIdx[1] = {index};
65
66 Value indexSubVec =
67 vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
68 Value maskSubVec =
69 vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
70 Value passThruSubVec =
71 vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
72 return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
73 op.getOffsets(), indexSubVec, maskSubVec,
74 passThruSubVec, op.getAlignmentAttr());
75 };
76
77 return unrollVectorOp(op, rewriter, unrollGatherFn);
78 }
79};
80
81/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
82/// MemRef with updated offsets/indices that model the strided access.
83///
84/// ```mlir
85/// %subview = memref.subview %M[%i, %j] [100, 1] [1, 1]
86/// : memref<100x3xf32> to memref<100xf32, strided<[3], offset: ?>>
87/// %gather = vector.gather %subview[%c0] [%idxs] (...)
88/// : memref<100xf32, strided<[3], offset: ?>>
89/// ```
90/// ==>
91/// ```mlir
92/// %collapse_shape = memref.collapse_shape %M (...)
93/// : memref<100x3xf32> into memref<300xf32>
94/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
95/// %new_off = arith.addi %c0_scaled, %subview_offset : index
96/// %gather = vector.gather %collapse_shape[%new_off] [%new_idxs] (...)
97/// : memref<300xf32> (...)
98/// ```
99///
100/// The subview's static offset (the linearized position of the first element
101/// in the source memref) must be folded into the gather's base offsets, so a
102/// subview that selects e.g. column `j_sub` of a row-major `MxN` memref still
103/// reads from `M_base + j_sub + idx * N` instead of `M_base + idx * N`.
104///
105/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
106/// but should be fairly straightforward to extend beyond that.
107struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
108 using Base::Base;
109
110 LogicalResult matchAndRewrite(vector::GatherOp op,
111 PatternRewriter &rewriter) const override {
112 Value base = op.getBase();
113
114 // TODO: Strided accesses might be coming from other ops as well
115 auto subview = base.getDefiningOp<memref::SubViewOp>();
116 if (!subview)
117 return failure();
118
119 auto sourceType = subview.getSource().getType();
120
121 // TODO: Allow ranks > 2.
122 if (sourceType.getRank() != 2)
123 return failure();
124
125 // Get strides
126 auto layout = subview.getResult().getType().getLayout();
127 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
128 if (!stridedLayoutAttr)
129 return failure();
130
131 // TODO: Allow the access to be strided in multiple dimensions.
132 if (stridedLayoutAttr.getStrides().size() != 1)
133 return failure();
134
135 int64_t srcTrailingDim = sourceType.getShape().back();
136
137 // Assume that the stride matches the trailing dimension of the source
138 // memref.
139 // TODO: Relax this assumption.
140 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
141 return failure();
142
143 // The result memref's offset is the linearized position of the subview's
144 // first element within the source memref. Bail out on dynamic offsets so
145 // we don't have to materialize them; the conditional-load fallback will
146 // still produce correct code.
147 // TODO: Support dynamic offsets.
148 int64_t subviewOffset = stridedLayoutAttr.getOffset();
149 if (ShapedType::isDynamic(subviewOffset))
150 return failure();
151
152 // 1. Collapse the input memref so that it's "flat".
153 SmallVector<ReassociationIndices> reassoc = {{0, 1}};
154 Value collapsed = memref::CollapseShapeOp::create(
155 rewriter, op.getLoc(), subview.getSource(), reassoc);
156
157 // 2. Generate new gather indices that will model the strided access.
158 // Take `memref<4xf32, strided<[3], offset: 1>>` and lane k as an example.
159 // For the rewrite to be correct, the flat positions must match:
160 // new_off + new_idxs[k] = 1 + (base_off + idxs[k]) * 3
161 // = 1 + base_off * 3 + idxs[k] * 3
162 // So the newIdxs is scaled with the stride.
163 IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
164 VectorType vType = op.getIndices().getType();
165 Value mulCst = arith::ConstantOp::create(
166 rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
167 Value newIdxs =
168 arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
169
170 // 3. Linearize the gather's base offsets through the source memref. On the
171 // collapsed memref the trailing offset must be scaled by the source's
172 // trailing dim and shifted by the subview's static offset.
173 // Pick new_idxs[k] = idxs[k] * 3 (that's step 2), and solve for new_off:
174 // new_off = 1 + base_off * 3
175 // = subview_offset + base_off * stride
176 // Note that createOrFold collapses the muli/addi when the trailing offset
177 // is a constant zero or the subview offset is zero.
178 SmallVector<Value> newOffsets(op.getOffsets());
179 Value strideVal =
180 arith::ConstantIndexOp::create(rewriter, op.getLoc(), srcTrailingDim);
181 newOffsets.back() = rewriter.createOrFold<arith::MulIOp>(
182 op.getLoc(), newOffsets.back(), strideVal);
183 Value subviewOffsetValue =
184 arith::ConstantIndexOp::create(rewriter, op.getLoc(), subviewOffset);
185 newOffsets.back() = rewriter.createOrFold<arith::AddIOp>(
186 op.getLoc(), newOffsets.back(), subviewOffsetValue);
187
188 // 4. Create an updated gather op with the collapsed input memref and the
189 // updated offsets/indices.
190 Value newGather = vector::GatherOp::create(
191 rewriter, op.getLoc(), op.getResult().getType(), collapsed, newOffsets,
192 newIdxs, op.getMask(), op.getPassThru(), op.getAlignmentAttr());
193 rewriter.replaceOp(op, newGather);
194
195 return success();
196 }
197};
198
199/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
200/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
201/// loads/extracts are made conditional using `scf.if` ops.
202///
203/// For multi-dimensional memrefs (rank > 1), the gather index is combined
204/// with the offsets via linearize-then-delinearize to produce correct
205/// N-D load indices:
206/// idx = indices[i]
207/// flatIdx = linearize(offsets, memrefShape) + idx
208/// loadIndices = delinearize(flatIdx, memrefShape)
209struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
210 using Base::Base;
211
212 LogicalResult matchAndRewrite(vector::GatherOp op,
213 PatternRewriter &rewriter) const override {
214 VectorType resultTy = op.getType();
215 if (resultTy.getRank() != 1)
216 return rewriter.notifyMatchFailure(op, "unsupported rank");
217
218 if (resultTy.isScalable())
219 return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
220
221 Location loc = op.getLoc();
222 Type elemTy = resultTy.getElementType();
223 // Vector type with a single element. Used to generate `vector.loads`.
224 VectorType elemVecTy = VectorType::get({1}, elemTy);
225
226 Value condMask = op.getMask();
227 Value base = op.getBase();
228
229 // For multi-dimensional memrefs, use linearize+delinearize to compute
230 // correct N-D load indices from the 1-D gather index.
231 bool useDelinearization = false;
232 if (auto memType = dyn_cast<MemRefType>(base.getType())) {
233 // vector.load requires the most minor memref dim to have unit stride
234 // (unless reading exactly 1 element).
235 if (auto stridesAttr =
236 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
237 if (stridesAttr.getStrides().back() != 1 &&
238 resultTy.getNumElements() != 1)
239 return rewriter.notifyMatchFailure(
240 op, "most minor memref dim must have unit stride");
241 }
242
243 if (memType.getRank() > 1)
244 useDelinearization = true;
245 }
246
247 Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
248 loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
249 op.getIndices());
250 auto loadOffsets = llvm::to_vector(op.getOffsets());
251 Value lastLoadOffset = loadOffsets.back();
252
253 // Compute the memref shape and linearized offsets once, outside the
254 // per-element loop.
256 Value linearizedOffsets;
257 if (useDelinearization) {
258 baseShape = memref::getMixedSizes(rewriter, loc, base);
259 linearizedOffsets = affine::AffineLinearizeIndexOp::create(
260 rewriter, loc, loadOffsets, baseShape, /*disjoint=*/false);
261 }
262
263 Value result = op.getPassThru();
264 BoolAttr nontemporalAttr = nullptr;
265 IntegerAttr alignmentAttr = op.getAlignmentAttr();
266
267 // Emit a conditional access for each vector element.
268 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
269 int64_t thisIdx[1] = {i};
270 Value condition =
271 vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
272 Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
273
274 if (useDelinearization) {
275 // The gather index offsets the innermost dimension. Combine with
276 // the offsets by linearizing, adding the gather index, then
277 // delinearizing back to N-D indices:
278 // flatIdx = linearize(offsets, shape) + idx
279 // loadIndices = delinearize(flatIdx, shape)
280 Value flatIdx =
281 rewriter.createOrFold<arith::AddIOp>(loc, linearizedOffsets, index);
282 auto delinOp = affine::AffineDelinearizeIndexOp::create(
283 rewriter, loc, flatIdx, baseShape, /*hasOuterBound=*/true);
284 for (int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
285 loadOffsets[d] = delinOp.getResult(d);
286 } else {
287 loadOffsets.back() =
288 rewriter.createOrFold<arith::AddIOp>(loc, lastLoadOffset, index);
289 }
290
291 auto loadBuilder = [&](OpBuilder &b, Location loc) {
292 Value extracted;
293 if (isa<MemRefType>(base.getType())) {
294 // `vector.load` does not support scalar result; emit a vector load
295 // and extract the single result instead.
296 Value load =
297 vector::LoadOp::create(b, loc, elemVecTy, base, loadOffsets,
298 nontemporalAttr, alignmentAttr);
299 int64_t zeroIdx[1] = {0};
300 extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
301 } else {
302 extracted = tensor::ExtractOp::create(b, loc, base, loadOffsets);
303 }
304
305 Value newResult =
306 vector::InsertOp::create(b, loc, extracted, result, thisIdx);
307 scf::YieldOp::create(b, loc, newResult);
308 };
309 auto passThruBuilder = [result](OpBuilder &b, Location loc) {
310 scf::YieldOp::create(b, loc, result);
311 };
312
313 result = scf::IfOp::create(rewriter, loc, condition,
314 /*thenBuilder=*/loadBuilder,
315 /*elseBuilder=*/passThruBuilder)
316 .getResult(0);
317 }
318
319 rewriter.replaceOp(op, result);
320 return success();
321 }
322};
323} // namespace
324
326 RewritePatternSet &patterns, PatternBenefit benefit) {
327 patterns.add<UnrollGather>(patterns.getContext(), benefit);
328}
329
331 RewritePatternSet &patterns, PatternBenefit benefit) {
332 patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
333 patterns.getContext(), benefit);
334}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...