MLIR  21.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include <cstdint>
25 #include <numeric>
26 
27 using namespace mlir;
28 
29 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
30  auto resultTypes = op->getResultTypes();
31  for (auto resType : resultTypes) {
32  VectorType vecType = dyn_cast<VectorType>(resType);
33  // Reject index since getElementTypeBitWidth will abort for Index types.
34  if (!vecType || vecType.getElementType().isIndex())
35  return false;
36  // There are no dimension to fold if it is a 0-D vector.
37  if (vecType.getRank() == 0)
38  return false;
39  unsigned trailingVecDimBitWidth =
40  vecType.getShape().back() * vecType.getElementTypeBitWidth();
41  if (trailingVecDimBitWidth >= targetBitWidth)
42  return false;
43  }
44  return true;
45 }
46 
47 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
48  VectorType vecType = dyn_cast<VectorType>(t);
49  // Reject index since getElementTypeBitWidth will abort for Index types.
50  if (!vecType || vecType.getElementType().isIndex())
51  return false;
52  // There are no dimension to fold if it is a 0-D vector.
53  if (vecType.getRank() == 0)
54  return false;
55  unsigned trailingVecDimBitWidth =
56  vecType.getShape().back() * vecType.getElementTypeBitWidth();
57  return trailingVecDimBitWidth <= targetBitWidth;
58 }
59 
60 static FailureOr<Attribute>
62  VectorType resType, Attribute value) {
63  if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
64  if (resType.isScalable() && !isa<SplatElementsAttr>(value))
65  return rewriter.notifyMatchFailure(
66  loc,
67  "Cannot linearize a constant scalable vector that's not a splat");
68 
69  return dstElementsAttr.reshape(resType);
70  }
71 
72  if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
73  return poisonAttr;
74 
75  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
76 }
77 
78 namespace {
79 struct LinearizeConstantLike final
80  : OpTraitConversionPattern<OpTrait::ConstantLike> {
82 
83  LinearizeConstantLike(
84  const TypeConverter &typeConverter, MLIRContext *context,
85  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
86  PatternBenefit benefit = 1)
87  : OpTraitConversionPattern(typeConverter, context, benefit),
88  targetVectorBitWidth(targetVectBitWidth) {}
89  LogicalResult
90  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
91  ConversionPatternRewriter &rewriter) const override {
92  Location loc = op->getLoc();
93  if (op->getNumResults() != 1)
94  return rewriter.notifyMatchFailure(loc, "expected 1 result");
95 
96  const TypeConverter &converter = *getTypeConverter();
97  auto resType =
98  converter.convertType<VectorType>(op->getResult(0).getType());
99 
100  if (!resType)
101  return rewriter.notifyMatchFailure(loc, "can't convert return type");
102 
103  if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
104  return rewriter.notifyMatchFailure(
105  loc, "Can't flatten since targetBitWidth <= OpSize");
106 
107  StringAttr attrName = rewriter.getStringAttr("value");
108  Attribute value = op->getAttr(attrName);
109  if (!value)
110  return rewriter.notifyMatchFailure(loc, "no 'value' attr");
111 
112  FailureOr<Attribute> newValue =
113  linearizeConstAttr(loc, rewriter, resType, value);
114  if (failed(newValue))
115  return failure();
116 
117  FailureOr<Operation *> convertResult =
118  convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
119  if (failed(convertResult))
120  return failure();
121 
122  Operation *newOp = *convertResult;
123  newOp->setAttr(attrName, *newValue);
124  rewriter.replaceOp(op, newOp);
125  return success();
126  }
127 
128 private:
129  unsigned targetVectorBitWidth;
130 };
131 
132 struct LinearizeVectorizable final
133  : OpTraitConversionPattern<OpTrait::Vectorizable> {
135 
136 public:
137  LinearizeVectorizable(
138  const TypeConverter &typeConverter, MLIRContext *context,
139  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
140  PatternBenefit benefit = 1)
141  : OpTraitConversionPattern(typeConverter, context, benefit),
142  targetVectorBitWidth(targetVectBitWidth) {}
143  LogicalResult
144  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
145  ConversionPatternRewriter &rewriter) const override {
146  if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
147  return rewriter.notifyMatchFailure(
148  op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
149  FailureOr<Operation *> newOp =
150  convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
151  if (failed(newOp))
152  return failure();
153 
154  rewriter.replaceOp(op, (*newOp)->getResults());
155  return success();
156  }
157 
158 private:
159  unsigned targetVectorBitWidth;
160 };
161 
162 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
163 /// on a linearized vector.
164 /// Following,
165 /// vector.extract_strided_slice %source
166 /// { offsets = [..], strides = [..], sizes = [..] }
167 /// is converted to :
168 /// %source_1d = vector.shape_cast %source
169 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
170 /// %out_nd = vector.shape_cast %out_1d
171 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
172 /// extraction.
173 struct LinearizeVectorExtractStridedSlice final
174  : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
176  LinearizeVectorExtractStridedSlice(
177  const TypeConverter &typeConverter, MLIRContext *context,
178  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
179  PatternBenefit benefit = 1)
180  : OpConversionPattern(typeConverter, context, benefit),
181  targetVectorBitWidth(targetVectBitWidth) {}
182 
183  LogicalResult
184  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override {
186  VectorType dstType =
187  getTypeConverter()->convertType<VectorType>(extractOp.getType());
188  assert(dstType && "vector type destination expected.");
189  if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
190  return rewriter.notifyMatchFailure(extractOp,
191  "scalable vectors are not supported.");
192  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
193  return rewriter.notifyMatchFailure(
194  extractOp, "Can't flatten since targetBitWidth <= OpSize");
195 
196  ArrayAttr offsets = extractOp.getOffsets();
197  ArrayAttr sizes = extractOp.getSizes();
198  ArrayAttr strides = extractOp.getStrides();
199  if (!isConstantIntValue(strides[0], 1))
200  return rewriter.notifyMatchFailure(
201  extractOp, "Strided slice with stride != 1 is not supported.");
202  Value srcVector = adaptor.getVector();
203  // If kD offsets are specified for nD source vector (n > k), the granularity
204  // of the extraction is greater than 1. In this case last (n-k) dimensions
205  // form the extraction granularity.
206  // Example :
207  // vector.extract_strided_slice %src {
208  // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
209  // vector<4x8x8xf32> to vector<2x2x8xf32>
210  // Here, extraction granularity is 8.
211  int64_t extractGranularitySize = 1;
212  int64_t nD = extractOp.getSourceVectorType().getRank();
213  int64_t kD = (int64_t)offsets.size();
214  int64_t k = kD;
215  while (k < nD) {
216  extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
217  ++k;
218  }
219  // Get total number of extracted slices.
220  int64_t nExtractedSlices = 1;
221  for (Attribute size : sizes) {
222  nExtractedSlices *= cast<IntegerAttr>(size).getInt();
223  }
224  // Compute the strides of the source vector considering first k dimensions.
225  llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
226  for (int i = kD - 2; i >= 0; --i) {
227  sourceStrides[i] = sourceStrides[i + 1] *
228  extractOp.getSourceVectorType().getShape()[i + 1];
229  }
230  // Final shuffle indices has nExtractedSlices * extractGranularitySize
231  // elements.
232  llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
233  extractGranularitySize);
234  // Compute the strides of the extracted kD vector.
235  llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
236  // Compute extractedStrides.
237  for (int i = kD - 2; i >= 0; --i) {
238  extractedStrides[i] =
239  extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
240  }
241  // Iterate over all extracted slices from 0 to nExtractedSlices - 1
242  // and compute the multi-dimensional index and the corresponding linearized
243  // index within the source vector.
244  for (int64_t i = 0; i < nExtractedSlices; ++i) {
245  int64_t index = i;
246  // Compute the corresponding multi-dimensional index.
247  llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
248  for (int64_t j = 0; j < kD; ++j) {
249  multiDimIndex[j] = (index / extractedStrides[j]);
250  index -= multiDimIndex[j] * extractedStrides[j];
251  }
252  // Compute the corresponding linearized index in the source vector
253  // i.e. shift the multiDimIndex by the offsets.
254  int64_t linearizedIndex = 0;
255  for (int64_t j = 0; j < kD; ++j) {
256  linearizedIndex +=
257  (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
258  sourceStrides[j];
259  }
260  // Fill the indices array form linearizedIndex to linearizedIndex +
261  // extractGranularitySize.
262  for (int64_t j = 0; j < extractGranularitySize; ++j) {
263  indices[i * extractGranularitySize + j] = linearizedIndex + j;
264  }
265  }
266  // Perform a shuffle to extract the kD vector.
267  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
268  extractOp, dstType, srcVector, srcVector, indices);
269  return success();
270  }
271 
272 private:
273  unsigned targetVectorBitWidth;
274 };
275 
276 /// This pattern converts the ShuffleOp that works on nD (n > 1)
277 /// vectors to a ShuffleOp that works on linearized vectors.
278 /// Following,
279 /// vector.shuffle %v1, %v2 [ shuffle_indices ]
280 /// is converted to :
281 /// %v1_1d = vector.shape_cast %v1
282 /// %v2_1d = vector.shape_cast %v2
283 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
284 /// %out_nd = vector.shape_cast %out_1d
285 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
286 /// of the original shuffle operation.
287 struct LinearizeVectorShuffle final
288  : public OpConversionPattern<vector::ShuffleOp> {
290  LinearizeVectorShuffle(
291  const TypeConverter &typeConverter, MLIRContext *context,
292  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
293  PatternBenefit benefit = 1)
294  : OpConversionPattern(typeConverter, context, benefit),
295  targetVectorBitWidth(targetVectBitWidth) {}
296 
297  LogicalResult
298  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
299  ConversionPatternRewriter &rewriter) const override {
300  VectorType dstType =
301  getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
302  assert(dstType && "vector type destination expected.");
303  // The assert is used because vector.shuffle does not support scalable
304  // vectors.
305  assert(!(shuffleOp.getV1VectorType().isScalable() ||
306  shuffleOp.getV2VectorType().isScalable() ||
307  dstType.isScalable()) &&
308  "scalable vectors are not supported.");
309  if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
310  return rewriter.notifyMatchFailure(
311  shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
312 
313  Value vec1 = adaptor.getV1();
314  Value vec2 = adaptor.getV2();
315  int shuffleSliceLen = 1;
316  int rank = shuffleOp.getV1().getType().getRank();
317 
318  // If rank > 1, we need to do the shuffle in the granularity of slices
319  // instead of scalars. Size of the slice is equal to the rank-1 innermost
320  // dims. Mask of the shuffle op specifies which slice to take from the
321  // outermost dim.
322  if (rank > 1) {
323  llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
324  for (unsigned i = 1; i < shape.size(); ++i) {
325  shuffleSliceLen *= shape[i];
326  }
327  }
328 
329  // For each value in the mask, we generate the indices of the source vectors
330  // that needs to be shuffled to the destination vector. If shuffleSliceLen >
331  // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
332  // elements) instead of scalars.
333  ArrayRef<int64_t> mask = shuffleOp.getMask();
334  int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
335  llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
336  for (auto [i, value] : llvm::enumerate(mask)) {
337  std::iota(indices.begin() + shuffleSliceLen * i,
338  indices.begin() + shuffleSliceLen * (i + 1),
339  shuffleSliceLen * value);
340  }
341 
342  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
343  vec2, indices);
344  return success();
345  }
346 
347 private:
348  unsigned targetVectorBitWidth;
349 };
350 
351 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
352 /// linearized vector.
353 /// Following,
354 /// vector.extract %source [ position ]
355 /// is converted to :
356 /// %source_1d = vector.shape_cast %source
357 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
358 /// %out_nd = vector.shape_cast %out_1d
359 /// `shuffle_indices_1d` is computed using the position of the original extract.
360 struct LinearizeVectorExtract final
361  : public OpConversionPattern<vector::ExtractOp> {
363  LinearizeVectorExtract(
364  const TypeConverter &typeConverter, MLIRContext *context,
365  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
366  PatternBenefit benefit = 1)
367  : OpConversionPattern(typeConverter, context, benefit),
368  targetVectorBitWidth(targetVectBitWidth) {}
369  LogicalResult
370  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
371  ConversionPatternRewriter &rewriter) const override {
372  Type dstTy = getTypeConverter()->convertType(extractOp.getType());
373  if (!dstTy)
374  return rewriter.notifyMatchFailure(extractOp,
375  "expected n-D vector type.");
376 
377  if (extractOp.getVector().getType().isScalable() ||
378  cast<VectorType>(dstTy).isScalable())
379  return rewriter.notifyMatchFailure(extractOp,
380  "scalable vectors are not supported.");
381  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
382  return rewriter.notifyMatchFailure(
383  extractOp, "Can't flatten since targetBitWidth <= OpSize");
384 
385  // Dynamic position is not supported.
386  if (extractOp.hasDynamicPosition())
387  return rewriter.notifyMatchFailure(extractOp,
388  "dynamic position is not supported.");
389 
390  llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
391  int64_t size = extractOp.getVector().getType().getNumElements();
392 
393  // Compute linearized offset.
394  int64_t linearizedOffset = 0;
395  llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
396  for (auto [i, off] : llvm::enumerate(offsets)) {
397  size /= shape[i];
398  linearizedOffset += offsets[i] * size;
399  }
400 
401  llvm::SmallVector<int64_t, 2> indices(size);
402  std::iota(indices.begin(), indices.end(), linearizedOffset);
403  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
404  extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
405 
406  return success();
407  }
408 
409 private:
410  unsigned targetVectorBitWidth;
411 };
412 
413 /// This pattern converts the InsertOp to a ShuffleOp that works on a
414 /// linearized vector.
415 /// Following,
416 /// vector.insert %source %destination [ position ]
417 /// is converted to :
418 /// %source_1d = vector.shape_cast %source
419 /// %destination_1d = vector.shape_cast %destination
420 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
421 /// ] %out_nd = vector.shape_cast %out_1d
422 /// `shuffle_indices_1d` is computed using the position of the original insert.
423 struct LinearizeVectorInsert final
424  : public OpConversionPattern<vector::InsertOp> {
426  LinearizeVectorInsert(
427  const TypeConverter &typeConverter, MLIRContext *context,
428  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
429  PatternBenefit benefit = 1)
430  : OpConversionPattern(typeConverter, context, benefit),
431  targetVectorBitWidth(targetVectBitWidth) {}
432  LogicalResult
433  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
434  ConversionPatternRewriter &rewriter) const override {
435  VectorType dstTy = getTypeConverter()->convertType<VectorType>(
436  insertOp.getDestVectorType());
437  assert(dstTy && "vector type destination expected.");
438  if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
439  return rewriter.notifyMatchFailure(insertOp,
440  "scalable vectors are not supported.");
441 
442  if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
443  targetVectorBitWidth))
444  return rewriter.notifyMatchFailure(
445  insertOp, "Can't flatten since targetBitWidth < OpSize");
446 
447  // dynamic position is not supported
448  if (insertOp.hasDynamicPosition())
449  return rewriter.notifyMatchFailure(insertOp,
450  "dynamic position is not supported.");
451  auto srcTy = insertOp.getSourceType();
452  auto srcAsVec = dyn_cast<VectorType>(srcTy);
453  uint64_t srcSize = 0;
454  if (srcAsVec) {
455  srcSize = srcAsVec.getNumElements();
456  } else {
457  return rewriter.notifyMatchFailure(insertOp,
458  "scalars are not supported.");
459  }
460 
461  auto dstShape = insertOp.getDestVectorType().getShape();
462  const auto dstSize = insertOp.getDestVectorType().getNumElements();
463  auto dstSizeForOffsets = dstSize;
464 
465  // compute linearized offset
466  int64_t linearizedOffset = 0;
467  auto offsetsNd = insertOp.getStaticPosition();
468  for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
469  dstSizeForOffsets /= dstShape[dim];
470  linearizedOffset += offset * dstSizeForOffsets;
471  }
472 
473  llvm::SmallVector<int64_t, 2> indices(dstSize);
474  auto origValsUntil = indices.begin();
475  std::advance(origValsUntil, linearizedOffset);
476  std::iota(indices.begin(), origValsUntil,
477  0); // original values that remain [0, offset)
478  auto newValsUntil = origValsUntil;
479  std::advance(newValsUntil, srcSize);
480  std::iota(origValsUntil, newValsUntil,
481  dstSize); // new values [offset, offset+srcNumElements)
482  std::iota(newValsUntil, indices.end(),
483  linearizedOffset + srcSize); // the rest of original values
484  // [offset+srcNumElements, end)
485 
486  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
487  insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
488 
489  return success();
490  }
491 
492 private:
493  unsigned targetVectorBitWidth;
494 };
495 
496 /// This pattern converts the BitCastOp that works on nD (n > 1)
497 /// vectors to a BitCastOp that works on linearized vectors.
498 /// Following,
499 /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
500 /// is converted to :
501 /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
502 /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
503 /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
504 struct LinearizeVectorBitCast final
505  : public OpConversionPattern<vector::BitCastOp> {
507  LinearizeVectorBitCast(
508  const TypeConverter &typeConverter, MLIRContext *context,
509  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
510  PatternBenefit benefit = 1)
511  : OpConversionPattern(typeConverter, context, benefit),
512  targetVectorBitWidth(targetVectBitWidth) {}
513  LogicalResult
514  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
515  ConversionPatternRewriter &rewriter) const override {
516  Location loc = castOp.getLoc();
517  auto resType = getTypeConverter()->convertType(castOp.getType());
518  if (!resType)
519  return rewriter.notifyMatchFailure(loc, "can't convert return type.");
520 
521  if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
522  return rewriter.notifyMatchFailure(
523  loc, "Can't flatten since targetBitWidth <= OpSize");
524 
525  rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
526  adaptor.getSource());
527  return mlir::success();
528  }
529 
530 private:
531  unsigned targetVectorBitWidth;
532 };
533 
534 } // namespace
535 
537  TypeConverter &typeConverter, RewritePatternSet &patterns,
538  ConversionTarget &target, unsigned targetBitWidth) {
539 
540  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
541  if (!isLinearizableVector(type))
542  return type;
543 
544  return VectorType::get(type.getNumElements(), type.getElementType(),
545  type.isScalable());
546  });
547 
548  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
549  Location loc) -> Value {
550  if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
551  !isa<VectorType>(type))
552  return nullptr;
553 
554  return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
555  };
556  typeConverter.addSourceMaterialization(materializeCast);
557  typeConverter.addTargetMaterialization(materializeCast);
559  [=](Operation *op) -> std::optional<bool> {
560  if ((isa<vector::BitCastOp>(op) ||
563  return (isLessThanTargetBitWidth(op, targetBitWidth)
564  ? typeConverter.isLegal(op)
565  : true);
566  }
567  return std::nullopt;
568  });
569 
570  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
571  LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
572  targetBitWidth);
573 }
574 
576  const TypeConverter &typeConverter, RewritePatternSet &patterns,
577  ConversionTarget &target, unsigned int targetBitWidth) {
578  target.addDynamicallyLegalOp<vector::ShuffleOp>(
579  [=](vector::ShuffleOp shuffleOp) -> bool {
580  return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
581  ? (typeConverter.isLegal(shuffleOp) &&
582  cast<mlir::VectorType>(shuffleOp.getResult().getType())
583  .getRank() == 1)
584  : true;
585  });
586  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
587  LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
588  typeConverter, patterns.getContext(), targetBitWidth);
589 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth)
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth)
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.