MLIR 22.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1//===- VectorLinearize.cpp - vector linearization transforms --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns and pass for linearizing ND vectors into 1D.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Operation.h"
22#include "llvm/ADT/ArrayRef.h"
23#include <cstdint>
24#include <numeric>
25#include <optional>
26
27using namespace mlir;
28
29static FailureOr<Attribute>
30linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
31 VectorType resType, Attribute value) {
32
33 if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
34 if (resType.isScalable() && !isa<SplatElementsAttr>(value))
35 return rewriter.notifyMatchFailure(
36 loc,
37 "Cannot linearize a constant scalable vector that's not a splat");
38
39 return dstElementsAttr.reshape(resType);
40 }
41
42 if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
43 return poisonAttr;
44
45 return rewriter.notifyMatchFailure(loc, "unsupported attr type");
46}
47
48namespace {
49
50struct LinearizeConstantLike final
51 : OpTraitConversionPattern<OpTrait::ConstantLike> {
52 using OpTraitConversionPattern::OpTraitConversionPattern;
53
54 LinearizeConstantLike(const TypeConverter &typeConverter,
55 MLIRContext *context, PatternBenefit benefit = 1)
56 : OpTraitConversionPattern(typeConverter, context, benefit) {}
57 LogicalResult
58 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59 ConversionPatternRewriter &rewriter) const override {
60 Location loc = op->getLoc();
61 if (op->getNumResults() != 1)
62 return rewriter.notifyMatchFailure(loc, "expected 1 result");
63
64 const TypeConverter &typeConverter = *getTypeConverter();
65 auto resType =
66 typeConverter.convertType<VectorType>(op->getResult(0).getType());
67 assert(resType && "expected 1-D vector type");
68
69 StringAttr attrName = rewriter.getStringAttr("value");
70 Attribute value = op->getAttr(attrName);
71 if (!value)
72 return rewriter.notifyMatchFailure(loc, "no 'value' attr");
73
74 FailureOr<Attribute> newValue =
75 linearizeConstAttr(loc, rewriter, resType, value);
76 if (failed(newValue))
77 return failure();
78
79 FailureOr<Operation *> convertResult =
80 convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
81 if (failed(convertResult))
82 return failure();
83
84 Operation *newOp = *convertResult;
85 newOp->setAttr(attrName, *newValue);
86 rewriter.replaceOp(op, newOp);
87 return success();
88 }
89};
90
91struct LinearizeVectorizable final
92 : OpTraitConversionPattern<OpTrait::Vectorizable> {
93 using OpTraitConversionPattern::OpTraitConversionPattern;
94
95public:
96 LinearizeVectorizable(const TypeConverter &typeConverter,
97 MLIRContext *context, PatternBenefit benefit = 1)
98 : OpTraitConversionPattern(typeConverter, context, benefit) {}
99 LogicalResult
100 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101 ConversionPatternRewriter &rewriter) const override {
102 FailureOr<Operation *> newOp =
103 convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
104 if (failed(newOp))
105 return failure();
106
107 rewriter.replaceOp(op, (*newOp)->getResults());
108 return success();
109 }
110};
111
112template <typename TOp>
113static bool stridesAllOne(TOp op) {
114 static_assert(
115 std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116 std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117 "expected vector.extract_strided_slice or vector.insert_strided_slice");
118 ArrayAttr strides = op.getStrides();
119 return llvm::all_of(strides, isOneInteger);
120}
121
122/// Convert an array of attributes into a vector of integers, if possible.
123static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
124 if (!attrs)
125 return failure();
127 ints.reserve(attrs.size());
128 for (auto attr : attrs) {
129 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130 ints.push_back(intAttr.getInt());
131 } else {
132 return failure();
133 }
134 }
135 return ints;
136}
137
138/// Consider inserting a vector of shape `small` into a vector of shape `large`,
139/// at position `offsets`: this function enumeratates all the indices in `large`
140/// that are written to. The enumeration is with row-major ordering.
141///
142/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143/// positions written to are (1,3) and (1,4), which have linearized indices 8
144/// and 9. So [8,9] is returned.
145///
146/// The length of the returned vector is equal to the number of elements in
147/// the shape `small` (i.e. the product of dimensions of `small`).
148SmallVector<int64_t> static getStridedSliceInsertionIndices(
150 ArrayRef<int64_t> offsets) {
151
152 // Example of alignment between, `large`, `small` and `offsets`:
153 // large = 4, 5, 6, 7, 8
154 // small = 1, 6, 7, 8
155 // offsets = 2, 3, 0
156 //
157 // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158 assert((large.size() >= small.size()) &&
159 "rank of 'large' cannot be lower than rank of 'small'");
160 assert((large.size() >= offsets.size()) &&
161 "rank of 'large' cannot be lower than the number of offsets");
162 unsigned delta = large.size() - small.size();
163 unsigned nOffsets = offsets.size();
164 auto getSmall = [&](int64_t i) -> int64_t {
165 return i >= delta ? small[i - delta] : 1;
166 };
167 auto getOffset = [&](int64_t i) -> int64_t {
168 return i < nOffsets ? offsets[i] : 0;
169 };
170
171 // Using 2 vectors of indices, at each iteration populate the updated set of
172 // indices based on the old set of indices, and the size of the small vector
173 // in the current iteration.
175 int64_t stride = 1;
176 for (int i = large.size() - 1; i >= 0; --i) {
177 int64_t currentSize = indices.size();
178 int64_t smallSize = getSmall(i);
179 int64_t nextSize = currentSize * smallSize;
180 SmallVector<int64_t> nextIndices(nextSize);
181 int64_t *base = nextIndices.begin();
182 int64_t offset = getOffset(i) * stride;
183 for (int j = 0; j < smallSize; ++j) {
184 for (int k = 0; k < currentSize; ++k) {
185 base[k] = indices[k] + offset;
186 }
187 offset += stride;
188 base += currentSize;
189 }
190 stride *= large[i];
191 indices = std::move(nextIndices);
192 }
193 return indices;
194}
195
196/// This pattern converts a vector.extract_strided_slice operation into a
197/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
198///
199/// For example, the following:
200///
201/// ```
202/// vector.extract_strided_slice %source
203/// { offsets = [..], strides = [..], sizes = [..] }
204/// ```
205///
206/// is converted to :
207/// ```
208/// %source_1d = vector.shape_cast %source
209/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
210/// %out_nd = vector.shape_cast %out_1d
211/// ```
212///
213/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
214/// vector.extract_strided_slice operation.
215struct LinearizeVectorExtractStridedSlice final
216 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
217 using Base::Base;
218 LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
219 MLIRContext *context,
220 PatternBenefit benefit = 1)
221 : OpConversionPattern(typeConverter, context, benefit) {}
222
223 LogicalResult
224 matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
225 OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const override {
227
228 VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229 extractStridedSliceOp.getType());
230 assert(flatOutputType && "vector type expected");
231
232 // Expect a legalization failure if the strides are not all 1 (if ever the
233 // verifier for extract_strided_slice allows non-1 strides).
234 if (!stridesAllOne(extractStridedSliceOp)) {
235 return rewriter.notifyMatchFailure(
236 extractStridedSliceOp,
237 "extract_strided_slice with strides != 1 not supported");
238 }
239
240 FailureOr<SmallVector<int64_t>> offsets =
241 intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242 if (failed(offsets)) {
243 return rewriter.notifyMatchFailure(extractStridedSliceOp,
244 "failed to get integer offsets");
245 }
246
247 ArrayRef<int64_t> inputShape =
248 extractStridedSliceOp.getSourceVectorType().getShape();
249
250 ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
251
252 SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
253 outputShape, inputShape, offsets.value());
254
255 Value srcVector = adaptor.getSource();
256 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
257 extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
258 return success();
259 }
260};
261
262/// This pattern converts a vector.insert_strided_slice operation into a
263/// vector.shuffle operation that has rank-1 (linearized) operands and result.
264///
265/// For example, the following:
266/// ```
267/// %0 = vector.insert_strided_slice %to_store, %into
268/// {offsets = [1, 0, 0, 0], strides = [1, 1]}
269/// : vector<2x2xi8> into vector<2x1x3x2xi8>
270/// ```
271///
272/// is converted to
273/// ```
274/// %to_store_1d
275/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
276/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
277/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
278/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
279/// ```
280///
281/// where shuffle_indices_1d in this case is
282/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
283/// ^^^^^^^^^^^^^^
284/// to_store_1d
285///
286struct LinearizeVectorInsertStridedSlice final
287 : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
288 using Base::Base;
289 LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
290 MLIRContext *context,
291 PatternBenefit benefit = 1)
292 : OpConversionPattern(typeConverter, context, benefit) {}
293
294 LogicalResult
295 matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
296 OpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter) const override {
298
299 // Expect a legalization failure if the strides are not all 1 (if ever the
300 // verifier for insert_strided_slice allows non-1 strides).
301 if (!stridesAllOne(insertStridedSliceOp)) {
302 return rewriter.notifyMatchFailure(
303 insertStridedSliceOp,
304 "insert_strided_slice with strides != 1 not supported");
305 }
306
307 VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
308 ArrayRef<int64_t> inputShape = inputType.getShape();
309
310 VectorType outputType = insertStridedSliceOp.getType();
311 ArrayRef<int64_t> outputShape = outputType.getShape();
312 int64_t nOutputElements = outputType.getNumElements();
313
314 FailureOr<SmallVector<int64_t>> offsets =
315 intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316 if (failed(offsets)) {
317 return rewriter.notifyMatchFailure(insertStridedSliceOp,
318 "failed to get integer offsets");
319 }
320 SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
321 inputShape, outputShape, offsets.value());
322
323 SmallVector<int64_t> indices(nOutputElements);
324 std::iota(indices.begin(), indices.end(), 0);
325 for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
326 indices[sliceIndex] = index + nOutputElements;
327 }
328
329 Value flatToStore = adaptor.getValueToStore();
330 Value flatDest = adaptor.getDest();
331 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
332 flatDest.getType(), flatDest,
333 flatToStore, indices);
334 return success();
335 }
336};
337
338/// This pattern converts the ShuffleOp that works on nD (n > 1)
339/// vectors to a ShuffleOp that works on linearized vectors.
340/// Following,
341/// vector.shuffle %v1, %v2 [ shuffle_indices ]
342/// is converted to :
343/// %v1_1d = vector.shape_cast %v1
344/// %v2_1d = vector.shape_cast %v2
345/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
346/// %out_nd = vector.shape_cast %out_1d
347// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
348/// of the original shuffle operation.
349struct LinearizeVectorShuffle final
350 : public OpConversionPattern<vector::ShuffleOp> {
351 using Base::Base;
352 LinearizeVectorShuffle(const TypeConverter &typeConverter,
353 MLIRContext *context, PatternBenefit benefit = 1)
354 : OpConversionPattern(typeConverter, context, benefit) {}
355
356 LogicalResult
357 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 VectorType dstType =
360 getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
361 assert(dstType && "vector type destination expected.");
362
363 Value vec1 = adaptor.getV1();
364 Value vec2 = adaptor.getV2();
365 int shuffleSliceLen = 1;
366 int rank = shuffleOp.getV1().getType().getRank();
367
368 // If rank > 1, we need to do the shuffle in the granularity of slices
369 // instead of scalars. Size of the slice is equal to the rank-1 innermost
370 // dims. Mask of the shuffle op specifies which slice to take from the
371 // outermost dim.
372 if (rank > 1) {
373 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
374 for (unsigned i = 1; i < shape.size(); ++i) {
375 shuffleSliceLen *= shape[i];
376 }
377 }
378
379 // For each value in the mask, we generate the indices of the source vectors
380 // that need to be shuffled to the destination vector. If shuffleSliceLen >
381 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
382 // elements) instead of scalars.
383 ArrayRef<int64_t> mask = shuffleOp.getMask();
384 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
385 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
386 for (auto [i, value] : llvm::enumerate(mask)) {
387 std::iota(indices.begin() + shuffleSliceLen * i,
388 indices.begin() + shuffleSliceLen * (i + 1),
389 shuffleSliceLen * value);
390 }
391
392 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
393 vec2, indices);
394 return success();
395 }
396};
397
398/// This pattern linearizes `vector.extract` operations. It generates a 1-D
399/// version of the `vector.extract` operation when extracting a scalar from a
400/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
401/// subvector from a larger vector.
402///
403/// Example #1:
404///
405/// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406///
407/// is converted to:
408///
409/// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410/// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411/// 24, 25, 26, 27, 28, 29, 30, 31] :
412/// vector<32xf32>, vector<32xf32>
413/// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414///
415/// Example #2:
416///
417/// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418///
419/// is converted to:
420///
421/// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422/// %1 = vector.extract %0[6] : i32 from vector<8xi32>
423///
424struct LinearizeVectorExtract final
425 : public OpConversionPattern<vector::ExtractOp> {
426 using Base::Base;
427 LinearizeVectorExtract(const TypeConverter &typeConverter,
428 MLIRContext *context, PatternBenefit benefit = 1)
429 : OpConversionPattern(typeConverter, context, benefit) {}
430 LogicalResult
431 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter) const override {
433 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
434 assert(dstTy && "expected 1-D vector type");
435
436 // Dynamic position is not supported.
437 if (extractOp.hasDynamicPosition())
438 return rewriter.notifyMatchFailure(extractOp,
439 "dynamic position is not supported.");
440
441 llvm::ArrayRef<int64_t> shape = extractOp.getSource().getType().getShape();
442 int64_t size = extractOp.getSource().getType().getNumElements();
443
444 // Compute linearized offset.
445 int64_t linearizedOffset = 0;
446 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
447 for (auto [i, off] : llvm::enumerate(offsets)) {
448 size /= shape[i];
449 linearizedOffset += offsets[i] * size;
450 }
451
452 Value srcVector = adaptor.getSource();
453 if (!isa<VectorType>(extractOp.getType())) {
454 // Scalar case: generate a 1-D extract.
455 Value result = rewriter.createOrFold<vector::ExtractOp>(
456 extractOp.getLoc(), srcVector, linearizedOffset);
457 rewriter.replaceOp(extractOp, result);
458 return success();
459 }
460
461 // Vector case: generate a shuffle.
462
463 llvm::SmallVector<int64_t, 2> indices(size);
464 std::iota(indices.begin(), indices.end(), linearizedOffset);
465 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, srcVector,
466 srcVector, indices);
467
468 return success();
469 }
470};
471
472/// This pattern linearizes `vector.insert` operations. It generates a 1-D
473/// version of the `vector.insert` operation when inserting a scalar into a
474/// vector. It generates a 1-D `vector.shuffle` operation when inserting a
475/// vector into another vector.
476///
477/// Example #1:
478///
479/// %0 = vector.insert %source, %destination[0] :
480/// vector<2x4xf32> into vector<2x2x4xf32>
481///
482/// is converted to:
483///
484/// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
485/// %1 = vector.shape_cast %destination :
486/// vector<2x2x4xf32> to vector<16xf32>
487/// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
488/// 8, 9, 10, 11, 12, 13, 14, 15] :
489/// vector<16xf32>, vector<8xf32>
490/// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
491///
492/// Example #2:
493///
494/// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
495///
496/// is converted to:
497///
498/// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
499/// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
500/// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
501///
502struct LinearizeVectorInsert final
503 : public OpConversionPattern<vector::InsertOp> {
504 using Base::Base;
505 LinearizeVectorInsert(const TypeConverter &typeConverter,
506 MLIRContext *context, PatternBenefit benefit = 1)
507 : OpConversionPattern(typeConverter, context, benefit) {}
508 LogicalResult
509 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter) const override {
511 VectorType dstTy = getTypeConverter()->convertType<VectorType>(
512 insertOp.getDestVectorType());
513 assert(dstTy && "vector type destination expected.");
514
515 // Dynamic position is not supported.
516 if (insertOp.hasDynamicPosition())
517 return rewriter.notifyMatchFailure(insertOp,
518 "dynamic position is not supported.");
519 auto srcTy = insertOp.getValueToStoreType();
520 auto srcAsVec = dyn_cast<VectorType>(srcTy);
521 uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
522
523 auto dstShape = insertOp.getDestVectorType().getShape();
524 const auto dstSize = insertOp.getDestVectorType().getNumElements();
525 auto dstSizeForOffsets = dstSize;
526
527 // Compute linearized offset.
528 int64_t linearizedOffset = 0;
529 auto offsetsNd = insertOp.getStaticPosition();
530 for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
531 dstSizeForOffsets /= dstShape[dim];
532 linearizedOffset += offset * dstSizeForOffsets;
533 }
534
535 Location loc = insertOp.getLoc();
536 Value valueToStore = adaptor.getValueToStore();
537
538 if (!isa<VectorType>(valueToStore.getType())) {
539 // Scalar case: generate a 1-D insert.
540 Value result = rewriter.createOrFold<vector::InsertOp>(
541 loc, valueToStore, adaptor.getDest(), linearizedOffset);
542 rewriter.replaceOp(insertOp, result);
543 return success();
544 }
545
546 // Vector case: generate a shuffle.
547 llvm::SmallVector<int64_t, 2> indices(dstSize);
548 auto *origValsUntil = indices.begin();
549 std::advance(origValsUntil, linearizedOffset);
550
551 // Original values that remain [0, offset).
552 std::iota(indices.begin(), origValsUntil, 0);
553 auto *newValsUntil = origValsUntil;
554 std::advance(newValsUntil, srcSize);
555 // New values [offset, offset+srcNumElements).
556 std::iota(origValsUntil, newValsUntil, dstSize);
557 // The rest of original values [offset+srcNumElements, end);
558 std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize);
559
560 Value result = rewriter.createOrFold<vector::ShuffleOp>(
561 loc, dstTy, adaptor.getDest(), valueToStore, indices);
562
563 rewriter.replaceOp(insertOp, result);
564 return success();
565 }
566};
567
568/// This pattern converts the BitCastOp that works on nD (n > 1)
569/// vectors to a BitCastOp that works on linearized vectors.
570/// Following,
571/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
572/// is converted to :
573/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
574/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
575/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
576struct LinearizeVectorBitCast final
577 : public OpConversionPattern<vector::BitCastOp> {
578 using Base::Base;
579 LinearizeVectorBitCast(const TypeConverter &typeConverter,
580 MLIRContext *context, PatternBenefit benefit = 1)
581 : OpConversionPattern(typeConverter, context, benefit) {}
582 LogicalResult
583 matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
584 ConversionPatternRewriter &rewriter) const override {
585 auto resType = getTypeConverter()->convertType(castOp.getType());
586 assert(resType && "expected 1-D vector type");
587 rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
588 adaptor.getSource());
589 return mlir::success();
590 }
591};
592
593/// This pattern converts the CreateMaskOp to work on a linearized vector.
594/// It currently supports only 2D masks with a unit outer dimension.
595/// Following,
596/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
597/// is converted to:
598/// %zero = arith.constant 0 : index
599/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
600/// %index = arith.index_cast %cmpi : i1 to index
601/// %mul = arith.andi %index, %arg1 : index
602/// %mask = vector.create_mask %mul : vector<4xi1>
603/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
604struct LinearizeVectorCreateMask final
605 : OpConversionPattern<vector::CreateMaskOp> {
606 using Base::Base;
607
608 LinearizeVectorCreateMask(const TypeConverter &typeConverter,
609 MLIRContext *context, PatternBenefit benefit = 1)
610 : OpConversionPattern(typeConverter, context, benefit) {}
611
612 LogicalResult
613 matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
614 ConversionPatternRewriter &rewriter) const override {
615 Location loc = createMaskOp.getLoc();
616 VectorType srcTy = createMaskOp.getType();
617 auto srcShape = srcTy.getShape();
618 if (srcShape.size() != 2)
619 return rewriter.notifyMatchFailure(createMaskOp,
620 "only 2D mask is supported.");
621
622 if (srcShape[0] != 1)
623 return rewriter.notifyMatchFailure(
624 createMaskOp, "only unit outer dimension is supported.");
625
626 auto dstTy = getTypeConverter()->convertType(srcTy);
627 if (!dstTy)
628 return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
629
630 // Compare the first operand with 0. If it is greater than 0, the
631 // corresponding mask element is set to true, otherwise false.
632 // The result of the comparison is then multiplied with
633 // the second operand of create_mask to get the 1D mask.
634 auto firstOperand = adaptor.getOperands().front();
635 auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
636 auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
637 loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
638 auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
639 loc, rewriter.getIndexType(), isNonZero);
640 auto secondOperand = adaptor.getOperands().back();
641 auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
642 loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
643
644 auto newMask =
645 mlir::vector::CreateMaskOp::create(rewriter, loc, dstTy, maskSize);
646 rewriter.replaceOp(createMaskOp, newMask);
647 return success();
648 }
649};
650
651/// This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
652/// It currently supports linearization where all but the last dimension are 1
653/// The following,
654/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
655/// is converted to:
656/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
657/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
658/// For generic cases, the vector unroll pass should be used to unroll the load
659/// to vector<1x1x...xN> form and then linearized
660struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
661 using Base::Base;
662 LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
663 PatternBenefit benefit = 1)
664 : OpConversionPattern(typeConverter, context, benefit) {}
665
666 LogicalResult
667 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter) const override {
669 VectorType vecTy = loadOp.getType();
670 if (!vecTy)
671 return rewriter.notifyMatchFailure(loadOp, "expected vector type");
672
673 auto shape = vecTy.getShape();
674 auto scalableDims = vecTy.getScalableDims();
675 // All but the last dim must be 1, and only the last dim may be scalable (if
676 // any).
677 if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
678 return rewriter.notifyMatchFailure(loadOp,
679 "only vector<1x1x...xN> supported");
680
681 if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
682 return rewriter.notifyMatchFailure(loadOp,
683 "only innermost dim may be scalable");
684
685 auto linearTy = typeConverter->convertType<VectorType>(vecTy);
686
687 auto newLoad =
688 vector::LoadOp::create(rewriter, loadOp.getLoc(), linearTy,
689 adaptor.getBase(), adaptor.getIndices());
690 rewriter.replaceOp(loadOp, newLoad.getResult());
691 return success();
692 }
693};
694
695/// This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
696/// It currently supports linearization where all but the last dimension are 1
697/// The following,
698/// vector.store %arg0, %arg1[%c0, %c0]s
699/// : vector<1x4xf32>, memref<1x4xf32>
700/// is converted to:
701/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
702/// vector.store %arg0, %arg1[%c0, %c0]
703/// : vector<4xf32>, memref<1x4xf32>
704/// For generic cases, the vector unroll pass should be used to unroll the store
705/// to vector<1x1x...xN> form and then linearized
706struct LinearizeVectorStore final
707 : public OpConversionPattern<vector::StoreOp> {
708 using Base::Base;
709 LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
710 PatternBenefit benefit = 1)
711 : OpConversionPattern(typeConverter, context, benefit) {}
712
713 LogicalResult
714 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const override {
716 VectorType vecTy = storeOp.getValueToStore().getType();
717 if (!vecTy)
718 return rewriter.notifyMatchFailure(storeOp, "expected vector type");
719
720 auto shape = vecTy.getShape();
721 auto scalableDims = vecTy.getScalableDims();
722 // All but the last dim must be 1, and only the last dim may be scalable (if
723 // any).
724 if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
725 return rewriter.notifyMatchFailure(storeOp,
726 "only vector<1x1x...xN> supported");
727
728 if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
729 return rewriter.notifyMatchFailure(storeOp,
730 "only innermost dim may be scalable");
731
732 rewriter.replaceOpWithNewOp<vector::StoreOp>(
733 storeOp, adaptor.getValueToStore(), adaptor.getBase(),
734 adaptor.getIndices());
735 return success();
736 }
737};
738
739/// This pattern linearizes `vector.from_elements` operations by converting
740/// the result type to a 1-D vector while preserving all element values.
741/// The transformation creates a linearized `vector.from_elements` followed by
742/// a `vector.shape_cast` to restore the original multidimensional shape.
743///
744/// Example:
745///
746/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
747///
748/// is converted to:
749///
750/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
751/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
752///
753struct LinearizeVectorFromElements final
754 : public OpConversionPattern<vector::FromElementsOp> {
755 using Base::Base;
756 LinearizeVectorFromElements(const TypeConverter &typeConverter,
757 MLIRContext *context, PatternBenefit benefit = 1)
758 : OpConversionPattern(typeConverter, context, benefit) {}
759 LogicalResult
760 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
761 ConversionPatternRewriter &rewriter) const override {
762 VectorType dstTy =
763 getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
764 assert(dstTy && "vector type destination expected.");
765
766 OperandRange elements = fromElementsOp.getElements();
767 assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
768 "expected same number of elements");
769 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
770 elements);
771 return success();
772 }
773};
774
775/// This pattern linearizes the operand in `vector.to_elements` operations
776/// by converting the source type to a 1-D vector while preserving all element
777/// values. The transformation creates a linearized `vector.shape_cast`
778/// followed by a `vector.to_elements`.
779///
780/// Example:
781///
782/// %0:4 = vector.to_elements %v : vector<2x2xf32>
783///
784/// is converted to:
785///
786/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
787/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
788///
789struct LinearizeVectorToElements final
790 : public OpConversionPattern<vector::ToElementsOp> {
791 using Base::Base;
792
793 LinearizeVectorToElements(const TypeConverter &typeConverter,
794 MLIRContext *context, PatternBenefit benefit = 1)
795 : OpConversionPattern(typeConverter, context, benefit) {}
796
797 LogicalResult
798 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
799 ConversionPatternRewriter &rewriter) const override {
800
801 VectorType vecType = toElementsOp.getSource().getType();
802 if (vecType.getRank() <= 1)
803 return rewriter.notifyMatchFailure(
804 toElementsOp, "the rank is already less than or equal to 1");
805
806 assert(vecType.getNumScalableDims() == 0 &&
807 "to_elements does not support scalable vectors");
808 auto vec1DType =
809 VectorType::get({vecType.getNumElements()}, vecType.getElementType());
810 Value shapeCast = vector::ShapeCastOp::create(
811 rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
812 auto newToElementsOp =
813 vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(),
814 toElementsOp.getResultTypes(), shapeCast);
815 rewriter.replaceOp(toElementsOp, newToElementsOp);
816 return success();
817 }
818};
819
820/// Convert broadcasts from scalars or 1-element vectors, such as
821///
822/// ```mlir
823/// vector.broadcast %value : f32 to vector<4x4xf32>
824/// ```
825///
826/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
827/// The above becomes,
828///
829/// ```mlir
830/// %out_1d = vector.broadcast %value : f32 to vector<16xf32>
831/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
832/// ```
833struct LinearizeVectorBroadcast final
834 : public OpConversionPattern<vector::BroadcastOp> {
835 using Base::Base;
836
837 LinearizeVectorBroadcast(const TypeConverter &typeConverter,
838 MLIRContext *context, PatternBenefit benefit = 1)
839 : OpConversionPattern(typeConverter, context, benefit) {}
840
841 LogicalResult
842 matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
843 ConversionPatternRewriter &rewriter) const override {
844
845 int numElements = 1;
846 Type sourceType = broadcastOp.getSourceType();
847 if (auto vecType = dyn_cast<VectorType>(sourceType)) {
848 numElements = vecType.getNumElements();
849 }
850
851 if (numElements != 1) {
852 return rewriter.notifyMatchFailure(
853 broadcastOp, "only broadcasts of single elements can be linearized.");
854 }
855
856 auto dstTy = getTypeConverter()->convertType(broadcastOp.getType());
857 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy,
858 adaptor.getSource());
859
860 return success();
861 }
862};
863
864} // namespace
865
866/// This method defines the set of operations that are linearizable, and hence
867/// that are considered illegal for the conversion target.
868static bool isLinearizable(Operation *op) {
869
870 // Only ops that are in the vector dialect, are ConstantLike, or
871 // are Vectorizable might be linearized currently.
872 StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
873 StringRef opDialect = op->getDialect()->getNamespace();
874 bool supported = (opDialect == vectorDialect) ||
877 if (!supported)
878 return false;
879
881 // As type legalization is done with vector.shape_cast, shape_cast
882 // itself cannot be linearized (will create new shape_casts to linearize
883 // ad infinitum).
884 .Case<vector::ShapeCastOp>([&](auto) { return false; })
885 // The operations
886 // - vector.extract_strided_slice
887 // - vector.extract
888 // - vector.insert_strided_slice
889 // - vector.insert
890 // are linearized to a rank-1 vector.shuffle by the current patterns.
891 // vector.shuffle only supports fixed size vectors, so it is impossible to
892 // use this approach to linearize these ops if they operate on scalable
893 // vectors.
894 .Case<vector::ExtractStridedSliceOp>(
895 [&](vector::ExtractStridedSliceOp extractOp) {
896 return !extractOp.getType().isScalable();
897 })
898 .Case<vector::InsertStridedSliceOp>(
899 [&](vector::InsertStridedSliceOp insertOp) {
900 return !insertOp.getType().isScalable();
901 })
902 .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
903 return !insertOp.getType().isScalable();
904 })
905 .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
906 return !extractOp.getSourceVectorType().isScalable();
907 })
908 .Default([&](auto) { return true; });
909}
910
911void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
913
914 auto convertType = [](Type type) -> std::optional<Type> {
915 VectorType vectorType = dyn_cast<VectorType>(type);
916 if (!vectorType || !isLinearizableVector(vectorType))
917 return type;
918
919 VectorType linearizedType =
920 VectorType::get(vectorType.getNumElements(),
921 vectorType.getElementType(), vectorType.isScalable());
922 return linearizedType;
923 };
924 typeConverter.addConversion(convertType);
925
926 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
927 Location loc) -> Value {
928 if (inputs.size() != 1)
929 return nullptr;
930
931 Value value = inputs.front();
932 if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
933 return nullptr;
934
935 return vector::ShapeCastOp::create(builder, loc, type, value);
936 };
937 typeConverter.addSourceMaterialization(materializeCast);
938 typeConverter.addTargetMaterialization(materializeCast);
939
940 target.markUnknownOpDynamicallyLegal(
941 [=](Operation *op) -> std::optional<bool> {
942 if (!isLinearizable(op))
943 return true;
944 // This will return true if, for all operand and result types `t`,
945 // convertType(t) = t. This is true if there are no rank>=2 vectors.
946 return typeConverter.isLegal(op);
947 });
948}
949
950void mlir::vector::populateVectorLinearizeBasePatterns(
951 const TypeConverter &typeConverter, const ConversionTarget &target,
954 .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
955 LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
956 LinearizeVectorBroadcast, LinearizeVectorFromElements,
957 LinearizeVectorToElements>(typeConverter, patterns.getContext());
958}
959
960void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
961 const TypeConverter &typeConverter, const ConversionTarget &target,
963 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
964 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
965 LinearizeVectorInsertStridedSlice>(typeConverter,
966 patterns.getContext());
967}
return success()
ArrayAttr()
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
static bool isLinearizable(Operation *op)
This method defines the set of operations that are linearizable, and hence that are considered illega...
Attributes are known-constant values of operations.
Definition Attributes.h:25
StringRef getNamespace() const
Definition Dialect.h:54
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.