MLIR 22.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.shape_cast' operation.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/IR/Location.h"
24#include "llvm/ADT/STLExtras.h"
25#include <numeric>
26
27#define DEBUG_TYPE "vector-shape-cast-lowering"
28
29using namespace mlir;
30
31/// Perform the inplace update
32/// rhs <- lhs + rhs
33///
34/// where `rhs` is a number expressed in mixed base `base` with most signficant
35/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
36/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
37///
38/// Some examples where `base` is {5,3,2}:
39/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
40/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
41/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
42///
43/// Invalid:
44/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
45///
46/// Overflows not handled correctly:
47/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
50
51 // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
52 for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
53 int64_t dimBase = base[dim];
54 assert(rhs[dim] < dimBase && "rhs not in base");
55
56 int64_t incremented = rhs[dim] + lhs;
57
58 // If the incremented value excedes the dimension base, we must spill to the
59 // next most significant dimension and repeat (we might need to spill to
60 // more significant dimensions multiple times).
61 lhs = incremented / dimBase;
62 rhs[dim] = incremented % dimBase;
63 if (lhs == 0)
64 break;
65 }
66}
67
68namespace {
69
70/// shape_cast is converted to a sequence of extract, extract_strided_slice,
71/// insert_strided_slice, and insert operations. The running example will be:
72///
73/// %0 = vector.shape_cast %arg0 :
74/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
75///
76/// In this example the source and result shapes share a common suffix of 7x11.
77/// This means we can always decompose the shape_cast into extract, insert, and
78/// their strided equivalents, on vectors with shape suffix 7x11.
79///
80/// The greatest common divisor (gcd) of the first dimension preceding the
81/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
82/// on vectors with shapes that are `multiples` of (what we define as) the
83/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
84///
85/// vector<2x2x3x4x7x11xi8> to
86/// vector<8x6x7x11xi8>
87/// | ||||
88/// | ++++------------> common suffix of 7x11
89/// +-----------------> gcd(4,6) is 2 | |
90/// | | |
91/// v v v
92/// atomic shape <----- 2x7x11
93///
94///
95///
96/// The decomposition implemented in this pattern consists of a sequence of
97/// repeated steps:
98///
99/// (1) Extract vectors from the suffix of the source.
100/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
101///
102/// (2) Do extract_strided_slice down to the atomic shape.
103/// In our example this is 4x7x11 -> 2x7x11.
104///
105/// (3) Do insert_strided_slice to the suffix of the result.
106/// In our example this is 2x7x11 -> 6x7x11.
107///
108/// (4) insert these vectors into the result vector.
109/// In our example this is 6x7x11 -> 8x6x7x11.
110///
111/// These steps occur with different periods. In this example
112/// (1) occurs 12 times,
113/// (2) and (3) occur 24 times, and
114/// (4) occurs 8 times.
115///
116/// Two special cases are handled independently in this pattern
117/// (i) A shape_cast that just does leading 1 insertion/removal
118/// (ii) A shape_cast where the gcd is 1.
119///
120/// These 2 cases can have more compact IR generated by not using the generic
121/// algorithm described above.
122///
123class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
124
125 // Case (i) of description.
126 // Assumes source and result shapes are identical up to some leading ones.
127 static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
128 PatternRewriter &rewriter) {
129
130 const Location loc = shapeCast.getLoc();
131 const VectorType sourceType = shapeCast.getSourceVectorType();
132 const VectorType resultType = shapeCast.getResultVectorType();
133
134 const int64_t sourceRank = sourceType.getRank();
135 const int64_t resultRank = resultType.getRank();
136 const int64_t delta = sourceRank - resultRank;
137 const int64_t sourceLeading = delta > 0 ? delta : 0;
138 const int64_t resultLeading = delta > 0 ? 0 : -delta;
139
140 const Value source = shapeCast.getSource();
141 const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
142 const Value extracted = vector::ExtractOp::create(
143 rewriter, loc, source, SmallVector<int64_t>(sourceLeading, 0));
144 const Value result =
145 vector::InsertOp::create(rewriter, loc, extracted, poison,
146 SmallVector<int64_t>(resultLeading, 0));
147
148 rewriter.replaceOp(shapeCast, result);
149 return success();
150 }
151
152 // Case (ii) of description.
153 // Assumes a shape_cast where the suffix shape of the source starting at
154 // `sourceDim` and the suffix shape of the result starting at `resultDim` are
155 // identical.
156 static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
157 int64_t sourceDim,
158 int64_t resultDim,
159 PatternRewriter &rewriter) {
160
161 const Location loc = shapeCast.getLoc();
162
163 const Value source = shapeCast.getSource();
164 const ArrayRef<int64_t> sourceShape =
165 shapeCast.getSourceVectorType().getShape();
166
167 const VectorType resultType = shapeCast.getResultVectorType();
168 const ArrayRef<int64_t> resultShape = resultType.getShape();
169
170 const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim));
171 SmallVector<int64_t> extractIndex(sourceDim, 0);
172 SmallVector<int64_t> insertIndex(resultDim, 0);
173 Value result = ub::PoisonOp::create(rewriter, loc, resultType);
174
175 for (int i = 0; i < nSlices; ++i) {
176 Value extracted =
177 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
178
179 result = vector::InsertOp::create(rewriter, loc, extracted, result,
180 insertIndex);
181
182 inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
183 inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
184 }
185 rewriter.replaceOp(shapeCast, result);
186 return success();
187 }
188
189public:
190 using Base::Base;
191
192 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
193 PatternRewriter &rewriter) const override {
194 Location loc = op.getLoc();
195 VectorType sourceType = op.getSourceVectorType();
196 VectorType resultType = op.getResultVectorType();
197
198 if (sourceType.isScalable() || resultType.isScalable())
199 return rewriter.notifyMatchFailure(
200 op,
201 "shape_cast where vectors are scalable not handled by this pattern");
202
203 const ArrayRef<int64_t> sourceShape = sourceType.getShape();
204 const ArrayRef<int64_t> resultShape = resultType.getShape();
205 const int64_t sourceRank = sourceType.getRank();
206 const int64_t resultRank = resultType.getRank();
207 const int64_t numElms = sourceType.getNumElements();
208 const Value source = op.getSource();
209
210 // Set the first dimension (starting at the end) in the source and result
211 // respectively where the dimension sizes differ. Using the running example:
212 //
213 // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
214 // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
215 // ^ ^
216 // | |
217 // sourceSuffixStartDim is 3 |
218 // |
219 // resultSuffixStartDim is 1
220 int64_t sourceSuffixStartDim = sourceRank - 1;
221 int64_t resultSuffixStartDim = resultRank - 1;
222 while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
223 (sourceType.getDimSize(sourceSuffixStartDim) ==
224 resultType.getDimSize(resultSuffixStartDim))) {
225 --sourceSuffixStartDim;
226 --resultSuffixStartDim;
227 }
228
229 // This is the case (i) where there are just some leading ones to contend
230 // with in the source or result. It can be handled with a single
231 // extract/insert pair.
232 if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
233 return leadingOnesLowering(op, rewriter);
234
235 const int64_t sourceSuffixStartDimSize =
236 sourceType.getDimSize(sourceSuffixStartDim);
237 const int64_t resultSuffixStartDimSize =
238 resultType.getDimSize(resultSuffixStartDim);
239 const int64_t greatestCommonDivisor =
240 std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
241 const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
242 const size_t extractPeriod =
243 sourceSuffixStartDimSize / greatestCommonDivisor;
244 const size_t insertPeriod =
245 resultSuffixStartDimSize / greatestCommonDivisor;
246
247 SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
248 sourceShape.end());
249 atomicShape[0] = greatestCommonDivisor;
250
251 const int64_t numAtomicElms = std::accumulate(
252 atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
253 const size_t nAtomicSlices = numElms / numAtomicElms;
254
255 // This is the case (ii) where the strided dimension size is 1. More compact
256 // IR is generated in this case if we just extract and insert the elements
257 // directly. In other words, we don't use extract_strided_slice and
258 // insert_strided_slice.
259 if (greatestCommonDivisor == 1)
260 return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
261 resultSuffixStartDim + 1, rewriter);
262
263 // The insert_strided_slice result's type
264 const ArrayRef<int64_t> insertStridedShape =
265 resultShape.drop_front(resultSuffixStartDim);
266 const VectorType insertStridedType =
267 VectorType::get(insertStridedShape, resultType.getElementType());
268
269 SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
270 SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
271 SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
272 SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
273 const SmallVector<int64_t> sizes(stridedSliceRank, 1);
274
275 Value extracted = {};
276 Value extractedStrided = {};
277 Value insertedSlice = {};
278 Value result = ub::PoisonOp::create(rewriter, loc, resultType);
279 const Value partResult =
280 ub::PoisonOp::create(rewriter, loc, insertStridedType);
281
282 for (size_t i = 0; i < nAtomicSlices; ++i) {
283
284 const size_t extractStridedPhase = i % extractPeriod;
285 const size_t insertStridedPhase = i % insertPeriod;
286
287 // vector.extract
288 if (extractStridedPhase == 0) {
289 extracted =
290 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
291 inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
292 extractIndex);
293 }
294
295 // vector.extract_strided_slice
296 extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
297 extractedStrided = vector::ExtractStridedSliceOp::create(
298 rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
299
300 // vector.insert_strided_slice
301 if (insertStridedPhase == 0) {
302 insertedSlice = partResult;
303 }
304 insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
305 insertedSlice = vector::InsertStridedSliceOp::create(
306 rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
307
308 // vector.insert
309 if (insertStridedPhase + 1 == insertPeriod) {
310 result = vector::InsertOp::create(rewriter, loc, insertedSlice, result,
311 insertIndex);
312 inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
313 insertIndex);
314 }
315 }
316 rewriter.replaceOp(op, result);
317 return success();
318 }
319};
320
321/// A shape_cast lowering for scalable vectors with a single trailing scalable
322/// dimension. This is similar to the general shape_cast lowering but makes use
323/// of vector.scalable.insert and vector.scalable.extract to move elements a
324/// subvector at a time.
325///
326/// E.g.:
327/// ```
328/// // Flatten scalable vector
329/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
330/// ```
331/// is rewritten to:
332/// ```
333/// // Flatten scalable vector
334/// %c = arith.constant dense<0> : vector<[8]xi32>
335/// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
336/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
337/// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
338/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
339/// ```
340/// or:
341/// ```
342/// // Un-flatten scalable vector
343/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
344/// ```
345/// is rewritten to:
346/// ```
347/// // Un-flatten scalable vector
348/// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
349/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
350/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
351/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
352/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
353/// ```
354class ScalableShapeCastOpRewritePattern
355 : public OpRewritePattern<vector::ShapeCastOp> {
356public:
357 using Base::Base;
358
359 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
360 PatternRewriter &rewriter) const override {
361
362 Location loc = op.getLoc();
363 auto sourceVectorType = op.getSourceVectorType();
364 auto resultVectorType = op.getResultVectorType();
365 auto srcRank = sourceVectorType.getRank();
366 auto resRank = resultVectorType.getRank();
367
368 // This can only lower shape_casts where both the source and result types
369 // have a single trailing scalable dimension. This is because there are no
370 // legal representation of other scalable types in LLVM (and likely won't be
371 // soon). There are also (currently) no operations that can index or extract
372 // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
373 if (!isTrailingDimScalable(sourceVectorType) ||
374 !isTrailingDimScalable(resultVectorType)) {
375 return rewriter.notifyMatchFailure(
376 op, "trailing dims are not scalable, not handled by this pattern");
377 }
378
379 // The sizes of the trailing dimension of the source and result vectors, the
380 // size of subvector to move, and the number of elements in the vectors.
381 // These are "min" sizes as they are the size when vscale == 1.
382 auto minSourceTrailingSize = sourceVectorType.getShape().back();
383 auto minResultTrailingSize = resultVectorType.getShape().back();
384 auto minExtractionSize =
385 std::min(minSourceTrailingSize, minResultTrailingSize);
386 int64_t minNumElts = 1;
387 for (auto size : sourceVectorType.getShape())
388 minNumElts *= size;
389
390 // The subvector type to move from the source to the result. Note that this
391 // is a scalable vector. This rewrite will generate code in terms of the
392 // "min" size (vscale == 1 case), that scales to any vscale.
393 auto extractionVectorType = VectorType::get(
394 {minExtractionSize}, sourceVectorType.getElementType(), {true});
395
396 Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
397 SmallVector<int64_t> srcIdx(srcRank, 0);
398 SmallVector<int64_t> resIdx(resRank, 0);
399
400 // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
401 // once D150000 lands.
402 Value currentResultScalableVector;
403 Value currentSourceScalableVector;
404 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
405 // 1. Extract a scalable subvector from the source vector.
406 if (!currentSourceScalableVector) {
407 if (srcRank != 1) {
408 currentSourceScalableVector =
409 vector::ExtractOp::create(rewriter, loc, op.getSource(),
410 llvm::ArrayRef(srcIdx).drop_back());
411 } else {
412 currentSourceScalableVector = op.getSource();
413 }
414 }
415 Value sourceSubVector = currentSourceScalableVector;
416 if (minExtractionSize < minSourceTrailingSize) {
417 sourceSubVector = vector::ScalableExtractOp::create(
418 rewriter, loc, extractionVectorType, sourceSubVector,
419 srcIdx.back());
420 }
421
422 // 2. Insert the scalable subvector into the result vector.
423 if (!currentResultScalableVector) {
424 if (minExtractionSize == minResultTrailingSize) {
425 currentResultScalableVector = sourceSubVector;
426 } else if (resRank != 1) {
427 currentResultScalableVector = vector::ExtractOp::create(
428 rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back());
429 } else {
430 currentResultScalableVector = result;
431 }
432 }
433 if (minExtractionSize < minResultTrailingSize) {
434 currentResultScalableVector = vector::ScalableInsertOp::create(
435 rewriter, loc, sourceSubVector, currentResultScalableVector,
436 resIdx.back());
437 }
438
439 // 3. Update the source and result scalable vectors if needed.
440 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
441 currentResultScalableVector != result) {
442 // Finished row of result. Insert complete scalable vector into result
443 // (n-D) vector.
444 result = vector::InsertOp::create(rewriter, loc,
445 currentResultScalableVector, result,
446 llvm::ArrayRef(resIdx).drop_back());
447 currentResultScalableVector = {};
448 }
449 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
450 // Finished row of source.
451 currentSourceScalableVector = {};
452 }
453
454 // 4. Increment the insert/extract indices, stepping by minExtractionSize
455 // for the trailing dimensions.
456 inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
457 inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
458 }
459
460 rewriter.replaceOp(op, result);
461 return success();
462 }
463
464 static bool isTrailingDimScalable(VectorType type) {
465 return type.getRank() >= 1 && type.getScalableDims().back() &&
466 !llvm::is_contained(type.getScalableDims().drop_back(), true);
467 }
468};
469
470} // namespace
471
474 patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
475 patterns.getContext(), benefit);
476}
return success()
lhs
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...