MLIR  22.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 #include <optional>
26 
27 using namespace mlir;
28 
29 static FailureOr<Attribute>
31  VectorType resType, Attribute value) {
32 
33  if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
34  if (resType.isScalable() && !isa<SplatElementsAttr>(value))
35  return rewriter.notifyMatchFailure(
36  loc,
37  "Cannot linearize a constant scalable vector that's not a splat");
38 
39  return dstElementsAttr.reshape(resType);
40  }
41 
42  if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
43  return poisonAttr;
44 
45  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
46 }
47 
48 namespace {
49 
50 struct LinearizeConstantLike final
51  : OpTraitConversionPattern<OpTrait::ConstantLike> {
53 
54  LinearizeConstantLike(const TypeConverter &typeConverter,
55  MLIRContext *context, PatternBenefit benefit = 1)
56  : OpTraitConversionPattern(typeConverter, context, benefit) {}
57  LogicalResult
58  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59  ConversionPatternRewriter &rewriter) const override {
60  Location loc = op->getLoc();
61  if (op->getNumResults() != 1)
62  return rewriter.notifyMatchFailure(loc, "expected 1 result");
63 
64  const TypeConverter &typeConverter = *getTypeConverter();
65  auto resType =
66  typeConverter.convertType<VectorType>(op->getResult(0).getType());
67  assert(resType && "expected 1-D vector type");
68 
69  StringAttr attrName = rewriter.getStringAttr("value");
70  Attribute value = op->getAttr(attrName);
71  if (!value)
72  return rewriter.notifyMatchFailure(loc, "no 'value' attr");
73 
74  FailureOr<Attribute> newValue =
75  linearizeConstAttr(loc, rewriter, resType, value);
76  if (failed(newValue))
77  return failure();
78 
79  FailureOr<Operation *> convertResult =
80  convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
81  if (failed(convertResult))
82  return failure();
83 
84  Operation *newOp = *convertResult;
85  newOp->setAttr(attrName, *newValue);
86  rewriter.replaceOp(op, newOp);
87  return success();
88  }
89 };
90 
91 struct LinearizeVectorizable final
92  : OpTraitConversionPattern<OpTrait::Vectorizable> {
94 
95 public:
96  LinearizeVectorizable(const TypeConverter &typeConverter,
97  MLIRContext *context, PatternBenefit benefit = 1)
98  : OpTraitConversionPattern(typeConverter, context, benefit) {}
99  LogicalResult
100  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101  ConversionPatternRewriter &rewriter) const override {
102  FailureOr<Operation *> newOp =
103  convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
104  if (failed(newOp))
105  return failure();
106 
107  rewriter.replaceOp(op, (*newOp)->getResults());
108  return success();
109  }
110 };
111 
112 template <typename TOp>
113 static bool stridesAllOne(TOp op) {
114  static_assert(
115  std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116  std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117  "expected vector.extract_strided_slice or vector.insert_strided_slice");
118  ArrayAttr strides = op.getStrides();
119  return llvm::all_of(strides, isOneInteger);
120 }
121 
122 /// Convert an array of attributes into a vector of integers, if possible.
123 static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
124  if (!attrs)
125  return failure();
127  ints.reserve(attrs.size());
128  for (auto attr : attrs) {
129  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130  ints.push_back(intAttr.getInt());
131  } else {
132  return failure();
133  }
134  }
135  return ints;
136 }
137 
138 /// Consider inserting a vector of shape `small` into a vector of shape `large`,
139 /// at position `offsets`: this function enumeratates all the indices in `large`
140 /// that are written to. The enumeration is with row-major ordering.
141 ///
142 /// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143 /// positions written to are (1,3) and (1,4), which have linearized indices 8
144 /// and 9. So [8,9] is returned.
145 ///
146 /// The length of the returned vector is equal to the number of elements in
147 /// the shape `small` (i.e. the product of dimensions of `small`).
148 SmallVector<int64_t> static getStridedSliceInsertionIndices(
150  ArrayRef<int64_t> offsets) {
151 
152  // Example of alignment between, `large`, `small` and `offsets`:
153  // large = 4, 5, 6, 7, 8
154  // small = 1, 6, 7, 8
155  // offsets = 2, 3, 0
156  //
157  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158  assert((large.size() >= small.size()) &&
159  "rank of 'large' cannot be lower than rank of 'small'");
160  assert((large.size() >= offsets.size()) &&
161  "rank of 'large' cannot be lower than the number of offsets");
162  unsigned delta = large.size() - small.size();
163  unsigned nOffsets = offsets.size();
164  auto getSmall = [&](int64_t i) -> int64_t {
165  return i >= delta ? small[i - delta] : 1;
166  };
167  auto getOffset = [&](int64_t i) -> int64_t {
168  return i < nOffsets ? offsets[i] : 0;
169  };
170 
171  // Using 2 vectors of indices, at each iteration populate the updated set of
172  // indices based on the old set of indices, and the size of the small vector
173  // in the current iteration.
174  SmallVector<int64_t> indices{0};
175  int64_t stride = 1;
176  for (int i = large.size() - 1; i >= 0; --i) {
177  int64_t currentSize = indices.size();
178  int64_t smallSize = getSmall(i);
179  int64_t nextSize = currentSize * smallSize;
180  SmallVector<int64_t> nextIndices(nextSize);
181  int64_t *base = nextIndices.begin();
182  int64_t offset = getOffset(i) * stride;
183  for (int j = 0; j < smallSize; ++j) {
184  for (int k = 0; k < currentSize; ++k) {
185  base[k] = indices[k] + offset;
186  }
187  offset += stride;
188  base += currentSize;
189  }
190  stride *= large[i];
191  indices = std::move(nextIndices);
192  }
193  return indices;
194 }
195 
196 /// This pattern converts a vector.extract_strided_slice operation into a
197 /// vector.shuffle operation that has a rank-1 (linearized) operand and result.
198 ///
199 /// For example, the following:
200 ///
201 /// ```
202 /// vector.extract_strided_slice %source
203 /// { offsets = [..], strides = [..], sizes = [..] }
204 /// ```
205 ///
206 /// is converted to :
207 /// ```
208 /// %source_1d = vector.shape_cast %source
209 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
210 /// %out_nd = vector.shape_cast %out_1d
211 /// ```
212 ///
213 /// `shuffle_indices_1d` is computed using the offsets and sizes of the original
214 /// vector.extract_strided_slice operation.
215 struct LinearizeVectorExtractStridedSlice final
216  : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
218  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
219  MLIRContext *context,
220  PatternBenefit benefit = 1)
221  : OpConversionPattern(typeConverter, context, benefit) {}
222 
223  LogicalResult
224  matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
225  OpAdaptor adaptor,
226  ConversionPatternRewriter &rewriter) const override {
227 
228  VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229  extractStridedSliceOp.getType());
230  assert(flatOutputType && "vector type expected");
231 
232  // Expect a legalization failure if the strides are not all 1 (if ever the
233  // verifier for extract_strided_slice allows non-1 strides).
234  if (!stridesAllOne(extractStridedSliceOp)) {
235  return rewriter.notifyMatchFailure(
236  extractStridedSliceOp,
237  "extract_strided_slice with strides != 1 not supported");
238  }
239 
240  FailureOr<SmallVector<int64_t>> offsets =
241  intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242  if (failed(offsets)) {
243  return rewriter.notifyMatchFailure(extractStridedSliceOp,
244  "failed to get integer offsets");
245  }
246 
247  ArrayRef<int64_t> inputShape =
248  extractStridedSliceOp.getSourceVectorType().getShape();
249 
250  ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
251 
252  SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
253  outputShape, inputShape, offsets.value());
254 
255  Value srcVector = adaptor.getVector();
256  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
257  extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
258  return success();
259  }
260 };
261 
262 /// This pattern converts a vector.insert_strided_slice operation into a
263 /// vector.shuffle operation that has rank-1 (linearized) operands and result.
264 ///
265 /// For example, the following:
266 /// ```
267 /// %0 = vector.insert_strided_slice %to_store, %into
268 /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
269 /// : vector<2x2xi8> into vector<2x1x3x2xi8>
270 /// ```
271 ///
272 /// is converted to
273 /// ```
274 /// %to_store_1d
275 /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
276 /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
277 /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
278 /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
279 /// ```
280 ///
281 /// where shuffle_indices_1d in this case is
282 /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
283 /// ^^^^^^^^^^^^^^
284 /// to_store_1d
285 ///
286 struct LinearizeVectorInsertStridedSlice final
287  : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
289  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
290  MLIRContext *context,
291  PatternBenefit benefit = 1)
292  : OpConversionPattern(typeConverter, context, benefit) {}
293 
294  LogicalResult
295  matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
296  OpAdaptor adaptor,
297  ConversionPatternRewriter &rewriter) const override {
298 
299  // Expect a legalization failure if the strides are not all 1 (if ever the
300  // verifier for insert_strided_slice allows non-1 strides).
301  if (!stridesAllOne(insertStridedSliceOp)) {
302  return rewriter.notifyMatchFailure(
303  insertStridedSliceOp,
304  "insert_strided_slice with strides != 1 not supported");
305  }
306 
307  VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
308  ArrayRef<int64_t> inputShape = inputType.getShape();
309 
310  VectorType outputType = insertStridedSliceOp.getType();
311  ArrayRef<int64_t> outputShape = outputType.getShape();
312  int64_t nOutputElements = outputType.getNumElements();
313 
314  FailureOr<SmallVector<int64_t>> offsets =
315  intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316  if (failed(offsets)) {
317  return rewriter.notifyMatchFailure(insertStridedSliceOp,
318  "failed to get integer offsets");
319  }
320  SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
321  inputShape, outputShape, offsets.value());
322 
323  SmallVector<int64_t> indices(nOutputElements);
324  std::iota(indices.begin(), indices.end(), 0);
325  for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
326  indices[sliceIndex] = index + nOutputElements;
327  }
328 
329  Value flatToStore = adaptor.getValueToStore();
330  Value flatDest = adaptor.getDest();
331  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
332  flatDest.getType(), flatDest,
333  flatToStore, indices);
334  return success();
335  }
336 };
337 
338 /// This pattern converts the ShuffleOp that works on nD (n > 1)
339 /// vectors to a ShuffleOp that works on linearized vectors.
340 /// Following,
341 /// vector.shuffle %v1, %v2 [ shuffle_indices ]
342 /// is converted to :
343 /// %v1_1d = vector.shape_cast %v1
344 /// %v2_1d = vector.shape_cast %v2
345 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
346 /// %out_nd = vector.shape_cast %out_1d
347 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
348 /// of the original shuffle operation.
349 struct LinearizeVectorShuffle final
350  : public OpConversionPattern<vector::ShuffleOp> {
352  LinearizeVectorShuffle(const TypeConverter &typeConverter,
353  MLIRContext *context, PatternBenefit benefit = 1)
354  : OpConversionPattern(typeConverter, context, benefit) {}
355 
356  LogicalResult
357  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
358  ConversionPatternRewriter &rewriter) const override {
359  VectorType dstType =
360  getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
361  assert(dstType && "vector type destination expected.");
362 
363  Value vec1 = adaptor.getV1();
364  Value vec2 = adaptor.getV2();
365  int shuffleSliceLen = 1;
366  int rank = shuffleOp.getV1().getType().getRank();
367 
368  // If rank > 1, we need to do the shuffle in the granularity of slices
369  // instead of scalars. Size of the slice is equal to the rank-1 innermost
370  // dims. Mask of the shuffle op specifies which slice to take from the
371  // outermost dim.
372  if (rank > 1) {
373  llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
374  for (unsigned i = 1; i < shape.size(); ++i) {
375  shuffleSliceLen *= shape[i];
376  }
377  }
378 
379  // For each value in the mask, we generate the indices of the source vectors
380  // that need to be shuffled to the destination vector. If shuffleSliceLen >
381  // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
382  // elements) instead of scalars.
383  ArrayRef<int64_t> mask = shuffleOp.getMask();
384  int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
385  llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
386  for (auto [i, value] : llvm::enumerate(mask)) {
387  std::iota(indices.begin() + shuffleSliceLen * i,
388  indices.begin() + shuffleSliceLen * (i + 1),
389  shuffleSliceLen * value);
390  }
391 
392  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
393  vec2, indices);
394  return success();
395  }
396 };
397 
398 /// This pattern linearizes `vector.extract` operations. It generates a 1-D
399 /// version of the `vector.extract` operation when extracting a scalar from a
400 /// vector. It generates a 1-D `vector.shuffle` operation when extracting a
401 /// subvector from a larger vector.
402 ///
403 /// Example #1:
404 ///
405 /// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406 ///
407 /// is converted to:
408 ///
409 /// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410 /// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411 /// 24, 25, 26, 27, 28, 29, 30, 31] :
412 /// vector<32xf32>, vector<32xf32>
413 /// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414 ///
415 /// Example #2:
416 ///
417 /// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418 ///
419 /// is converted to:
420 ///
421 /// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422 /// %1 = vector.extract %0[6] : i32 from vector<8xi32>
423 ///
424 struct LinearizeVectorExtract final
425  : public OpConversionPattern<vector::ExtractOp> {
427  LinearizeVectorExtract(const TypeConverter &typeConverter,
428  MLIRContext *context, PatternBenefit benefit = 1)
429  : OpConversionPattern(typeConverter, context, benefit) {}
430  LogicalResult
431  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
432  ConversionPatternRewriter &rewriter) const override {
433  Type dstTy = getTypeConverter()->convertType(extractOp.getType());
434  assert(dstTy && "expected 1-D vector type");
435 
436  // Dynamic position is not supported.
437  if (extractOp.hasDynamicPosition())
438  return rewriter.notifyMatchFailure(extractOp,
439  "dynamic position is not supported.");
440 
441  llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
442  int64_t size = extractOp.getVector().getType().getNumElements();
443 
444  // Compute linearized offset.
445  int64_t linearizedOffset = 0;
446  llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
447  for (auto [i, off] : llvm::enumerate(offsets)) {
448  size /= shape[i];
449  linearizedOffset += offsets[i] * size;
450  }
451 
452  Value srcVector = adaptor.getVector();
453  if (!isa<VectorType>(extractOp.getType())) {
454  // Scalar case: generate a 1-D extract.
455  Value result = rewriter.createOrFold<vector::ExtractOp>(
456  extractOp.getLoc(), srcVector, linearizedOffset);
457  rewriter.replaceOp(extractOp, result);
458  return success();
459  }
460 
461  // Vector case: generate a shuffle.
462 
463  llvm::SmallVector<int64_t, 2> indices(size);
464  std::iota(indices.begin(), indices.end(), linearizedOffset);
465  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, srcVector,
466  srcVector, indices);
467 
468  return success();
469  }
470 };
471 
472 /// This pattern linearizes `vector.insert` operations. It generates a 1-D
473 /// version of the `vector.insert` operation when inserting a scalar into a
474 /// vector. It generates a 1-D `vector.shuffle` operation when inserting a
475 /// vector into another vector.
476 ///
477 /// Example #1:
478 ///
479 /// %0 = vector.insert %source, %destination[0] :
480 /// vector<2x4xf32> into vector<2x2x4xf32>
481 ///
482 /// is converted to:
483 ///
484 /// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
485 /// %1 = vector.shape_cast %destination :
486 /// vector<2x2x4xf32> to vector<16xf32>
487 /// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
488 /// 8, 9, 10, 11, 12, 13, 14, 15] :
489 /// vector<16xf32>, vector<8xf32>
490 /// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
491 ///
492 /// Example #2:
493 ///
494 /// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
495 ///
496 /// is converted to:
497 ///
498 /// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
499 /// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
500 /// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
501 ///
502 struct LinearizeVectorInsert final
503  : public OpConversionPattern<vector::InsertOp> {
505  LinearizeVectorInsert(const TypeConverter &typeConverter,
506  MLIRContext *context, PatternBenefit benefit = 1)
507  : OpConversionPattern(typeConverter, context, benefit) {}
508  LogicalResult
509  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
510  ConversionPatternRewriter &rewriter) const override {
511  VectorType dstTy = getTypeConverter()->convertType<VectorType>(
512  insertOp.getDestVectorType());
513  assert(dstTy && "vector type destination expected.");
514 
515  // Dynamic position is not supported.
516  if (insertOp.hasDynamicPosition())
517  return rewriter.notifyMatchFailure(insertOp,
518  "dynamic position is not supported.");
519  auto srcTy = insertOp.getValueToStoreType();
520  auto srcAsVec = dyn_cast<VectorType>(srcTy);
521  uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
522 
523  auto dstShape = insertOp.getDestVectorType().getShape();
524  const auto dstSize = insertOp.getDestVectorType().getNumElements();
525  auto dstSizeForOffsets = dstSize;
526 
527  // Compute linearized offset.
528  int64_t linearizedOffset = 0;
529  auto offsetsNd = insertOp.getStaticPosition();
530  for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
531  dstSizeForOffsets /= dstShape[dim];
532  linearizedOffset += offset * dstSizeForOffsets;
533  }
534 
535  Location loc = insertOp.getLoc();
536  Value valueToStore = adaptor.getValueToStore();
537 
538  if (!isa<VectorType>(valueToStore.getType())) {
539  // Scalar case: generate a 1-D insert.
540  Value result = rewriter.createOrFold<vector::InsertOp>(
541  loc, valueToStore, adaptor.getDest(), linearizedOffset);
542  rewriter.replaceOp(insertOp, result);
543  return success();
544  }
545 
546  // Vector case: generate a shuffle.
547  llvm::SmallVector<int64_t, 2> indices(dstSize);
548  auto *origValsUntil = indices.begin();
549  std::advance(origValsUntil, linearizedOffset);
550 
551  // Original values that remain [0, offset).
552  std::iota(indices.begin(), origValsUntil, 0);
553  auto *newValsUntil = origValsUntil;
554  std::advance(newValsUntil, srcSize);
555  // New values [offset, offset+srcNumElements).
556  std::iota(origValsUntil, newValsUntil, dstSize);
557  // The rest of original values [offset+srcNumElements, end);
558  std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize);
559 
560  Value result = rewriter.createOrFold<vector::ShuffleOp>(
561  loc, dstTy, adaptor.getDest(), valueToStore, indices);
562 
563  rewriter.replaceOp(insertOp, result);
564  return success();
565  }
566 };
567 
568 /// This pattern converts the BitCastOp that works on nD (n > 1)
569 /// vectors to a BitCastOp that works on linearized vectors.
570 /// Following,
571 /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
572 /// is converted to :
573 /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
574 /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
575 /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
576 struct LinearizeVectorBitCast final
577  : public OpConversionPattern<vector::BitCastOp> {
579  LinearizeVectorBitCast(const TypeConverter &typeConverter,
580  MLIRContext *context, PatternBenefit benefit = 1)
581  : OpConversionPattern(typeConverter, context, benefit) {}
582  LogicalResult
583  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
584  ConversionPatternRewriter &rewriter) const override {
585  auto resType = getTypeConverter()->convertType(castOp.getType());
586  assert(resType && "expected 1-D vector type");
587  rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
588  adaptor.getSource());
589  return mlir::success();
590  }
591 };
592 
593 /// This pattern converts the SplatOp to work on a linearized vector.
594 /// Following,
595 /// vector.splat %value : vector<4x4xf32>
596 /// is converted to:
597 /// %out_1d = vector.splat %value : vector<16xf32>
598 /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
599 struct LinearizeVectorSplat final
600  : public OpConversionPattern<vector::SplatOp> {
602 
603  LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
604  PatternBenefit benefit = 1)
605  : OpConversionPattern(typeConverter, context, benefit) {}
606 
607  LogicalResult
608  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
609  ConversionPatternRewriter &rewriter) const override {
610  auto dstTy = getTypeConverter()->convertType(splatOp.getType());
611  if (!dstTy)
612  return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
613  rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
614  dstTy);
615  return success();
616  }
617 };
618 
619 /// This pattern converts the CreateMaskOp to work on a linearized vector.
620 /// It currently supports only 2D masks with a unit outer dimension.
621 /// Following,
622 /// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
623 /// is converted to:
624 /// %zero = arith.constant 0 : index
625 /// %cmpi = arith.cmpi sgt, %arg0, %zero : index
626 /// %index = arith.index_cast %cmpi : i1 to index
627 /// %mul = arith.andi %index, %arg1 : index
628 /// %mask = vector.create_mask %mul : vector<4xi1>
629 /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
630 struct LinearizeVectorCreateMask final
631  : OpConversionPattern<vector::CreateMaskOp> {
633 
634  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
635  MLIRContext *context, PatternBenefit benefit = 1)
636  : OpConversionPattern(typeConverter, context, benefit) {}
637 
638  LogicalResult
639  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
640  ConversionPatternRewriter &rewriter) const override {
641  Location loc = createMaskOp.getLoc();
642  VectorType srcTy = createMaskOp.getType();
643  auto srcShape = srcTy.getShape();
644  if (srcShape.size() != 2)
645  return rewriter.notifyMatchFailure(createMaskOp,
646  "only 2D mask is supported.");
647 
648  if (srcShape[0] != 1)
649  return rewriter.notifyMatchFailure(
650  createMaskOp, "only unit outer dimension is supported.");
651 
652  auto dstTy = getTypeConverter()->convertType(srcTy);
653  if (!dstTy)
654  return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
655 
656  // Compare the first operand with 0. If it is greater than 0, the
657  // corresponding mask element is set to true, otherwise false.
658  // The result of the comparison is then multiplied with
659  // the second operand of create_mask to get the 1D mask.
660  auto firstOperand = adaptor.getOperands().front();
661  auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
662  auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
663  loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
664  auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
665  loc, rewriter.getIndexType(), isNonZero);
666  auto secondOperand = adaptor.getOperands().back();
667  auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
668  loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
669 
670  auto newMask =
671  rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
672  rewriter.replaceOp(createMaskOp, newMask);
673  return success();
674  }
675 };
676 
677 /// This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
678 /// It currently supports linearization where all but the last dimension are 1
679 /// The following,
680 /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
681 /// is converted to:
682 /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
683 /// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
684 /// For generic cases, the vector unroll pass should be used to unroll the load
685 /// to vector<1x1x...xN> form and then linearized
686 struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
688  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
689  PatternBenefit benefit = 1)
690  : OpConversionPattern(typeConverter, context, benefit) {}
691 
692  LogicalResult
693  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
694  ConversionPatternRewriter &rewriter) const override {
695  VectorType vecTy = loadOp.getType();
696  if (!vecTy)
697  return rewriter.notifyMatchFailure(loadOp, "expected vector type");
698 
699  auto shape = vecTy.getShape();
700  auto scalableDims = vecTy.getScalableDims();
701  // All but the last dim must be 1, and only the last dim may be scalable (if
702  // any).
703  if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
704  return rewriter.notifyMatchFailure(loadOp,
705  "only vector<1x1x...xN> supported");
706 
707  if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
708  return rewriter.notifyMatchFailure(loadOp,
709  "only innermost dim may be scalable");
710 
711  auto linearTy = typeConverter->convertType<VectorType>(vecTy);
712 
713  auto newLoad = rewriter.create<vector::LoadOp>(
714  loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
715  rewriter.replaceOp(loadOp, newLoad.getResult());
716  return success();
717  }
718 };
719 
720 /// This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
721 /// It currently supports linearization where all but the last dimension are 1
722 /// The following,
723 /// vector.store %arg0, %arg1[%c0, %c0]s
724 /// : vector<1x4xf32>, memref<1x4xf32>
725 /// is converted to:
726 /// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
727 /// vector.store %arg0, %arg1[%c0, %c0]
728 /// : vector<4xf32>, memref<1x4xf32>
729 /// For generic cases, the vector unroll pass should be used to unroll the store
730 /// to vector<1x1x...xN> form and then linearized
731 struct LinearizeVectorStore final
732  : public OpConversionPattern<vector::StoreOp> {
734  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
735  PatternBenefit benefit = 1)
736  : OpConversionPattern(typeConverter, context, benefit) {}
737 
738  LogicalResult
739  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
740  ConversionPatternRewriter &rewriter) const override {
741  VectorType vecTy = storeOp.getValueToStore().getType();
742  if (!vecTy)
743  return rewriter.notifyMatchFailure(storeOp, "expected vector type");
744 
745  auto shape = vecTy.getShape();
746  auto scalableDims = vecTy.getScalableDims();
747  // All but the last dim must be 1, and only the last dim may be scalable (if
748  // any).
749  if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
750  return rewriter.notifyMatchFailure(storeOp,
751  "only vector<1x1x...xN> supported");
752 
753  if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
754  return rewriter.notifyMatchFailure(storeOp,
755  "only innermost dim may be scalable");
756 
757  rewriter.replaceOpWithNewOp<vector::StoreOp>(
758  storeOp, adaptor.getValueToStore(), adaptor.getBase(),
759  adaptor.getIndices());
760  return success();
761  }
762 };
763 
764 } // namespace
765 
766 /// This method defines the set of operations that are linearizable, and hence
767 /// that are considered illegal for the conversion target.
768 static bool isLinearizable(Operation *op) {
769 
770  // Only ops that are in the vector dialect, are ConstantLike, or
771  // are Vectorizable might be linearized currently.
772  StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
773  StringRef opDialect = op->getDialect()->getNamespace();
774  bool supported = (opDialect == vectorDialect) ||
777  if (!supported)
778  return false;
779 
781  // As type legalization is done with vector.shape_cast, shape_cast
782  // itself cannot be linearized (will create new shape_casts to linearize
783  // ad infinitum).
784  .Case<vector::ShapeCastOp>([&](auto) { return false; })
785  // The operations
786  // - vector.extract_strided_slice
787  // - vector.extract
788  // - vector.insert_strided_slice
789  // - vector.insert
790  // are linearized to a rank-1 vector.shuffle by the current patterns.
791  // vector.shuffle only supports fixed size vectors, so it is impossible to
792  // use this approach to linearize these ops if they operate on scalable
793  // vectors.
794  .Case<vector::ExtractStridedSliceOp>(
795  [&](vector::ExtractStridedSliceOp extractOp) {
796  return !extractOp.getType().isScalable();
797  })
798  .Case<vector::InsertStridedSliceOp>(
799  [&](vector::InsertStridedSliceOp insertOp) {
800  return !insertOp.getType().isScalable();
801  })
802  .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
803  return !insertOp.getType().isScalable();
804  })
805  .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
806  return !extractOp.getSourceVectorType().isScalable();
807  })
808  .Default([&](auto) { return true; });
809 }
810 
811 void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
812  ConversionTarget &target) {
813 
814  auto convertType = [](Type type) -> std::optional<Type> {
815  VectorType vectorType = dyn_cast<VectorType>(type);
816  if (!vectorType || !isLinearizableVector(vectorType))
817  return type;
818 
819  VectorType linearizedType =
820  VectorType::get(vectorType.getNumElements(),
821  vectorType.getElementType(), vectorType.isScalable());
822  return linearizedType;
823  };
824  typeConverter.addConversion(convertType);
825 
826  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
827  Location loc) -> Value {
828  if (inputs.size() != 1)
829  return nullptr;
830 
831  Value value = inputs.front();
832  if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
833  return nullptr;
834 
835  return builder.create<vector::ShapeCastOp>(loc, type, value);
836  };
837  typeConverter.addSourceMaterialization(materializeCast);
838  typeConverter.addTargetMaterialization(materializeCast);
839 
841  [=](Operation *op) -> std::optional<bool> {
842  if (!isLinearizable(op))
843  return true;
844  // This will return true if, for all operand and result types `t`,
845  // convertType(t) = t. This is true if there are no rank>=2 vectors.
846  return typeConverter.isLegal(op);
847  });
848 }
849 
850 void mlir::vector::populateVectorLinearizeBasePatterns(
851  const TypeConverter &typeConverter, const ConversionTarget &target,
853  patterns
854  .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
855  LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
856  LinearizeVectorStore>(typeConverter, patterns.getContext());
857 }
858 
859 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
860  const TypeConverter &typeConverter, const ConversionTarget &target,
862  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
863  LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
864  LinearizeVectorInsertStridedSlice>(typeConverter,
865  patterns.getContext());
866 }
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
static bool isLinearizable(Operation *op)
This method defines the set of operations that are linearizable, and hence that are considered illega...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
StringRef getNamespace() const
Definition: Dialect.h:54
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:97
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Include the generated interface declarations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.