MLIR 22.0.0git
LowerVectorTranspose.cpp
Go to the documentation of this file.
1//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.transpose' operation.
11//
12//===----------------------------------------------------------------------===//
13
22#include "mlir/IR/Location.h"
25
26#define DEBUG_TYPE "lower-vector-transpose"
27
28using namespace mlir;
29using namespace mlir::vector;
30
31/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
32/// transposed.
35 size_t numTransposedDims = transpose.size();
36 for (size_t transpDim : llvm::reverse(transpose)) {
37 if (transpDim != numTransposedDims - 1)
38 break;
39 numTransposedDims--;
40 }
41
42 result.append(transpose.begin(), transpose.begin() + numTransposedDims);
43}
44
45/// Returns true if the lowering option is a vector shuffle based approach.
46static bool isShuffleLike(VectorTransposeLowering lowering) {
47 return lowering == VectorTransposeLowering::Shuffle1D ||
48 lowering == VectorTransposeLowering::Shuffle16x16;
49}
50
51/// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
52/// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
53/// create the mask for `numBits` bits vector. The `numBits` have to be a
54/// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
55/// 512, there should be 16 elements in the final result. It constructs the
56/// below mask to get the unpack elements.
57/// [0, 1, 16, 17,
58/// 0+4, 1+4, 16+4, 17+4,
59/// 0+8, 1+8, 16+8, 17+8,
60/// 0+12, 1+12, 16+12, 17+12]
63 assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
64 int numElem = numBits / 32;
66 for (int i = 0; i < numElem; i += 4)
67 for (int64_t v : vals)
68 res.push_back(v + i);
69 return res;
70}
71
72/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
73/// example, if it is targeting 512 bit vector, returns
74/// vector.shuffle on v1, v2, [0, 1, 16, 17,
75/// 0+4, 1+4, 16+4, 17+4,
76/// 0+8, 1+8, 16+8, 17+8,
77/// 0+12, 1+12, 16+12, 17+12].
79 int numBits) {
80 int numElem = numBits / 32;
81 return vector::ShuffleOp::create(
82 b, v1, v2,
83 getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
84}
85
86/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
87/// example, if it is targeting 512 bit vector, returns
88/// vector.shuffle, v1, v2, [2, 3, 18, 19,
89/// 2+4, 3+4, 18+4, 19+4,
90/// 2+8, 3+8, 18+8, 19+8,
91/// 2+12, 3+12, 18+12, 19+12].
93 int numBits) {
94 int numElem = numBits / 32;
95 return vector::ShuffleOp::create(
96 b, v1, v2,
97 getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
98 numBits));
99}
100
101/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
102/// example, if it is targeting 512 bit vector, returns
103/// vector.shuffle, v1, v2, [0, 16, 1, 17,
104/// 0+4, 16+4, 1+4, 17+4,
105/// 0+8, 16+8, 1+8, 17+8,
106/// 0+12, 16+12, 1+12, 17+12].
108 int numBits) {
109 int numElem = numBits / 32;
110 auto shuffle = vector::ShuffleOp::create(
111 b, v1, v2,
112 getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
113 return shuffle;
114}
115
116/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
117/// example, if it is targeting 512 bit vector, returns
118/// vector.shuffle, v1, v2, [2, 18, 3, 19,
119/// 2+4, 18+4, 3+4, 19+4,
120/// 2+8, 18+8, 3+8, 19+8,
121/// 2+12, 18+12, 3+12, 19+12].
123 int numBits) {
124 int numElem = numBits / 32;
125 return vector::ShuffleOp::create(
126 b, v1, v2,
127 getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
128 numBits));
129}
130
131/// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
132/// elements) selected by `mask` from `v1` and `v2`. I.e.,
133///
134/// DEFINE SELECT4(src, control) {
135/// CASE(control[1:0]) OF
136/// 0: tmp[127:0] := src[127:0]
137/// 1: tmp[127:0] := src[255:128]
138/// 2: tmp[127:0] := src[383:256]
139/// 3: tmp[127:0] := src[511:384]
140/// ESAC
141/// RETURN tmp[127:0]
142/// }
143/// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
144/// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
145/// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
146/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
148 uint8_t mask) {
149 assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
150 "expected a vector with length=16");
151 SmallVector<int64_t> shuffleMask;
152 auto appendToMask = [&](int64_t base, uint8_t control) {
153 switch (control) {
154 case 0:
155 llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
156 base + 2, base + 3});
157 break;
158 case 1:
159 llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
160 base + 6, base + 7});
161 break;
162 case 2:
163 llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
164 base + 10, base + 11});
165 break;
166 case 3:
167 llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
168 base + 14, base + 15});
169 break;
170 default:
171 llvm_unreachable("control > 3 : overflow");
172 }
173 };
174 uint8_t b01 = mask & 0x3;
175 uint8_t b23 = (mask >> 2) & 0x3;
176 uint8_t b45 = (mask >> 4) & 0x3;
177 uint8_t b67 = (mask >> 6) & 0x3;
178 appendToMask(0, b01);
179 appendToMask(0, b23);
180 appendToMask(16, b45);
181 appendToMask(16, b67);
182 return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
183}
184
185/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
186/// 1-D vector and have `m`x`n` elements.
187static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
189 mask.reserve(m * n);
190 for (int64_t j = 0; j < n; ++j)
191 for (int64_t i = 0; i < m; ++i)
192 mask.push_back(i * n + j);
193 return vector::ShuffleOp::create(b, source.getLoc(), source, source, mask);
194}
195
196/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
197/// expected to be a 16x16 vector.
198static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
199 int n) {
200 ImplicitLocOpBuilder b(source.getLoc(), builder);
202 for (int64_t i = 0; i < m; ++i)
203 vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
204
205 // Interleave 32-bit lanes using
206 // 8x _mm512_unpacklo_epi32
207 // 8x _mm512_unpackhi_epi32
208 Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
209 Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
210 Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
211 Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
212 Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
213 Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
214 Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
215 Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
216 Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
217 Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
218 Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
219 Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
220 Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
221 Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
222 Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
223 Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
224
225 // Interleave 64-bit lanes using
226 // 8x _mm512_unpacklo_epi64
227 // 8x _mm512_unpackhi_epi64
228 Value r0 = createUnpackLoPd(b, t0, t2, 512);
229 Value r1 = createUnpackHiPd(b, t0, t2, 512);
230 Value r2 = createUnpackLoPd(b, t1, t3, 512);
231 Value r3 = createUnpackHiPd(b, t1, t3, 512);
232 Value r4 = createUnpackLoPd(b, t4, t6, 512);
233 Value r5 = createUnpackHiPd(b, t4, t6, 512);
234 Value r6 = createUnpackLoPd(b, t5, t7, 512);
235 Value r7 = createUnpackHiPd(b, t5, t7, 512);
236 Value r8 = createUnpackLoPd(b, t8, ta, 512);
237 Value r9 = createUnpackHiPd(b, t8, ta, 512);
238 Value ra = createUnpackLoPd(b, t9, tb, 512);
239 Value rb = createUnpackHiPd(b, t9, tb, 512);
240 Value rc = createUnpackLoPd(b, tc, te, 512);
241 Value rd = createUnpackHiPd(b, tc, te, 512);
242 Value re = createUnpackLoPd(b, td, tf, 512);
243 Value rf = createUnpackHiPd(b, td, tf, 512);
244
245 // Permute 128-bit lanes using
246 // 16x _mm512_shuffle_i32x4
247 t0 = create4x128BitSuffle(b, r0, r4, 0x88);
248 t1 = create4x128BitSuffle(b, r1, r5, 0x88);
249 t2 = create4x128BitSuffle(b, r2, r6, 0x88);
250 t3 = create4x128BitSuffle(b, r3, r7, 0x88);
251 t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
252 t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
253 t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
254 t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
255 t8 = create4x128BitSuffle(b, r8, rc, 0x88);
256 t9 = create4x128BitSuffle(b, r9, rd, 0x88);
257 ta = create4x128BitSuffle(b, ra, re, 0x88);
258 tb = create4x128BitSuffle(b, rb, rf, 0x88);
259 tc = create4x128BitSuffle(b, r8, rc, 0xdd);
260 td = create4x128BitSuffle(b, r9, rd, 0xdd);
261 te = create4x128BitSuffle(b, ra, re, 0xdd);
262 tf = create4x128BitSuffle(b, rb, rf, 0xdd);
263
264 // Permute 256-bit lanes using again
265 // 16x _mm512_shuffle_i32x4
266 vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
267 vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
268 vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
269 vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
270 vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
271 vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
272 vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
273 vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
274 vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
275 vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
276 vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
277 vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
278 vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
279 vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
280 vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
281 vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
282
283 auto reshInputType = VectorType::get(
284 {m, n}, cast<VectorType>(source.getType()).getElementType());
285 Value res = ub::PoisonOp::create(b, reshInputType);
286 for (int64_t i = 0; i < m; ++i)
287 res = vector::InsertOp::create(b, vs[i], res, i);
288 return res;
289}
290
291namespace {
292/// Progressive lowering of TransposeOp.
293/// One:
294/// %x = vector.transpose %y, [1, 0]
295/// is replaced by:
296/// %z = arith.constant dense<0.000000e+00>
297/// %0 = vector.extract %y[0, 0]
298/// %1 = vector.insert %0, %z [0, 0]
299/// ..
300/// %x = vector.insert .., .. [.., ..]
301class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
302public:
303 using Base::Base;
304
305 TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
306 MLIRContext *context, PatternBenefit benefit = 1)
307 : OpRewritePattern<vector::TransposeOp>(context, benefit),
308 vectorTransposeLowering(vectorTransposeLowering) {}
309
310 LogicalResult matchAndRewrite(vector::TransposeOp op,
311 PatternRewriter &rewriter) const override {
312 auto loc = op.getLoc();
313
314 Value input = op.getVector();
315 VectorType inputType = op.getSourceVectorType();
316 VectorType resType = op.getResultVectorType();
317
318 if (inputType.isScalable())
319 return rewriter.notifyMatchFailure(
320 op, "This lowering does not support scalable vectors");
321
322 // Set up convenience transposition table.
323 ArrayRef<int64_t> transp = op.getPermutation();
324
325 if (isShuffleLike(vectorTransposeLowering) &&
326 succeeded(isTranspose2DSlice(op)))
327 return rewriter.notifyMatchFailure(
328 op, "Options specifies lowering to shuffle");
329
330 // Generate unrolled extract/insert ops. We do not unroll the rightmost
331 // (i.e., highest-order) dimensions that are not transposed and leave them
332 // in vector form to improve performance. Therefore, we prune those
333 // dimensions from the shape/transpose data structures used to generate the
334 // extract/insert ops.
335 SmallVector<int64_t> prunedTransp;
336 pruneNonTransposedDims(transp, prunedTransp);
337 size_t numPrunedDims = transp.size() - prunedTransp.size();
338 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
339 auto prunedInStrides = computeStrides(prunedInShape);
340
341 // Generates the extract/insert operations for every scalar/vector element
342 // of the leftmost transposed dimensions. We traverse every transpose
343 // element using a linearized index that we delinearize to generate the
344 // appropriate indices for the extract/insert operations.
345 Value result = ub::PoisonOp::create(rewriter, loc, resType);
346 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
347
348 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
349 ++linearIdx) {
350 auto extractIdxs = delinearize(linearIdx, prunedInStrides);
351 SmallVector<int64_t> insertIdxs(extractIdxs);
352 applyPermutationToVector(insertIdxs, prunedTransp);
353 Value extractOp =
354 rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
355 result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
356 insertIdxs);
357 }
358
359 rewriter.replaceOp(op, result);
360 return success();
361 }
362
363private:
364 /// Options to control the vector patterns.
365 vector::VectorTransposeLowering vectorTransposeLowering;
366};
367
368/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
369/// to 2D vectors with at least one unit dim. For example:
370///
371/// Replace:
372/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
373/// vector<1x4xi32>
374/// with:
375/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
376///
377/// Source with leading unit dim (inverse) is also replaced. Unit dim must
378/// be fixed. Non-unit dim can be scalable.
379///
380/// TODO: This pattern was introduced specifically to help lower scalable
381/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
382/// to cancel out) would be preferable:
383///
384/// BEFORE:
385/// %0 = some_op
386/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
387/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
388/// AFTER:
389/// %0 = some_op
390/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
391///
392/// Given the context above, we may want to consider (re-)moving this pattern
393/// at some later time. I am leaving it for now in case there are other users
394/// that I am not aware of.
395class Transpose2DWithUnitDimToShapeCast
396 : public OpRewritePattern<vector::TransposeOp> {
397public:
398 using Base::Base;
399
400 Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
401 PatternBenefit benefit = 1)
402 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
403
404 LogicalResult matchAndRewrite(vector::TransposeOp op,
405 PatternRewriter &rewriter) const override {
406 Value input = op.getVector();
407 VectorType resType = op.getResultVectorType();
408
409 // Set up convenience transposition table.
410 ArrayRef<int64_t> transp = op.getPermutation();
411
412 if (resType.getRank() == 2 &&
413 ((resType.getShape().front() == 1 &&
414 !resType.getScalableDims().front()) ||
415 (resType.getShape().back() == 1 &&
416 !resType.getScalableDims().back())) &&
417 transp == ArrayRef<int64_t>({1, 0})) {
418 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
419 return success();
420 }
421
422 return failure();
423 }
424};
425
426/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
427/// If the strategy is Shuffle1D, it will be lowered to:
428/// vector.shape_cast 2D -> 1D
429/// vector.shuffle
430/// vector.shape_cast 1D -> 2D
431/// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
432/// ops on 16xf32 vectors.
433class TransposeOp2DToShuffleLowering
434 : public OpRewritePattern<vector::TransposeOp> {
435public:
436 using Base::Base;
437
438 TransposeOp2DToShuffleLowering(
439 vector::VectorTransposeLowering vectorTransposeLowering,
440 MLIRContext *context, PatternBenefit benefit = 1)
441 : OpRewritePattern<vector::TransposeOp>(context, benefit),
442 vectorTransposeLowering(vectorTransposeLowering) {}
443
444 LogicalResult matchAndRewrite(vector::TransposeOp op,
445 PatternRewriter &rewriter) const override {
446 if (!isShuffleLike(vectorTransposeLowering))
447 return rewriter.notifyMatchFailure(
448 op, "not using vector shuffle based lowering");
449
450 if (op.getSourceVectorType().isScalable())
451 return rewriter.notifyMatchFailure(
452 op, "vector shuffle lowering not supported for scalable vectors");
453
454 auto srcGtOneDims = isTranspose2DSlice(op);
455 if (failed(srcGtOneDims))
456 return rewriter.notifyMatchFailure(
457 op, "expected transposition on a 2D slice");
458
459 VectorType srcType = op.getSourceVectorType();
460 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
461 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
462
463 // Reshape the n-D input vector with only two dimensions greater than one
464 // to a 2-D vector.
465 Location loc = op.getLoc();
466 auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
467 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
468 auto reshInput = vector::ShapeCastOp::create(rewriter, loc, flattenedType,
469 op.getVector());
470
471 Value res;
472 if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
473 m == 16 && n == 16) {
474 reshInput =
475 vector::ShapeCastOp::create(rewriter, loc, reshInputType, reshInput);
476 res = transposeToShuffle16x16(rewriter, reshInput, m, n);
477 } else {
478 // Fallback to shuffle on 1D approach.
479 res = transposeToShuffle1D(rewriter, reshInput, m, n);
480 }
481
482 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
483 op, op.getResultVectorType(), res);
484
485 return success();
486 }
487
488private:
489 /// Options to control the vector patterns.
490 vector::VectorTransposeLowering vectorTransposeLowering;
491};
492} // namespace
493
496 VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
497 patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
498 benefit);
499 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
500 vectorTransposeLowering, patterns.getContext(), benefit);
501}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.