MLIR  21.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 #include <optional>
26 
27 using namespace mlir;
28 
29 static FailureOr<Attribute>
31  VectorType resType, Attribute value) {
32 
33  if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
34  if (resType.isScalable() && !isa<SplatElementsAttr>(value))
35  return rewriter.notifyMatchFailure(
36  loc,
37  "Cannot linearize a constant scalable vector that's not a splat");
38 
39  return dstElementsAttr.reshape(resType);
40  }
41 
42  if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
43  return poisonAttr;
44 
45  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
46 }
47 
48 namespace {
49 
50 struct LinearizeConstantLike final
51  : OpTraitConversionPattern<OpTrait::ConstantLike> {
53 
54  LinearizeConstantLike(const TypeConverter &typeConverter,
55  MLIRContext *context, PatternBenefit benefit = 1)
56  : OpTraitConversionPattern(typeConverter, context, benefit) {}
57  LogicalResult
58  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
59  ConversionPatternRewriter &rewriter) const override {
60  Location loc = op->getLoc();
61  if (op->getNumResults() != 1)
62  return rewriter.notifyMatchFailure(loc, "expected 1 result");
63 
64  const TypeConverter &typeConverter = *getTypeConverter();
65  auto resType =
66  typeConverter.convertType<VectorType>(op->getResult(0).getType());
67  assert(resType && "expected 1-D vector type");
68 
69  StringAttr attrName = rewriter.getStringAttr("value");
70  Attribute value = op->getAttr(attrName);
71  if (!value)
72  return rewriter.notifyMatchFailure(loc, "no 'value' attr");
73 
74  FailureOr<Attribute> newValue =
75  linearizeConstAttr(loc, rewriter, resType, value);
76  if (failed(newValue))
77  return failure();
78 
79  FailureOr<Operation *> convertResult =
80  convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
81  if (failed(convertResult))
82  return failure();
83 
84  Operation *newOp = *convertResult;
85  newOp->setAttr(attrName, *newValue);
86  rewriter.replaceOp(op, newOp);
87  return success();
88  }
89 };
90 
91 struct LinearizeVectorizable final
92  : OpTraitConversionPattern<OpTrait::Vectorizable> {
94 
95 public:
96  LinearizeVectorizable(const TypeConverter &typeConverter,
97  MLIRContext *context, PatternBenefit benefit = 1)
98  : OpTraitConversionPattern(typeConverter, context, benefit) {}
99  LogicalResult
100  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101  ConversionPatternRewriter &rewriter) const override {
102  FailureOr<Operation *> newOp =
103  convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
104  if (failed(newOp))
105  return failure();
106 
107  rewriter.replaceOp(op, (*newOp)->getResults());
108  return success();
109  }
110 };
111 
112 template <typename TOp>
113 static bool stridesAllOne(TOp op) {
114  static_assert(
115  std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116  std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117  "expected vector.extract_strided_slice or vector.insert_strided_slice");
118  ArrayAttr strides = op.getStrides();
119  return llvm::all_of(strides, isOneInteger);
120 }
121 
122 /// Convert an array of attributes into a vector of integers, if possible.
123 static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
124  if (!attrs)
125  return failure();
127  ints.reserve(attrs.size());
128  for (auto attr : attrs) {
129  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130  ints.push_back(intAttr.getInt());
131  } else {
132  return failure();
133  }
134  }
135  return ints;
136 }
137 
138 /// Consider inserting a vector of shape `small` into a vector of shape `large`,
139 /// at position `offsets`: this function enumeratates all the indices in `large`
140 /// that are written to. The enumeration is with row-major ordering.
141 ///
142 /// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143 /// positions written to are (1,3) and (1,4), which have linearized indices 8
144 /// and 9. So [8,9] is returned.
145 ///
146 /// The length of the returned vector is equal to the number of elements in
147 /// the shape `small` (i.e. the product of dimensions of `small`).
148 SmallVector<int64_t> static getStridedSliceInsertionIndices(
150  ArrayRef<int64_t> offsets) {
151 
152  // Example of alignment between, `large`, `small` and `offsets`:
153  // large = 4, 5, 6, 7, 8
154  // small = 1, 6, 7, 8
155  // offsets = 2, 3, 0
156  //
157  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158  assert((large.size() >= small.size()) &&
159  "rank of 'large' cannot be lower than rank of 'small'");
160  assert((large.size() >= offsets.size()) &&
161  "rank of 'large' cannot be lower than the number of offsets");
162  unsigned delta = large.size() - small.size();
163  unsigned nOffsets = offsets.size();
164  auto getSmall = [&](int64_t i) -> int64_t {
165  return i >= delta ? small[i - delta] : 1;
166  };
167  auto getOffset = [&](int64_t i) -> int64_t {
168  return i < nOffsets ? offsets[i] : 0;
169  };
170 
171  // Using 2 vectors of indices, at each iteration populate the updated set of
172  // indices based on the old set of indices, and the size of the small vector
173  // in the current iteration.
174  SmallVector<int64_t> indices{0};
175  int64_t stride = 1;
176  for (int i = large.size() - 1; i >= 0; --i) {
177  int64_t currentSize = indices.size();
178  int64_t smallSize = getSmall(i);
179  int64_t nextSize = currentSize * smallSize;
180  SmallVector<int64_t> nextIndices(nextSize);
181  int64_t *base = nextIndices.begin();
182  int64_t offset = getOffset(i) * stride;
183  for (int j = 0; j < smallSize; ++j) {
184  for (int k = 0; k < currentSize; ++k) {
185  base[k] = indices[k] + offset;
186  }
187  offset += stride;
188  base += currentSize;
189  }
190  stride *= large[i];
191  indices = std::move(nextIndices);
192  }
193  return indices;
194 }
195 
196 /// This pattern converts a vector.extract_strided_slice operation into a
197 /// vector.shuffle operation that has a rank-1 (linearized) operand and result.
198 ///
199 /// For example, the following:
200 ///
201 /// ```
202 /// vector.extract_strided_slice %source
203 /// { offsets = [..], strides = [..], sizes = [..] }
204 /// ```
205 ///
206 /// is converted to :
207 /// ```
208 /// %source_1d = vector.shape_cast %source
209 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
210 /// %out_nd = vector.shape_cast %out_1d
211 /// ```
212 ///
213 /// `shuffle_indices_1d` is computed using the offsets and sizes of the original
214 /// vector.extract_strided_slice operation.
215 struct LinearizeVectorExtractStridedSlice final
216  : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
218  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
219  MLIRContext *context,
220  PatternBenefit benefit = 1)
221  : OpConversionPattern(typeConverter, context, benefit) {}
222 
223  LogicalResult
224  matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
225  OpAdaptor adaptor,
226  ConversionPatternRewriter &rewriter) const override {
227 
228  VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229  extractStridedSliceOp.getType());
230  assert(flatOutputType && "vector type expected");
231 
232  // Expect a legalization failure if the strides are not all 1 (if ever the
233  // verifier for extract_strided_slice allows non-1 strides).
234  if (!stridesAllOne(extractStridedSliceOp)) {
235  return rewriter.notifyMatchFailure(
236  extractStridedSliceOp,
237  "extract_strided_slice with strides != 1 not supported");
238  }
239 
240  FailureOr<SmallVector<int64_t>> offsets =
241  intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242  if (failed(offsets)) {
243  return rewriter.notifyMatchFailure(extractStridedSliceOp,
244  "failed to get integer offsets");
245  }
246 
247  ArrayRef<int64_t> inputShape =
248  extractStridedSliceOp.getSourceVectorType().getShape();
249 
250  ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
251 
252  SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
253  outputShape, inputShape, offsets.value());
254 
255  Value srcVector = adaptor.getVector();
256  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
257  extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
258  return success();
259  }
260 };
261 
262 /// This pattern converts a vector.insert_strided_slice operation into a
263 /// vector.shuffle operation that has rank-1 (linearized) operands and result.
264 ///
265 /// For example, the following:
266 /// ```
267 /// %0 = vector.insert_strided_slice %to_store, %into
268 /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
269 /// : vector<2x2xi8> into vector<2x1x3x2xi8>
270 /// ```
271 ///
272 /// is converted to
273 /// ```
274 /// %to_store_1d
275 /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
276 /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
277 /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
278 /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
279 /// ```
280 ///
281 /// where shuffle_indices_1d in this case is
282 /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
283 /// ^^^^^^^^^^^^^^
284 /// to_store_1d
285 ///
286 struct LinearizeVectorInsertStridedSlice final
287  : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
289  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
290  MLIRContext *context,
291  PatternBenefit benefit = 1)
292  : OpConversionPattern(typeConverter, context, benefit) {}
293 
294  LogicalResult
295  matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
296  OpAdaptor adaptor,
297  ConversionPatternRewriter &rewriter) const override {
298 
299  // Expect a legalization failure if the strides are not all 1 (if ever the
300  // verifier for insert_strided_slice allows non-1 strides).
301  if (!stridesAllOne(insertStridedSliceOp)) {
302  return rewriter.notifyMatchFailure(
303  insertStridedSliceOp,
304  "insert_strided_slice with strides != 1 not supported");
305  }
306 
307  VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
308  ArrayRef<int64_t> inputShape = inputType.getShape();
309 
310  VectorType outputType = insertStridedSliceOp.getType();
311  ArrayRef<int64_t> outputShape = outputType.getShape();
312  int64_t nOutputElements = outputType.getNumElements();
313 
314  FailureOr<SmallVector<int64_t>> offsets =
315  intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316  if (failed(offsets)) {
317  return rewriter.notifyMatchFailure(insertStridedSliceOp,
318  "failed to get integer offsets");
319  }
320  SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
321  inputShape, outputShape, offsets.value());
322 
323  SmallVector<int64_t> indices(nOutputElements);
324  std::iota(indices.begin(), indices.end(), 0);
325  for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
326  indices[sliceIndex] = index + nOutputElements;
327  }
328 
329  Value flatToStore = adaptor.getValueToStore();
330  Value flatDest = adaptor.getDest();
331  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
332  flatDest.getType(), flatDest,
333  flatToStore, indices);
334  return success();
335  }
336 };
337 
338 /// This pattern converts the ShuffleOp that works on nD (n > 1)
339 /// vectors to a ShuffleOp that works on linearized vectors.
340 /// Following,
341 /// vector.shuffle %v1, %v2 [ shuffle_indices ]
342 /// is converted to :
343 /// %v1_1d = vector.shape_cast %v1
344 /// %v2_1d = vector.shape_cast %v2
345 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
346 /// %out_nd = vector.shape_cast %out_1d
347 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
348 /// of the original shuffle operation.
349 struct LinearizeVectorShuffle final
350  : public OpConversionPattern<vector::ShuffleOp> {
352  LinearizeVectorShuffle(const TypeConverter &typeConverter,
353  MLIRContext *context, PatternBenefit benefit = 1)
354  : OpConversionPattern(typeConverter, context, benefit) {}
355 
356  LogicalResult
357  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
358  ConversionPatternRewriter &rewriter) const override {
359  VectorType dstType =
360  getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
361  assert(dstType && "vector type destination expected.");
362 
363  Value vec1 = adaptor.getV1();
364  Value vec2 = adaptor.getV2();
365  int shuffleSliceLen = 1;
366  int rank = shuffleOp.getV1().getType().getRank();
367 
368  // If rank > 1, we need to do the shuffle in the granularity of slices
369  // instead of scalars. Size of the slice is equal to the rank-1 innermost
370  // dims. Mask of the shuffle op specifies which slice to take from the
371  // outermost dim.
372  if (rank > 1) {
373  llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
374  for (unsigned i = 1; i < shape.size(); ++i) {
375  shuffleSliceLen *= shape[i];
376  }
377  }
378 
379  // For each value in the mask, we generate the indices of the source vectors
380  // that need to be shuffled to the destination vector. If shuffleSliceLen >
381  // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
382  // elements) instead of scalars.
383  ArrayRef<int64_t> mask = shuffleOp.getMask();
384  int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
385  llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
386  for (auto [i, value] : llvm::enumerate(mask)) {
387  std::iota(indices.begin() + shuffleSliceLen * i,
388  indices.begin() + shuffleSliceLen * (i + 1),
389  shuffleSliceLen * value);
390  }
391 
392  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
393  vec2, indices);
394  return success();
395  }
396 };
397 
398 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
399 /// linearized vector.
400 /// Following,
401 /// vector.extract %source [ position ]
402 /// is converted to :
403 /// %source_1d = vector.shape_cast %source
404 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
405 /// %out_nd = vector.shape_cast %out_1d
406 /// `shuffle_indices_1d` is computed using the position of the original extract.
407 struct LinearizeVectorExtract final
408  : public OpConversionPattern<vector::ExtractOp> {
410  LinearizeVectorExtract(const TypeConverter &typeConverter,
411  MLIRContext *context, PatternBenefit benefit = 1)
412  : OpConversionPattern(typeConverter, context, benefit) {}
413  LogicalResult
414  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
415  ConversionPatternRewriter &rewriter) const override {
416  // Skip if result is not a vector type
417  if (!isa<VectorType>(extractOp.getType()))
418  return rewriter.notifyMatchFailure(extractOp,
419  "scalar extract not supported");
420  Type dstTy = getTypeConverter()->convertType(extractOp.getType());
421  assert(dstTy && "expected 1-D vector type");
422 
423  // Dynamic position is not supported.
424  if (extractOp.hasDynamicPosition())
425  return rewriter.notifyMatchFailure(extractOp,
426  "dynamic position is not supported.");
427 
428  llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
429  int64_t size = extractOp.getVector().getType().getNumElements();
430 
431  // Compute linearized offset.
432  int64_t linearizedOffset = 0;
433  llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
434  for (auto [i, off] : llvm::enumerate(offsets)) {
435  size /= shape[i];
436  linearizedOffset += offsets[i] * size;
437  }
438 
439  llvm::SmallVector<int64_t, 2> indices(size);
440  std::iota(indices.begin(), indices.end(), linearizedOffset);
441  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
442  extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
443 
444  return success();
445  }
446 };
447 
448 /// This pattern converts the InsertOp to a ShuffleOp that works on a
449 /// linearized vector.
450 /// Following,
451 /// vector.insert %source %destination [ position ]
452 /// is converted to :
453 /// %source_1d = vector.shape_cast %source
454 /// %destination_1d = vector.shape_cast %destination
455 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
456 /// ] %out_nd = vector.shape_cast %out_1d
457 /// `shuffle_indices_1d` is computed using the position of the original insert.
458 struct LinearizeVectorInsert final
459  : public OpConversionPattern<vector::InsertOp> {
461  LinearizeVectorInsert(const TypeConverter &typeConverter,
462  MLIRContext *context, PatternBenefit benefit = 1)
463  : OpConversionPattern(typeConverter, context, benefit) {}
464  LogicalResult
465  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
466  ConversionPatternRewriter &rewriter) const override {
467  VectorType dstTy = getTypeConverter()->convertType<VectorType>(
468  insertOp.getDestVectorType());
469  assert(dstTy && "vector type destination expected.");
470 
471  // dynamic position is not supported
472  if (insertOp.hasDynamicPosition())
473  return rewriter.notifyMatchFailure(insertOp,
474  "dynamic position is not supported.");
475  auto srcTy = insertOp.getValueToStoreType();
476  auto srcAsVec = dyn_cast<VectorType>(srcTy);
477  uint64_t srcSize = 0;
478  if (srcAsVec) {
479  srcSize = srcAsVec.getNumElements();
480  } else {
481  return rewriter.notifyMatchFailure(insertOp,
482  "scalars are not supported.");
483  }
484 
485  auto dstShape = insertOp.getDestVectorType().getShape();
486  const auto dstSize = insertOp.getDestVectorType().getNumElements();
487  auto dstSizeForOffsets = dstSize;
488 
489  // compute linearized offset
490  int64_t linearizedOffset = 0;
491  auto offsetsNd = insertOp.getStaticPosition();
492  for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
493  dstSizeForOffsets /= dstShape[dim];
494  linearizedOffset += offset * dstSizeForOffsets;
495  }
496 
497  llvm::SmallVector<int64_t, 2> indices(dstSize);
498  auto *origValsUntil = indices.begin();
499  std::advance(origValsUntil, linearizedOffset);
500  std::iota(indices.begin(), origValsUntil,
501  0); // original values that remain [0, offset)
502  auto *newValsUntil = origValsUntil;
503  std::advance(newValsUntil, srcSize);
504  std::iota(origValsUntil, newValsUntil,
505  dstSize); // new values [offset, offset+srcNumElements)
506  std::iota(newValsUntil, indices.end(),
507  linearizedOffset + srcSize); // the rest of original values
508  // [offset+srcNumElements, end)
509 
510  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
511  insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
512 
513  return success();
514  }
515 };
516 
517 /// This pattern converts the BitCastOp that works on nD (n > 1)
518 /// vectors to a BitCastOp that works on linearized vectors.
519 /// Following,
520 /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
521 /// is converted to :
522 /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
523 /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
524 /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
525 struct LinearizeVectorBitCast final
526  : public OpConversionPattern<vector::BitCastOp> {
528  LinearizeVectorBitCast(const TypeConverter &typeConverter,
529  MLIRContext *context, PatternBenefit benefit = 1)
530  : OpConversionPattern(typeConverter, context, benefit) {}
531  LogicalResult
532  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
533  ConversionPatternRewriter &rewriter) const override {
534  auto resType = getTypeConverter()->convertType(castOp.getType());
535  assert(resType && "expected 1-D vector type");
536  rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
537  adaptor.getSource());
538  return mlir::success();
539  }
540 };
541 
542 /// This pattern converts the SplatOp to work on a linearized vector.
543 /// Following,
544 /// vector.splat %value : vector<4x4xf32>
545 /// is converted to:
546 /// %out_1d = vector.splat %value : vector<16xf32>
547 /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
548 struct LinearizeVectorSplat final
549  : public OpConversionPattern<vector::SplatOp> {
551 
552  LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
553  PatternBenefit benefit = 1)
554  : OpConversionPattern(typeConverter, context, benefit) {}
555 
556  LogicalResult
557  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
558  ConversionPatternRewriter &rewriter) const override {
559  auto dstTy = getTypeConverter()->convertType(splatOp.getType());
560  if (!dstTy)
561  return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
562  rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
563  dstTy);
564  return success();
565  }
566 };
567 
568 /// This pattern converts the CreateMaskOp to work on a linearized vector.
569 /// It currently supports only 2D masks with a unit outer dimension.
570 /// Following,
571 /// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
572 /// is converted to:
573 /// %zero = arith.constant 0 : index
574 /// %cmpi = arith.cmpi sgt, %arg0, %zero : index
575 /// %index = arith.index_cast %cmpi : i1 to index
576 /// %mul = arith.andi %index, %arg1 : index
577 /// %mask = vector.create_mask %mul : vector<4xi1>
578 /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
579 struct LinearizeVectorCreateMask final
580  : OpConversionPattern<vector::CreateMaskOp> {
582 
583  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
584  MLIRContext *context, PatternBenefit benefit = 1)
585  : OpConversionPattern(typeConverter, context, benefit) {}
586 
587  LogicalResult
588  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
589  ConversionPatternRewriter &rewriter) const override {
590  Location loc = createMaskOp.getLoc();
591  VectorType srcTy = createMaskOp.getType();
592  auto srcShape = srcTy.getShape();
593  if (srcShape.size() != 2)
594  return rewriter.notifyMatchFailure(createMaskOp,
595  "only 2D mask is supported.");
596 
597  if (srcShape[0] != 1)
598  return rewriter.notifyMatchFailure(
599  createMaskOp, "only unit outer dimension is supported.");
600 
601  auto dstTy = getTypeConverter()->convertType(srcTy);
602  if (!dstTy)
603  return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
604 
605  // Compare the first operand with 0. If it is greater than 0, the
606  // corresponding mask element is set to true, otherwise false.
607  // The result of the comparison is then multiplied with
608  // the second operand of create_mask to get the 1D mask.
609  auto firstOperand = adaptor.getOperands().front();
610  auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
611  auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
612  loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
613  auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
614  loc, rewriter.getIndexType(), isNonZero);
615  auto secondOperand = adaptor.getOperands().back();
616  auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
617  loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
618 
619  auto newMask =
620  rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
621  rewriter.replaceOp(createMaskOp, newMask);
622  return success();
623  }
624 };
625 
626 } // namespace
627 
628 /// This method defines the set of operations that are linearizable, and hence
629 /// that are considered illegal for the conversion target.
630 static bool isLinearizable(Operation *op) {
631 
632  // Only ops that are in the vector dialect, are ConstantLike, or
633  // are Vectorizable might be linearized currently.
634  StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
635  StringRef opDialect = op->getDialect()->getNamespace();
636  bool supported = (opDialect == vectorDialect) ||
639  if (!supported)
640  return false;
641 
643  // As type legalization is done with vector.shape_cast, shape_cast
644  // itself cannot be linearized (will create new shape_casts to linearize
645  // ad infinitum).
646  .Case<vector::ShapeCastOp>([&](auto) { return false; })
647  // The operations
648  // - vector.extract_strided_slice
649  // - vector.extract
650  // - vector.insert_strided_slice
651  // - vector.insert
652  // are linearized to a rank-1 vector.shuffle by the current patterns.
653  // vector.shuffle only supports fixed size vectors, so it is impossible to
654  // use this approach to linearize these ops if they operate on scalable
655  // vectors.
656  .Case<vector::ExtractStridedSliceOp>(
657  [&](vector::ExtractStridedSliceOp extractOp) {
658  return !extractOp.getType().isScalable();
659  })
660  .Case<vector::InsertStridedSliceOp>(
661  [&](vector::InsertStridedSliceOp insertOp) {
662  return !insertOp.getType().isScalable();
663  })
664  .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
665  return !insertOp.getType().isScalable();
666  })
667  .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
668  return !extractOp.getSourceVectorType().isScalable();
669  })
670  .Default([&](auto) { return true; });
671 }
672 
673 void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
674  ConversionTarget &target) {
675 
676  auto convertType = [](Type type) -> std::optional<Type> {
677  VectorType vectorType = dyn_cast<VectorType>(type);
678  if (!vectorType || !isLinearizableVector(vectorType))
679  return type;
680 
681  VectorType linearizedType =
682  VectorType::get(vectorType.getNumElements(),
683  vectorType.getElementType(), vectorType.isScalable());
684  return linearizedType;
685  };
686  typeConverter.addConversion(convertType);
687 
688  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
689  Location loc) -> Value {
690  if (inputs.size() != 1)
691  return nullptr;
692 
693  Value value = inputs.front();
694  if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
695  return nullptr;
696 
697  return builder.create<vector::ShapeCastOp>(loc, type, value);
698  };
699  typeConverter.addSourceMaterialization(materializeCast);
700  typeConverter.addTargetMaterialization(materializeCast);
701 
703  [=](Operation *op) -> std::optional<bool> {
704  if (!isLinearizable(op))
705  return true;
706  // This will return true if, for all operand and result types `t`,
707  // convertType(t) = t. This is true if there are no rank>=2 vectors.
708  return typeConverter.isLegal(op);
709  });
710 }
711 
712 void mlir::vector::populateVectorLinearizeBasePatterns(
713  const TypeConverter &typeConverter, const ConversionTarget &target,
715  patterns
716  .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
717  LinearizeVectorSplat, LinearizeVectorCreateMask>(
718  typeConverter, patterns.getContext());
719 }
720 
721 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
722  const TypeConverter &typeConverter, const ConversionTarget &target,
724  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
725  LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
726  LinearizeVectorInsertStridedSlice>(typeConverter,
727  patterns.getContext());
728 }
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
static bool isLinearizable(Operation *op)
This method defines the set of operations that are linearizable, and hence that are considered illega...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
IndexType getIndexType()
Definition: Builders.cpp:53
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
StringRef getNamespace() const
Definition: Dialect.h:54
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Include the generated interface declarations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.