MLIR  22.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 #include <optional>
26 
27 using namespace mlir;
28 
29 static FailureOr<Attribute>
31  VectorType resType, Attribute value) {
32 
33  if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
34  if (resType.isScalable() && !isa<SplatElementsAttr>(value))
35  return rewriter.notifyMatchFailure(
36  loc,
37  "Cannot linearize a constant scalable vector that's not a splat");
38 
39  return dstElementsAttr.reshape(resType);
40  }
41 
42  if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
43  return poisonAttr;
44 
45  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
46 }
47 
48 namespace {
49 
50 struct LinearizeConstantLike final
51  : OpTraitConversionPattern<OpTrait::ConstantLike> {
53 
54  LinearizeConstantLike(const TypeConverter &typeConverter,
55  MLIRContext *context, PatternBenefit benefit = 1)
56  : OpTraitConversionPattern(typeConverter, context, benefit) {}
57  LogicalResult
58  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59  ConversionPatternRewriter &rewriter) const override {
60  Location loc = op->getLoc();
61  if (op->getNumResults() != 1)
62  return rewriter.notifyMatchFailure(loc, "expected 1 result");
63 
64  const TypeConverter &typeConverter = *getTypeConverter();
65  auto resType =
66  typeConverter.convertType<VectorType>(op->getResult(0).getType());
67  assert(resType && "expected 1-D vector type");
68 
69  StringAttr attrName = rewriter.getStringAttr("value");
70  Attribute value = op->getAttr(attrName);
71  if (!value)
72  return rewriter.notifyMatchFailure(loc, "no 'value' attr");
73 
74  FailureOr<Attribute> newValue =
75  linearizeConstAttr(loc, rewriter, resType, value);
76  if (failed(newValue))
77  return failure();
78 
79  FailureOr<Operation *> convertResult =
80  convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
81  if (failed(convertResult))
82  return failure();
83 
84  Operation *newOp = *convertResult;
85  newOp->setAttr(attrName, *newValue);
86  rewriter.replaceOp(op, newOp);
87  return success();
88  }
89 };
90 
91 struct LinearizeVectorizable final
92  : OpTraitConversionPattern<OpTrait::Vectorizable> {
94 
95 public:
96  LinearizeVectorizable(const TypeConverter &typeConverter,
97  MLIRContext *context, PatternBenefit benefit = 1)
98  : OpTraitConversionPattern(typeConverter, context, benefit) {}
99  LogicalResult
100  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101  ConversionPatternRewriter &rewriter) const override {
102  FailureOr<Operation *> newOp =
103  convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
104  if (failed(newOp))
105  return failure();
106 
107  rewriter.replaceOp(op, (*newOp)->getResults());
108  return success();
109  }
110 };
111 
112 template <typename TOp>
113 static bool stridesAllOne(TOp op) {
114  static_assert(
115  std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116  std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117  "expected vector.extract_strided_slice or vector.insert_strided_slice");
118  ArrayAttr strides = op.getStrides();
119  return llvm::all_of(strides, isOneInteger);
120 }
121 
122 /// Convert an array of attributes into a vector of integers, if possible.
123 static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
124  if (!attrs)
125  return failure();
127  ints.reserve(attrs.size());
128  for (auto attr : attrs) {
129  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130  ints.push_back(intAttr.getInt());
131  } else {
132  return failure();
133  }
134  }
135  return ints;
136 }
137 
138 /// Consider inserting a vector of shape `small` into a vector of shape `large`,
139 /// at position `offsets`: this function enumeratates all the indices in `large`
140 /// that are written to. The enumeration is with row-major ordering.
141 ///
142 /// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143 /// positions written to are (1,3) and (1,4), which have linearized indices 8
144 /// and 9. So [8,9] is returned.
145 ///
146 /// The length of the returned vector is equal to the number of elements in
147 /// the shape `small` (i.e. the product of dimensions of `small`).
148 SmallVector<int64_t> static getStridedSliceInsertionIndices(
150  ArrayRef<int64_t> offsets) {
151 
152  // Example of alignment between, `large`, `small` and `offsets`:
153  // large = 4, 5, 6, 7, 8
154  // small = 1, 6, 7, 8
155  // offsets = 2, 3, 0
156  //
157  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158  assert((large.size() >= small.size()) &&
159  "rank of 'large' cannot be lower than rank of 'small'");
160  assert((large.size() >= offsets.size()) &&
161  "rank of 'large' cannot be lower than the number of offsets");
162  unsigned delta = large.size() - small.size();
163  unsigned nOffsets = offsets.size();
164  auto getSmall = [&](int64_t i) -> int64_t {
165  return i >= delta ? small[i - delta] : 1;
166  };
167  auto getOffset = [&](int64_t i) -> int64_t {
168  return i < nOffsets ? offsets[i] : 0;
169  };
170 
171  // Using 2 vectors of indices, at each iteration populate the updated set of
172  // indices based on the old set of indices, and the size of the small vector
173  // in the current iteration.
174  SmallVector<int64_t> indices{0};
175  int64_t stride = 1;
176  for (int i = large.size() - 1; i >= 0; --i) {
177  int64_t currentSize = indices.size();
178  int64_t smallSize = getSmall(i);
179  int64_t nextSize = currentSize * smallSize;
180  SmallVector<int64_t> nextIndices(nextSize);
181  int64_t *base = nextIndices.begin();
182  int64_t offset = getOffset(i) * stride;
183  for (int j = 0; j < smallSize; ++j) {
184  for (int k = 0; k < currentSize; ++k) {
185  base[k] = indices[k] + offset;
186  }
187  offset += stride;
188  base += currentSize;
189  }
190  stride *= large[i];
191  indices = std::move(nextIndices);
192  }
193  return indices;
194 }
195 
196 /// This pattern converts a vector.extract_strided_slice operation into a
197 /// vector.shuffle operation that has a rank-1 (linearized) operand and result.
198 ///
199 /// For example, the following:
200 ///
201 /// ```
202 /// vector.extract_strided_slice %source
203 /// { offsets = [..], strides = [..], sizes = [..] }
204 /// ```
205 ///
206 /// is converted to :
207 /// ```
208 /// %source_1d = vector.shape_cast %source
209 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
210 /// %out_nd = vector.shape_cast %out_1d
211 /// ```
212 ///
213 /// `shuffle_indices_1d` is computed using the offsets and sizes of the original
214 /// vector.extract_strided_slice operation.
215 struct LinearizeVectorExtractStridedSlice final
216  : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
217  using Base::Base;
218  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
219  MLIRContext *context,
220  PatternBenefit benefit = 1)
221  : OpConversionPattern(typeConverter, context, benefit) {}
222 
223  LogicalResult
224  matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
225  OpAdaptor adaptor,
226  ConversionPatternRewriter &rewriter) const override {
227 
228  VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229  extractStridedSliceOp.getType());
230  assert(flatOutputType && "vector type expected");
231 
232  // Expect a legalization failure if the strides are not all 1 (if ever the
233  // verifier for extract_strided_slice allows non-1 strides).
234  if (!stridesAllOne(extractStridedSliceOp)) {
235  return rewriter.notifyMatchFailure(
236  extractStridedSliceOp,
237  "extract_strided_slice with strides != 1 not supported");
238  }
239 
240  FailureOr<SmallVector<int64_t>> offsets =
241  intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242  if (failed(offsets)) {
243  return rewriter.notifyMatchFailure(extractStridedSliceOp,
244  "failed to get integer offsets");
245  }
246 
247  ArrayRef<int64_t> inputShape =
248  extractStridedSliceOp.getSourceVectorType().getShape();
249 
250  ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
251 
252  SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
253  outputShape, inputShape, offsets.value());
254 
255  Value srcVector = adaptor.getSource();
256  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
257  extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
258  return success();
259  }
260 };
261 
262 /// This pattern converts a vector.insert_strided_slice operation into a
263 /// vector.shuffle operation that has rank-1 (linearized) operands and result.
264 ///
265 /// For example, the following:
266 /// ```
267 /// %0 = vector.insert_strided_slice %to_store, %into
268 /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
269 /// : vector<2x2xi8> into vector<2x1x3x2xi8>
270 /// ```
271 ///
272 /// is converted to
273 /// ```
274 /// %to_store_1d
275 /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
276 /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
277 /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
278 /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
279 /// ```
280 ///
281 /// where shuffle_indices_1d in this case is
282 /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
283 /// ^^^^^^^^^^^^^^
284 /// to_store_1d
285 ///
286 struct LinearizeVectorInsertStridedSlice final
287  : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
288  using Base::Base;
289  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
290  MLIRContext *context,
291  PatternBenefit benefit = 1)
292  : OpConversionPattern(typeConverter, context, benefit) {}
293 
294  LogicalResult
295  matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
296  OpAdaptor adaptor,
297  ConversionPatternRewriter &rewriter) const override {
298 
299  // Expect a legalization failure if the strides are not all 1 (if ever the
300  // verifier for insert_strided_slice allows non-1 strides).
301  if (!stridesAllOne(insertStridedSliceOp)) {
302  return rewriter.notifyMatchFailure(
303  insertStridedSliceOp,
304  "insert_strided_slice with strides != 1 not supported");
305  }
306 
307  VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
308  ArrayRef<int64_t> inputShape = inputType.getShape();
309 
310  VectorType outputType = insertStridedSliceOp.getType();
311  ArrayRef<int64_t> outputShape = outputType.getShape();
312  int64_t nOutputElements = outputType.getNumElements();
313 
314  FailureOr<SmallVector<int64_t>> offsets =
315  intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316  if (failed(offsets)) {
317  return rewriter.notifyMatchFailure(insertStridedSliceOp,
318  "failed to get integer offsets");
319  }
320  SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
321  inputShape, outputShape, offsets.value());
322 
323  SmallVector<int64_t> indices(nOutputElements);
324  std::iota(indices.begin(), indices.end(), 0);
325  for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
326  indices[sliceIndex] = index + nOutputElements;
327  }
328 
329  Value flatToStore = adaptor.getValueToStore();
330  Value flatDest = adaptor.getDest();
331  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
332  flatDest.getType(), flatDest,
333  flatToStore, indices);
334  return success();
335  }
336 };
337 
338 /// This pattern converts the ShuffleOp that works on nD (n > 1)
339 /// vectors to a ShuffleOp that works on linearized vectors.
340 /// Following,
341 /// vector.shuffle %v1, %v2 [ shuffle_indices ]
342 /// is converted to :
343 /// %v1_1d = vector.shape_cast %v1
344 /// %v2_1d = vector.shape_cast %v2
345 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
346 /// %out_nd = vector.shape_cast %out_1d
347 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
348 /// of the original shuffle operation.
349 struct LinearizeVectorShuffle final
350  : public OpConversionPattern<vector::ShuffleOp> {
351  using Base::Base;
352  LinearizeVectorShuffle(const TypeConverter &typeConverter,
353  MLIRContext *context, PatternBenefit benefit = 1)
354  : OpConversionPattern(typeConverter, context, benefit) {}
355 
356  LogicalResult
357  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
358  ConversionPatternRewriter &rewriter) const override {
359  VectorType dstType =
360  getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
361  assert(dstType && "vector type destination expected.");
362 
363  Value vec1 = adaptor.getV1();
364  Value vec2 = adaptor.getV2();
365  int shuffleSliceLen = 1;
366  int rank = shuffleOp.getV1().getType().getRank();
367 
368  // If rank > 1, we need to do the shuffle in the granularity of slices
369  // instead of scalars. Size of the slice is equal to the rank-1 innermost
370  // dims. Mask of the shuffle op specifies which slice to take from the
371  // outermost dim.
372  if (rank > 1) {
373  llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
374  for (unsigned i = 1; i < shape.size(); ++i) {
375  shuffleSliceLen *= shape[i];
376  }
377  }
378 
379  // For each value in the mask, we generate the indices of the source vectors
380  // that need to be shuffled to the destination vector. If shuffleSliceLen >
381  // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
382  // elements) instead of scalars.
383  ArrayRef<int64_t> mask = shuffleOp.getMask();
384  int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
385  llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
386  for (auto [i, value] : llvm::enumerate(mask)) {
387  std::iota(indices.begin() + shuffleSliceLen * i,
388  indices.begin() + shuffleSliceLen * (i + 1),
389  shuffleSliceLen * value);
390  }
391 
392  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
393  vec2, indices);
394  return success();
395  }
396 };
397 
398 /// This pattern linearizes `vector.extract` operations. It generates a 1-D
399 /// version of the `vector.extract` operation when extracting a scalar from a
400 /// vector. It generates a 1-D `vector.shuffle` operation when extracting a
401 /// subvector from a larger vector.
402 ///
403 /// Example #1:
404 ///
405 /// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406 ///
407 /// is converted to:
408 ///
409 /// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410 /// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411 /// 24, 25, 26, 27, 28, 29, 30, 31] :
412 /// vector<32xf32>, vector<32xf32>
413 /// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414 ///
415 /// Example #2:
416 ///
417 /// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418 ///
419 /// is converted to:
420 ///
421 /// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422 /// %1 = vector.extract %0[6] : i32 from vector<8xi32>
423 ///
424 struct LinearizeVectorExtract final
425  : public OpConversionPattern<vector::ExtractOp> {
426  using Base::Base;
427  LinearizeVectorExtract(const TypeConverter &typeConverter,
428  MLIRContext *context, PatternBenefit benefit = 1)
429  : OpConversionPattern(typeConverter, context, benefit) {}
430  LogicalResult
431  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
432  ConversionPatternRewriter &rewriter) const override {
433  Type dstTy = getTypeConverter()->convertType(extractOp.getType());
434  assert(dstTy && "expected 1-D vector type");
435 
436  // Dynamic position is not supported.
437  if (extractOp.hasDynamicPosition())
438  return rewriter.notifyMatchFailure(extractOp,
439  "dynamic position is not supported.");
440 
441  llvm::ArrayRef<int64_t> shape = extractOp.getSource().getType().getShape();
442  int64_t size = extractOp.getSource().getType().getNumElements();
443 
444  // Compute linearized offset.
445  int64_t linearizedOffset = 0;
446  llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
447  for (auto [i, off] : llvm::enumerate(offsets)) {
448  size /= shape[i];
449  linearizedOffset += offsets[i] * size;
450  }
451 
452  Value srcVector = adaptor.getSource();
453  if (!isa<VectorType>(extractOp.getType())) {
454  // Scalar case: generate a 1-D extract.
455  Value result = rewriter.createOrFold<vector::ExtractOp>(
456  extractOp.getLoc(), srcVector, linearizedOffset);
457  rewriter.replaceOp(extractOp, result);
458  return success();
459  }
460 
461  // Vector case: generate a shuffle.
462 
463  llvm::SmallVector<int64_t, 2> indices(size);
464  std::iota(indices.begin(), indices.end(), linearizedOffset);
465  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, srcVector,
466  srcVector, indices);
467 
468  return success();
469  }
470 };
471 
472 /// This pattern linearizes `vector.insert` operations. It generates a 1-D
473 /// version of the `vector.insert` operation when inserting a scalar into a
474 /// vector. It generates a 1-D `vector.shuffle` operation when inserting a
475 /// vector into another vector.
476 ///
477 /// Example #1:
478 ///
479 /// %0 = vector.insert %source, %destination[0] :
480 /// vector<2x4xf32> into vector<2x2x4xf32>
481 ///
482 /// is converted to:
483 ///
484 /// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
485 /// %1 = vector.shape_cast %destination :
486 /// vector<2x2x4xf32> to vector<16xf32>
487 /// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
488 /// 8, 9, 10, 11, 12, 13, 14, 15] :
489 /// vector<16xf32>, vector<8xf32>
490 /// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
491 ///
492 /// Example #2:
493 ///
494 /// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
495 ///
496 /// is converted to:
497 ///
498 /// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
499 /// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
500 /// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
501 ///
502 struct LinearizeVectorInsert final
503  : public OpConversionPattern<vector::InsertOp> {
504  using Base::Base;
505  LinearizeVectorInsert(const TypeConverter &typeConverter,
506  MLIRContext *context, PatternBenefit benefit = 1)
507  : OpConversionPattern(typeConverter, context, benefit) {}
508  LogicalResult
509  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
510  ConversionPatternRewriter &rewriter) const override {
511  VectorType dstTy = getTypeConverter()->convertType<VectorType>(
512  insertOp.getDestVectorType());
513  assert(dstTy && "vector type destination expected.");
514 
515  // Dynamic position is not supported.
516  if (insertOp.hasDynamicPosition())
517  return rewriter.notifyMatchFailure(insertOp,
518  "dynamic position is not supported.");
519  auto srcTy = insertOp.getValueToStoreType();
520  auto srcAsVec = dyn_cast<VectorType>(srcTy);
521  uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
522 
523  auto dstShape = insertOp.getDestVectorType().getShape();
524  const auto dstSize = insertOp.getDestVectorType().getNumElements();
525  auto dstSizeForOffsets = dstSize;
526 
527  // Compute linearized offset.
528  int64_t linearizedOffset = 0;
529  auto offsetsNd = insertOp.getStaticPosition();
530  for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
531  dstSizeForOffsets /= dstShape[dim];
532  linearizedOffset += offset * dstSizeForOffsets;
533  }
534 
535  Location loc = insertOp.getLoc();
536  Value valueToStore = adaptor.getValueToStore();
537 
538  if (!isa<VectorType>(valueToStore.getType())) {
539  // Scalar case: generate a 1-D insert.
540  Value result = rewriter.createOrFold<vector::InsertOp>(
541  loc, valueToStore, adaptor.getDest(), linearizedOffset);
542  rewriter.replaceOp(insertOp, result);
543  return success();
544  }
545 
546  // Vector case: generate a shuffle.
547  llvm::SmallVector<int64_t, 2> indices(dstSize);
548  auto *origValsUntil = indices.begin();
549  std::advance(origValsUntil, linearizedOffset);
550 
551  // Original values that remain [0, offset).
552  std::iota(indices.begin(), origValsUntil, 0);
553  auto *newValsUntil = origValsUntil;
554  std::advance(newValsUntil, srcSize);
555  // New values [offset, offset+srcNumElements).
556  std::iota(origValsUntil, newValsUntil, dstSize);
557  // The rest of original values [offset+srcNumElements, end);
558  std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize);
559 
560  Value result = rewriter.createOrFold<vector::ShuffleOp>(
561  loc, dstTy, adaptor.getDest(), valueToStore, indices);
562 
563  rewriter.replaceOp(insertOp, result);
564  return success();
565  }
566 };
567 
568 /// This pattern converts the BitCastOp that works on nD (n > 1)
569 /// vectors to a BitCastOp that works on linearized vectors.
570 /// Following,
571 /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
572 /// is converted to :
573 /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
574 /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
575 /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
576 struct LinearizeVectorBitCast final
577  : public OpConversionPattern<vector::BitCastOp> {
578  using Base::Base;
579  LinearizeVectorBitCast(const TypeConverter &typeConverter,
580  MLIRContext *context, PatternBenefit benefit = 1)
581  : OpConversionPattern(typeConverter, context, benefit) {}
582  LogicalResult
583  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
584  ConversionPatternRewriter &rewriter) const override {
585  auto resType = getTypeConverter()->convertType(castOp.getType());
586  assert(resType && "expected 1-D vector type");
587  rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
588  adaptor.getSource());
589  return mlir::success();
590  }
591 };
592 
593 /// This pattern converts the CreateMaskOp to work on a linearized vector.
594 /// It currently supports only 2D masks with a unit outer dimension.
595 /// Following,
596 /// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
597 /// is converted to:
598 /// %zero = arith.constant 0 : index
599 /// %cmpi = arith.cmpi sgt, %arg0, %zero : index
600 /// %index = arith.index_cast %cmpi : i1 to index
601 /// %mul = arith.andi %index, %arg1 : index
602 /// %mask = vector.create_mask %mul : vector<4xi1>
603 /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
604 struct LinearizeVectorCreateMask final
605  : OpConversionPattern<vector::CreateMaskOp> {
606  using Base::Base;
607 
608  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
609  MLIRContext *context, PatternBenefit benefit = 1)
610  : OpConversionPattern(typeConverter, context, benefit) {}
611 
612  LogicalResult
613  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
614  ConversionPatternRewriter &rewriter) const override {
615  Location loc = createMaskOp.getLoc();
616  VectorType srcTy = createMaskOp.getType();
617  auto srcShape = srcTy.getShape();
618  if (srcShape.size() != 2)
619  return rewriter.notifyMatchFailure(createMaskOp,
620  "only 2D mask is supported.");
621 
622  if (srcShape[0] != 1)
623  return rewriter.notifyMatchFailure(
624  createMaskOp, "only unit outer dimension is supported.");
625 
626  auto dstTy = getTypeConverter()->convertType(srcTy);
627  if (!dstTy)
628  return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
629 
630  // Compare the first operand with 0. If it is greater than 0, the
631  // corresponding mask element is set to true, otherwise false.
632  // The result of the comparison is then multiplied with
633  // the second operand of create_mask to get the 1D mask.
634  auto firstOperand = adaptor.getOperands().front();
635  auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
636  auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
637  loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
638  auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
639  loc, rewriter.getIndexType(), isNonZero);
640  auto secondOperand = adaptor.getOperands().back();
641  auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
642  loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
643 
644  auto newMask =
645  mlir::vector::CreateMaskOp::create(rewriter, loc, dstTy, maskSize);
646  rewriter.replaceOp(createMaskOp, newMask);
647  return success();
648  }
649 };
650 
651 /// This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
652 /// It currently supports linearization where all but the last dimension are 1
653 /// The following,
654 /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
655 /// is converted to:
656 /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
657 /// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
658 /// For generic cases, the vector unroll pass should be used to unroll the load
659 /// to vector<1x1x...xN> form and then linearized
660 struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
661  using Base::Base;
662  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
663  PatternBenefit benefit = 1)
664  : OpConversionPattern(typeConverter, context, benefit) {}
665 
666  LogicalResult
667  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
668  ConversionPatternRewriter &rewriter) const override {
669  VectorType vecTy = loadOp.getType();
670  if (!vecTy)
671  return rewriter.notifyMatchFailure(loadOp, "expected vector type");
672 
673  auto shape = vecTy.getShape();
674  auto scalableDims = vecTy.getScalableDims();
675  // All but the last dim must be 1, and only the last dim may be scalable (if
676  // any).
677  if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
678  return rewriter.notifyMatchFailure(loadOp,
679  "only vector<1x1x...xN> supported");
680 
681  if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
682  return rewriter.notifyMatchFailure(loadOp,
683  "only innermost dim may be scalable");
684 
685  auto linearTy = typeConverter->convertType<VectorType>(vecTy);
686 
687  auto newLoad =
688  vector::LoadOp::create(rewriter, loadOp.getLoc(), linearTy,
689  adaptor.getBase(), adaptor.getIndices());
690  rewriter.replaceOp(loadOp, newLoad.getResult());
691  return success();
692  }
693 };
694 
695 /// This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
696 /// It currently supports linearization where all but the last dimension are 1
697 /// The following,
698 /// vector.store %arg0, %arg1[%c0, %c0]s
699 /// : vector<1x4xf32>, memref<1x4xf32>
700 /// is converted to:
701 /// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
702 /// vector.store %arg0, %arg1[%c0, %c0]
703 /// : vector<4xf32>, memref<1x4xf32>
704 /// For generic cases, the vector unroll pass should be used to unroll the store
705 /// to vector<1x1x...xN> form and then linearized
706 struct LinearizeVectorStore final
707  : public OpConversionPattern<vector::StoreOp> {
708  using Base::Base;
709  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
710  PatternBenefit benefit = 1)
711  : OpConversionPattern(typeConverter, context, benefit) {}
712 
713  LogicalResult
714  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
715  ConversionPatternRewriter &rewriter) const override {
716  VectorType vecTy = storeOp.getValueToStore().getType();
717  if (!vecTy)
718  return rewriter.notifyMatchFailure(storeOp, "expected vector type");
719 
720  auto shape = vecTy.getShape();
721  auto scalableDims = vecTy.getScalableDims();
722  // All but the last dim must be 1, and only the last dim may be scalable (if
723  // any).
724  if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
725  return rewriter.notifyMatchFailure(storeOp,
726  "only vector<1x1x...xN> supported");
727 
728  if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
729  return rewriter.notifyMatchFailure(storeOp,
730  "only innermost dim may be scalable");
731 
732  rewriter.replaceOpWithNewOp<vector::StoreOp>(
733  storeOp, adaptor.getValueToStore(), adaptor.getBase(),
734  adaptor.getIndices());
735  return success();
736  }
737 };
738 
739 /// This pattern linearizes `vector.from_elements` operations by converting
740 /// the result type to a 1-D vector while preserving all element values.
741 /// The transformation creates a linearized `vector.from_elements` followed by
742 /// a `vector.shape_cast` to restore the original multidimensional shape.
743 ///
744 /// Example:
745 ///
746 /// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
747 ///
748 /// is converted to:
749 ///
750 /// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
751 /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
752 ///
753 struct LinearizeVectorFromElements final
754  : public OpConversionPattern<vector::FromElementsOp> {
755  using Base::Base;
756  LinearizeVectorFromElements(const TypeConverter &typeConverter,
757  MLIRContext *context, PatternBenefit benefit = 1)
758  : OpConversionPattern(typeConverter, context, benefit) {}
759  LogicalResult
760  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
761  ConversionPatternRewriter &rewriter) const override {
762  VectorType dstTy =
763  getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
764  assert(dstTy && "vector type destination expected.");
765 
766  OperandRange elements = fromElementsOp.getElements();
767  assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
768  "expected same number of elements");
769  rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
770  elements);
771  return success();
772  }
773 };
774 
775 /// This pattern linearizes the operand in `vector.to_elements` operations
776 /// by converting the source type to a 1-D vector while preserving all element
777 /// values. The transformation creates a linearized `vector.shape_cast`
778 /// followed by a `vector.to_elements`.
779 ///
780 /// Example:
781 ///
782 /// %0:4 = vector.to_elements %v : vector<2x2xf32>
783 ///
784 /// is converted to:
785 ///
786 /// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
787 /// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
788 ///
789 struct LinearizeVectorToElements final
790  : public OpConversionPattern<vector::ToElementsOp> {
791  using Base::Base;
792 
793  LinearizeVectorToElements(const TypeConverter &typeConverter,
794  MLIRContext *context, PatternBenefit benefit = 1)
795  : OpConversionPattern(typeConverter, context, benefit) {}
796 
797  LogicalResult
798  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
799  ConversionPatternRewriter &rewriter) const override {
800 
801  VectorType vecType = toElementsOp.getSource().getType();
802  if (vecType.getRank() <= 1)
803  return rewriter.notifyMatchFailure(
804  toElementsOp, "the rank is already less than or equal to 1");
805 
806  assert(vecType.getNumScalableDims() == 0 &&
807  "to_elements does not support scalable vectors");
808  auto vec1DType =
809  VectorType::get({vecType.getNumElements()}, vecType.getElementType());
810  Value shapeCast = vector::ShapeCastOp::create(
811  rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
812  auto newToElementsOp =
813  vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(),
814  toElementsOp.getResultTypes(), shapeCast);
815  rewriter.replaceOp(toElementsOp, newToElementsOp);
816  return success();
817  }
818 };
819 
820 /// Convert broadcasts from scalars or 1-element vectors, such as
821 ///
822 /// ```mlir
823 /// vector.broadcast %value : f32 to vector<4x4xf32>
824 /// ```
825 ///
826 /// to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
827 /// The above becomes,
828 ///
829 /// ```mlir
830 /// %out_1d = vector.broadcast %value : f32 to vector<16xf32>
831 /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
832 /// ```
833 struct LinearizeVectorBroadcast final
834  : public OpConversionPattern<vector::BroadcastOp> {
835  using Base::Base;
836 
837  LinearizeVectorBroadcast(const TypeConverter &typeConverter,
838  MLIRContext *context, PatternBenefit benefit = 1)
839  : OpConversionPattern(typeConverter, context, benefit) {}
840 
841  LogicalResult
842  matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
843  ConversionPatternRewriter &rewriter) const override {
844 
845  int numElements = 1;
846  Type sourceType = broadcastOp.getSourceType();
847  if (auto vecType = dyn_cast<VectorType>(sourceType)) {
848  numElements = vecType.getNumElements();
849  }
850 
851  if (numElements != 1) {
852  return rewriter.notifyMatchFailure(
853  broadcastOp, "only broadcasts of single elements can be linearized.");
854  }
855 
856  auto dstTy = getTypeConverter()->convertType(broadcastOp.getType());
857  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy,
858  adaptor.getSource());
859 
860  return success();
861  }
862 };
863 
864 } // namespace
865 
866 /// This method defines the set of operations that are linearizable, and hence
867 /// that are considered illegal for the conversion target.
868 static bool isLinearizable(Operation *op) {
869 
870  // Only ops that are in the vector dialect, are ConstantLike, or
871  // are Vectorizable might be linearized currently.
872  StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
873  StringRef opDialect = op->getDialect()->getNamespace();
874  bool supported = (opDialect == vectorDialect) ||
877  if (!supported)
878  return false;
879 
881  // As type legalization is done with vector.shape_cast, shape_cast
882  // itself cannot be linearized (will create new shape_casts to linearize
883  // ad infinitum).
884  .Case<vector::ShapeCastOp>([&](auto) { return false; })
885  // The operations
886  // - vector.extract_strided_slice
887  // - vector.extract
888  // - vector.insert_strided_slice
889  // - vector.insert
890  // are linearized to a rank-1 vector.shuffle by the current patterns.
891  // vector.shuffle only supports fixed size vectors, so it is impossible to
892  // use this approach to linearize these ops if they operate on scalable
893  // vectors.
894  .Case<vector::ExtractStridedSliceOp>(
895  [&](vector::ExtractStridedSliceOp extractOp) {
896  return !extractOp.getType().isScalable();
897  })
898  .Case<vector::InsertStridedSliceOp>(
899  [&](vector::InsertStridedSliceOp insertOp) {
900  return !insertOp.getType().isScalable();
901  })
902  .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
903  return !insertOp.getType().isScalable();
904  })
905  .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
906  return !extractOp.getSourceVectorType().isScalable();
907  })
908  .Default([&](auto) { return true; });
909 }
910 
911 void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
912  ConversionTarget &target) {
913 
914  auto convertType = [](Type type) -> std::optional<Type> {
915  VectorType vectorType = dyn_cast<VectorType>(type);
916  if (!vectorType || !isLinearizableVector(vectorType))
917  return type;
918 
919  VectorType linearizedType =
920  VectorType::get(vectorType.getNumElements(),
921  vectorType.getElementType(), vectorType.isScalable());
922  return linearizedType;
923  };
924  typeConverter.addConversion(convertType);
925 
926  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
927  Location loc) -> Value {
928  if (inputs.size() != 1)
929  return nullptr;
930 
931  Value value = inputs.front();
932  if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
933  return nullptr;
934 
935  return vector::ShapeCastOp::create(builder, loc, type, value);
936  };
937  typeConverter.addSourceMaterialization(materializeCast);
938  typeConverter.addTargetMaterialization(materializeCast);
939 
941  [=](Operation *op) -> std::optional<bool> {
942  if (!isLinearizable(op))
943  return true;
944  // This will return true if, for all operand and result types `t`,
945  // convertType(t) = t. This is true if there are no rank>=2 vectors.
946  return typeConverter.isLegal(op);
947  });
948 }
949 
950 void mlir::vector::populateVectorLinearizeBasePatterns(
951  const TypeConverter &typeConverter, const ConversionTarget &target,
953  patterns
954  .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
955  LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
956  LinearizeVectorBroadcast, LinearizeVectorFromElements,
957  LinearizeVectorToElements>(typeConverter, patterns.getContext());
958 }
959 
960 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
961  const TypeConverter &typeConverter, const ConversionTarget &target,
963  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
964  LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
965  LinearizeVectorInsertStridedSlice>(typeConverter,
966  patterns.getContext());
967 }
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
static bool isLinearizable(Operation *op)
This method defines the set of operations that are linearizable, and hence that are considered illega...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
StringRef getNamespace() const
Definition: Dialect.h:54
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Include the generated interface declarations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.