MLIR 22.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1//===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.gather' operation.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/Location.h"
27
28#define DEBUG_TYPE "vector-broadcast-lowering"
29
30using namespace mlir;
31using namespace mlir::vector;
32
33namespace {
34/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
35/// outermost dimension. For example:
36/// ```
37/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
38/// ... into vector<2x3xf32>
39///
40/// ==>
41///
42/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
43/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
44/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
45/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
46/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
47/// ```
48///
49/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
50///
51/// Supports vector types with a fixed leading dimension.
52struct UnrollGather : OpRewritePattern<vector::GatherOp> {
53 using Base::Base;
54
55 LogicalResult matchAndRewrite(vector::GatherOp op,
56 PatternRewriter &rewriter) const override {
57 Value indexVec = op.getIndices();
58 Value maskVec = op.getMask();
59 Value passThruVec = op.getPassThru();
60
61 auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
62 VectorType subTy, int64_t index) {
63 int64_t thisIdx[1] = {index};
64
65 Value indexSubVec =
66 vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
67 Value maskSubVec =
68 vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
69 Value passThruSubVec =
70 vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
71 return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
72 op.getOffsets(), indexSubVec, maskSubVec,
73 passThruSubVec, op.getAlignmentAttr());
74 };
75
76 return unrollVectorOp(op, rewriter, unrollGatherFn);
77 }
78};
79
80/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
81/// MemRef with updated indices that model the strided access.
82///
83/// ```mlir
84/// %subview = memref.subview %M (...)
85/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
86/// %gather = vector.gather %subview[%idxs] (...)
87/// : memref<100xf32, strided<[3]>>
88/// ```
89/// ==>
90/// ```mlir
91/// %collapse_shape = memref.collapse_shape %M (...)
92/// : memref<100x3xf32> into memref<300xf32>
93/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
94/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
95/// : memref<300xf32> (...)
96/// ```
97///
98/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
99/// but should be fairly straightforward to extend beyond that.
100struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
101 using Base::Base;
102
103 LogicalResult matchAndRewrite(vector::GatherOp op,
104 PatternRewriter &rewriter) const override {
105 Value base = op.getBase();
106
107 // TODO: Strided accesses might be coming from other ops as well
108 auto subview = base.getDefiningOp<memref::SubViewOp>();
109 if (!subview)
110 return failure();
111
112 auto sourceType = subview.getSource().getType();
113
114 // TODO: Allow ranks > 2.
115 if (sourceType.getRank() != 2)
116 return failure();
117
118 // Get strides
119 auto layout = subview.getResult().getType().getLayout();
120 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
121 if (!stridedLayoutAttr)
122 return failure();
123
124 // TODO: Allow the access to be strided in multiple dimensions.
125 if (stridedLayoutAttr.getStrides().size() != 1)
126 return failure();
127
128 int64_t srcTrailingDim = sourceType.getShape().back();
129
130 // Assume that the stride matches the trailing dimension of the source
131 // memref.
132 // TODO: Relax this assumption.
133 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
134 return failure();
135
136 // 1. Collapse the input memref so that it's "flat".
137 SmallVector<ReassociationIndices> reassoc = {{0, 1}};
138 Value collapsed = memref::CollapseShapeOp::create(
139 rewriter, op.getLoc(), subview.getSource(), reassoc);
140
141 // 2. Generate new gather indices that will model the
142 // strided access.
143 IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
144 VectorType vType = op.getIndices().getType();
145 Value mulCst = arith::ConstantOp::create(
146 rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
147
148 Value newIdxs =
149 arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
150
151 // 3. Create an updated gather op with the collapsed input memref and the
152 // updated indices.
153 Value newGather = vector::GatherOp::create(
154 rewriter, op.getLoc(), op.getResult().getType(), collapsed,
155 op.getOffsets(), newIdxs, op.getMask(), op.getPassThru(),
156 op.getAlignmentAttr());
157 rewriter.replaceOp(op, newGather);
158
159 return success();
160 }
161};
162
163/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
164/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
165/// loads/extracts are made conditional using `scf.if` ops.
166struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
167 using Base::Base;
168
169 LogicalResult matchAndRewrite(vector::GatherOp op,
170 PatternRewriter &rewriter) const override {
171 VectorType resultTy = op.getType();
172 if (resultTy.getRank() != 1)
173 return rewriter.notifyMatchFailure(op, "unsupported rank");
174
175 if (resultTy.isScalable())
176 return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
177
178 Location loc = op.getLoc();
179 Type elemTy = resultTy.getElementType();
180 // Vector type with a single element. Used to generate `vector.loads`.
181 VectorType elemVecTy = VectorType::get({1}, elemTy);
182
183 Value condMask = op.getMask();
184 Value base = op.getBase();
185
186 // vector.load requires the most minor memref dim to have unit stride
187 // (unless reading exactly 1 element)
188 if (auto memType = dyn_cast<MemRefType>(base.getType())) {
189 if (auto stridesAttr =
190 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
191 if (stridesAttr.getStrides().back() != 1 &&
192 resultTy.getNumElements() != 1)
193 return failure();
194 }
195 }
196
197 Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
198 loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
199 op.getIndices());
200 auto baseOffsets = llvm::to_vector(op.getOffsets());
201 Value lastBaseOffset = baseOffsets.back();
202
203 Value result = op.getPassThru();
204 BoolAttr nontemporalAttr = nullptr;
205 IntegerAttr alignmentAttr = op.getAlignmentAttr();
206
207 // Emit a conditional access for each vector element.
208 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
209 int64_t thisIdx[1] = {i};
210 Value condition =
211 vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
212 Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
213 baseOffsets.back() =
214 rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
215
216 auto loadBuilder = [&](OpBuilder &b, Location loc) {
217 Value extracted;
218 if (isa<MemRefType>(base.getType())) {
219 // `vector.load` does not support scalar result; emit a vector load
220 // and extract the single result instead.
221 Value load =
222 vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets,
223 nontemporalAttr, alignmentAttr);
224 int64_t zeroIdx[1] = {0};
225 extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
226 } else {
227 extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets);
228 }
229
230 Value newResult =
231 vector::InsertOp::create(b, loc, extracted, result, thisIdx);
232 scf::YieldOp::create(b, loc, newResult);
233 };
234 auto passThruBuilder = [result](OpBuilder &b, Location loc) {
235 scf::YieldOp::create(b, loc, result);
236 };
237
238 result = scf::IfOp::create(rewriter, loc, condition,
239 /*thenBuilder=*/loadBuilder,
240 /*elseBuilder=*/passThruBuilder)
241 .getResult(0);
242 }
243
244 rewriter.replaceOp(op, result);
245 return success();
246 }
247};
248} // namespace
249
252 patterns.add<UnrollGather>(patterns.getContext(), benefit);
253}
254
257 patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
258 patterns.getContext(), benefit);
259}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...