MLIR 23.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1//===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.gather' operation.
11//
12//===----------------------------------------------------------------------===//
13
25#include "mlir/IR/Location.h"
28
29#define DEBUG_TYPE "vector-broadcast-lowering"
30
31using namespace mlir;
32using namespace mlir::vector;
33
34namespace {
35/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
36/// outermost dimension. For example:
37/// ```
38/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
39/// ... into vector<2x3xf32>
40///
41/// ==>
42///
43/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
44/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
45/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
46/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
47/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
48/// ```
49///
50/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
51///
52/// Supports vector types with a fixed leading dimension.
53struct UnrollGather : OpRewritePattern<vector::GatherOp> {
54 using Base::Base;
55
56 LogicalResult matchAndRewrite(vector::GatherOp op,
57 PatternRewriter &rewriter) const override {
58 Value indexVec = op.getIndices();
59 Value maskVec = op.getMask();
60 Value passThruVec = op.getPassThru();
61
62 auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
63 VectorType subTy, int64_t index) {
64 int64_t thisIdx[1] = {index};
65
66 Value indexSubVec =
67 vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
68 Value maskSubVec =
69 vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
70 Value passThruSubVec =
71 vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
72 return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
73 op.getOffsets(), indexSubVec, maskSubVec,
74 passThruSubVec, op.getAlignmentAttr());
75 };
76
77 return unrollVectorOp(op, rewriter, unrollGatherFn);
78 }
79};
80
81/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
82/// MemRef with updated indices that model the strided access.
83///
84/// ```mlir
85/// %subview = memref.subview %M (...)
86/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
87/// %gather = vector.gather %subview[%idxs] (...)
88/// : memref<100xf32, strided<[3]>>
89/// ```
90/// ==>
91/// ```mlir
92/// %collapse_shape = memref.collapse_shape %M (...)
93/// : memref<100x3xf32> into memref<300xf32>
94/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
95/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
96/// : memref<300xf32> (...)
97/// ```
98///
99/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
100/// but should be fairly straightforward to extend beyond that.
101struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
102 using Base::Base;
103
104 LogicalResult matchAndRewrite(vector::GatherOp op,
105 PatternRewriter &rewriter) const override {
106 Value base = op.getBase();
107
108 // TODO: Strided accesses might be coming from other ops as well
109 auto subview = base.getDefiningOp<memref::SubViewOp>();
110 if (!subview)
111 return failure();
112
113 auto sourceType = subview.getSource().getType();
114
115 // TODO: Allow ranks > 2.
116 if (sourceType.getRank() != 2)
117 return failure();
118
119 // Get strides
120 auto layout = subview.getResult().getType().getLayout();
121 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
122 if (!stridedLayoutAttr)
123 return failure();
124
125 // TODO: Allow the access to be strided in multiple dimensions.
126 if (stridedLayoutAttr.getStrides().size() != 1)
127 return failure();
128
129 int64_t srcTrailingDim = sourceType.getShape().back();
130
131 // Assume that the stride matches the trailing dimension of the source
132 // memref.
133 // TODO: Relax this assumption.
134 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
135 return failure();
136
137 // 1. Collapse the input memref so that it's "flat".
138 SmallVector<ReassociationIndices> reassoc = {{0, 1}};
139 Value collapsed = memref::CollapseShapeOp::create(
140 rewriter, op.getLoc(), subview.getSource(), reassoc);
141
142 // 2. Generate new gather indices that will model the
143 // strided access.
144 IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
145 VectorType vType = op.getIndices().getType();
146 Value mulCst = arith::ConstantOp::create(
147 rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
148
149 Value newIdxs =
150 arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
151
152 // 3. Create an updated gather op with the collapsed input memref and the
153 // updated indices.
154 Value newGather = vector::GatherOp::create(
155 rewriter, op.getLoc(), op.getResult().getType(), collapsed,
156 op.getOffsets(), newIdxs, op.getMask(), op.getPassThru(),
157 op.getAlignmentAttr());
158 rewriter.replaceOp(op, newGather);
159
160 return success();
161 }
162};
163
164/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
165/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
166/// loads/extracts are made conditional using `scf.if` ops.
167///
168/// For multi-dimensional memrefs (rank > 1), the gather index is combined
169/// with the offsets via linearize-then-delinearize to produce correct
170/// N-D load indices:
171/// idx = indices[i]
172/// flatIdx = linearize(offsets, memrefShape) + idx
173/// loadIndices = delinearize(flatIdx, memrefShape)
174struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
175 using Base::Base;
176
177 LogicalResult matchAndRewrite(vector::GatherOp op,
178 PatternRewriter &rewriter) const override {
179 VectorType resultTy = op.getType();
180 if (resultTy.getRank() != 1)
181 return rewriter.notifyMatchFailure(op, "unsupported rank");
182
183 if (resultTy.isScalable())
184 return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
185
186 Location loc = op.getLoc();
187 Type elemTy = resultTy.getElementType();
188 // Vector type with a single element. Used to generate `vector.loads`.
189 VectorType elemVecTy = VectorType::get({1}, elemTy);
190
191 Value condMask = op.getMask();
192 Value base = op.getBase();
193
194 // For multi-dimensional memrefs, use linearize+delinearize to compute
195 // correct N-D load indices from the 1-D gather index.
196 bool useDelinearization = false;
197 if (auto memType = dyn_cast<MemRefType>(base.getType())) {
198 // vector.load requires the most minor memref dim to have unit stride
199 // (unless reading exactly 1 element).
200 if (auto stridesAttr =
201 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
202 if (stridesAttr.getStrides().back() != 1 &&
203 resultTy.getNumElements() != 1)
204 return rewriter.notifyMatchFailure(
205 op, "most minor memref dim must have unit stride");
206 }
207
208 if (memType.getRank() > 1)
209 useDelinearization = true;
210 }
211
212 Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
213 loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
214 op.getIndices());
215 auto loadOffsets = llvm::to_vector(op.getOffsets());
216 Value lastLoadOffset = loadOffsets.back();
217
218 // Compute the memref shape and linearized offsets once, outside the
219 // per-element loop.
221 Value linearizedOffsets;
222 if (useDelinearization) {
223 baseShape = memref::getMixedSizes(rewriter, loc, base);
224 linearizedOffsets = affine::AffineLinearizeIndexOp::create(
225 rewriter, loc, loadOffsets, baseShape, /*disjoint=*/false);
226 }
227
228 Value result = op.getPassThru();
229 BoolAttr nontemporalAttr = nullptr;
230 IntegerAttr alignmentAttr = op.getAlignmentAttr();
231
232 // Emit a conditional access for each vector element.
233 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
234 int64_t thisIdx[1] = {i};
235 Value condition =
236 vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
237 Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
238
239 if (useDelinearization) {
240 // The gather index offsets the innermost dimension. Combine with
241 // the offsets by linearizing, adding the gather index, then
242 // delinearizing back to N-D indices:
243 // flatIdx = linearize(offsets, shape) + idx
244 // loadIndices = delinearize(flatIdx, shape)
245 Value flatIdx =
246 rewriter.createOrFold<arith::AddIOp>(loc, linearizedOffsets, index);
247 auto delinOp = affine::AffineDelinearizeIndexOp::create(
248 rewriter, loc, flatIdx, baseShape, /*hasOuterBound=*/true);
249 for (int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
250 loadOffsets[d] = delinOp.getResult(d);
251 } else {
252 loadOffsets.back() =
253 rewriter.createOrFold<arith::AddIOp>(loc, lastLoadOffset, index);
254 }
255
256 auto loadBuilder = [&](OpBuilder &b, Location loc) {
257 Value extracted;
258 if (isa<MemRefType>(base.getType())) {
259 // `vector.load` does not support scalar result; emit a vector load
260 // and extract the single result instead.
261 Value load =
262 vector::LoadOp::create(b, loc, elemVecTy, base, loadOffsets,
263 nontemporalAttr, alignmentAttr);
264 int64_t zeroIdx[1] = {0};
265 extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
266 } else {
267 extracted = tensor::ExtractOp::create(b, loc, base, loadOffsets);
268 }
269
270 Value newResult =
271 vector::InsertOp::create(b, loc, extracted, result, thisIdx);
272 scf::YieldOp::create(b, loc, newResult);
273 };
274 auto passThruBuilder = [result](OpBuilder &b, Location loc) {
275 scf::YieldOp::create(b, loc, result);
276 };
277
278 result = scf::IfOp::create(rewriter, loc, condition,
279 /*thenBuilder=*/loadBuilder,
280 /*elseBuilder=*/passThruBuilder)
281 .getResult(0);
282 }
283
284 rewriter.replaceOp(op, result);
285 return success();
286 }
287};
288} // namespace
289
291 RewritePatternSet &patterns, PatternBenefit benefit) {
292 patterns.add<UnrollGather>(patterns.getContext(), benefit);
293}
294
296 RewritePatternSet &patterns, PatternBenefit benefit) {
297 patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
298 patterns.getContext(), benefit);
299}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...