MLIR  19.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
29  auto resultTypes = op->getResultTypes();
30  for (auto resType : resultTypes) {
31  VectorType vecType = dyn_cast<VectorType>(resType);
32  // Reject index since getElementTypeBitWidth will abort for Index types.
33  if (!vecType || vecType.getElementType().isIndex())
34  return false;
35  // There are no dimension to fold if it is a 0-D vector.
36  if (vecType.getRank() == 0)
37  return false;
38  unsigned trailingVecDimBitWidth =
39  vecType.getShape().back() * vecType.getElementTypeBitWidth();
40  if (trailingVecDimBitWidth >= targetBitWidth)
41  return false;
42  }
43  return true;
44 }
45 
46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
47  VectorType vecType = dyn_cast<VectorType>(t);
48  // Reject index since getElementTypeBitWidth will abort for Index types.
49  if (!vecType || vecType.getElementType().isIndex())
50  return false;
51  // There are no dimension to fold if it is a 0-D vector.
52  if (vecType.getRank() == 0)
53  return false;
54  unsigned trailingVecDimBitWidth =
55  vecType.getShape().back() * vecType.getElementTypeBitWidth();
56  return trailingVecDimBitWidth <= targetBitWidth;
57 }
58 
59 namespace {
60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
62  LinearizeConstant(
63  const TypeConverter &typeConverter, MLIRContext *context,
64  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
65  PatternBenefit benefit = 1)
66  : OpConversionPattern(typeConverter, context, benefit),
67  targetVectorBitWidth(targetVectBitWidth) {}
68  LogicalResult
69  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
70  ConversionPatternRewriter &rewriter) const override {
71  Location loc = constOp.getLoc();
72  auto resType =
73  getTypeConverter()->convertType<VectorType>(constOp.getType());
74 
75  if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
76  return rewriter.notifyMatchFailure(
77  loc,
78  "Cannot linearize a constant scalable vector that's not a splat");
79 
80  if (!resType)
81  return rewriter.notifyMatchFailure(loc, "can't convert return type");
82  if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
83  return rewriter.notifyMatchFailure(
84  loc, "Can't flatten since targetBitWidth <= OpSize");
85  auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
86  if (!dstElementsAttr)
87  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
88 
89  dstElementsAttr = dstElementsAttr.reshape(resType);
90  rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
91  dstElementsAttr);
92  return success();
93  }
94 
95 private:
96  unsigned targetVectorBitWidth;
97 };
98 
99 struct LinearizeVectorizable final
100  : OpTraitConversionPattern<OpTrait::Vectorizable> {
102 
103 public:
104  LinearizeVectorizable(
105  const TypeConverter &typeConverter, MLIRContext *context,
106  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
107  PatternBenefit benefit = 1)
108  : OpTraitConversionPattern(typeConverter, context, benefit),
109  targetVectorBitWidth(targetVectBitWidth) {}
110  LogicalResult
111  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112  ConversionPatternRewriter &rewriter) const override {
113  if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
114  return rewriter.notifyMatchFailure(
115  op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
116  FailureOr<Operation *> newOp =
117  convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
118  if (failed(newOp))
119  return failure();
120 
121  rewriter.replaceOp(op, (*newOp)->getResults());
122  return success();
123  }
124 
125 private:
126  unsigned targetVectorBitWidth;
127 };
128 
129 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
130 /// on a linearized vector.
131 /// Following,
132 /// vector.extract_strided_slice %source
133 /// { offsets = [..], strides = [..], sizes = [..] }
134 /// is converted to :
135 /// %source_1d = vector.shape_cast %source
136 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
137 /// %out_nd = vector.shape_cast %out_1d
138 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
139 /// extraction.
140 struct LinearizeVectorExtractStridedSlice final
141  : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
143  LinearizeVectorExtractStridedSlice(
144  const TypeConverter &typeConverter, MLIRContext *context,
145  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
146  PatternBenefit benefit = 1)
147  : OpConversionPattern(typeConverter, context, benefit),
148  targetVectorBitWidth(targetVectBitWidth) {}
149 
150  LogicalResult
151  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
152  ConversionPatternRewriter &rewriter) const override {
153  VectorType dstType =
154  getTypeConverter()->convertType<VectorType>(extractOp.getType());
155  assert(dstType && "vector type destination expected.");
156  if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
157  return rewriter.notifyMatchFailure(extractOp,
158  "scalable vectors are not supported.");
159  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
160  return rewriter.notifyMatchFailure(
161  extractOp, "Can't flatten since targetBitWidth <= OpSize");
162 
163  ArrayAttr offsets = extractOp.getOffsets();
164  ArrayAttr sizes = extractOp.getSizes();
165  ArrayAttr strides = extractOp.getStrides();
166  if (!isConstantIntValue(strides[0], 1))
167  return rewriter.notifyMatchFailure(
168  extractOp, "Strided slice with stride != 1 is not supported.");
169  Value srcVector = adaptor.getVector();
170  // If kD offsets are specified for nD source vector (n > k), the granularity
171  // of the extraction is greater than 1. In this case last (n-k) dimensions
172  // form the extraction granularity.
173  // Example :
174  // vector.extract_strided_slice %src {
175  // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
176  // vector<4x8x8xf32> to vector<2x2x8xf32>
177  // Here, extraction granularity is 8.
178  int64_t extractGranularitySize = 1;
179  int64_t nD = extractOp.getSourceVectorType().getRank();
180  int64_t kD = (int64_t)offsets.size();
181  int64_t k = kD;
182  while (k < nD) {
183  extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
184  ++k;
185  }
186  // Get total number of extracted slices.
187  int64_t nExtractedSlices = 1;
188  for (Attribute size : sizes) {
189  nExtractedSlices *= cast<IntegerAttr>(size).getInt();
190  }
191  // Compute the strides of the source vector considering first k dimensions.
192  llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
193  for (int i = kD - 2; i >= 0; --i) {
194  sourceStrides[i] = sourceStrides[i + 1] *
195  extractOp.getSourceVectorType().getShape()[i + 1];
196  }
197  // Final shuffle indices has nExtractedSlices * extractGranularitySize
198  // elements.
199  llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
200  extractGranularitySize);
201  // Compute the strides of the extracted kD vector.
202  llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
203  // Compute extractedStrides.
204  for (int i = kD - 2; i >= 0; --i) {
205  extractedStrides[i] =
206  extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
207  }
208  // Iterate over all extracted slices from 0 to nExtractedSlices - 1
209  // and compute the multi-dimensional index and the corresponding linearized
210  // index within the source vector.
211  for (int64_t i = 0; i < nExtractedSlices; ++i) {
212  int64_t index = i;
213  // Compute the corresponding multi-dimensional index.
214  llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
215  for (int64_t j = 0; j < kD; ++j) {
216  multiDimIndex[j] = (index / extractedStrides[j]);
217  index -= multiDimIndex[j] * extractedStrides[j];
218  }
219  // Compute the corresponding linearized index in the source vector
220  // i.e. shift the multiDimIndex by the offsets.
221  int64_t linearizedIndex = 0;
222  for (int64_t j = 0; j < kD; ++j) {
223  linearizedIndex +=
224  (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
225  sourceStrides[j];
226  }
227  // Fill the indices array form linearizedIndex to linearizedIndex +
228  // extractGranularitySize.
229  for (int64_t j = 0; j < extractGranularitySize; ++j) {
230  indices[i * extractGranularitySize + j] = linearizedIndex + j;
231  }
232  }
233  // Perform a shuffle to extract the kD vector.
234  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
235  extractOp, dstType, srcVector, srcVector,
236  rewriter.getI64ArrayAttr(indices));
237  return success();
238  }
239 
240 private:
241  unsigned targetVectorBitWidth;
242 };
243 
244 /// This pattern converts the ShuffleOp that works on nD (n > 1)
245 /// vectors to a ShuffleOp that works on linearized vectors.
246 /// Following,
247 /// vector.shuffle %v1, %v2 [ shuffle_indices ]
248 /// is converted to :
249 /// %v1_1d = vector.shape_cast %v1
250 /// %v2_1d = vector.shape_cast %v2
251 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
252 /// %out_nd = vector.shape_cast %out_1d
253 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
254 /// of the original shuffle operation.
255 struct LinearizeVectorShuffle final
256  : public OpConversionPattern<vector::ShuffleOp> {
258  LinearizeVectorShuffle(
259  const TypeConverter &typeConverter, MLIRContext *context,
260  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
261  PatternBenefit benefit = 1)
262  : OpConversionPattern(typeConverter, context, benefit),
263  targetVectorBitWidth(targetVectBitWidth) {}
264 
265  LogicalResult
266  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
267  ConversionPatternRewriter &rewriter) const override {
268  VectorType dstType =
269  getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
270  assert(dstType && "vector type destination expected.");
271  // The assert is used because vector.shuffle does not support scalable
272  // vectors.
273  assert(!(shuffleOp.getV1VectorType().isScalable() ||
274  shuffleOp.getV2VectorType().isScalable() ||
275  dstType.isScalable()) &&
276  "scalable vectors are not supported.");
277  if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
278  return rewriter.notifyMatchFailure(
279  shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
280 
281  Value vec1 = adaptor.getV1();
282  Value vec2 = adaptor.getV2();
283  int shuffleSliceLen = 1;
284  int rank = shuffleOp.getV1().getType().getRank();
285 
286  // If rank > 1, we need to do the shuffle in the granularity of slices
287  // instead of scalars. Size of the slice is equal to the rank-1 innermost
288  // dims. Mask of the shuffle op specifies which slice to take from the
289  // outermost dim.
290  if (rank > 1) {
291  llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
292  for (unsigned i = 1; i < shape.size(); ++i) {
293  shuffleSliceLen *= shape[i];
294  }
295  }
296 
297  // For each value in the mask, we generate the indices of the source vectors
298  // that needs to be shuffled to the destination vector. If shuffleSliceLen >
299  // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
300  // elements) instead of scalars.
301  ArrayAttr mask = shuffleOp.getMask();
302  int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
303  llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
304  for (auto [i, value] :
305  llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
306 
307  int64_t v = value.getZExtValue();
308  std::iota(indices.begin() + shuffleSliceLen * i,
309  indices.begin() + shuffleSliceLen * (i + 1),
310  shuffleSliceLen * v);
311  }
312 
313  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
314  shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
315  return success();
316  }
317 
318 private:
319  unsigned targetVectorBitWidth;
320 };
321 
322 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
323 /// linearized vector.
324 /// Following,
325 /// vector.extract %source [ position ]
326 /// is converted to :
327 /// %source_1d = vector.shape_cast %source
328 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
329 /// %out_nd = vector.shape_cast %out_1d
330 /// `shuffle_indices_1d` is computed using the position of the original extract.
331 struct LinearizeVectorExtract final
332  : public OpConversionPattern<vector::ExtractOp> {
334  LinearizeVectorExtract(
335  const TypeConverter &typeConverter, MLIRContext *context,
336  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
337  PatternBenefit benefit = 1)
338  : OpConversionPattern(typeConverter, context, benefit),
339  targetVectorBitWidth(targetVectBitWidth) {}
340  LogicalResult
341  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
342  ConversionPatternRewriter &rewriter) const override {
343  Type dstTy = getTypeConverter()->convertType(extractOp.getType());
344  if (extractOp.getVector().getType().isScalable() ||
345  cast<VectorType>(dstTy).isScalable())
346  return rewriter.notifyMatchFailure(extractOp,
347  "scalable vectors are not supported.");
348  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
349  return rewriter.notifyMatchFailure(
350  extractOp, "Can't flatten since targetBitWidth <= OpSize");
351 
352  // Dynamic position is not supported.
353  if (extractOp.hasDynamicPosition())
354  return rewriter.notifyMatchFailure(extractOp,
355  "dynamic position is not supported.");
356 
357  llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
358  int64_t size = extractOp.getVector().getType().getNumElements();
359 
360  // Compute linearized offset.
361  int64_t linearizedOffset = 0;
362  llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
363  for (auto [i, off] : llvm::enumerate(offsets)) {
364  size /= shape[i];
365  linearizedOffset += offsets[i] * size;
366  }
367 
368  llvm::SmallVector<int64_t, 2> indices(size);
369  std::iota(indices.begin(), indices.end(), linearizedOffset);
370  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
371  extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
372  rewriter.getI64ArrayAttr(indices));
373 
374  return success();
375  }
376 
377 private:
378  unsigned targetVectorBitWidth;
379 };
380 
381 /// This pattern converts the InsertOp to a ShuffleOp that works on a
382 /// linearized vector.
383 /// Following,
384 /// vector.insert %source %destination [ position ]
385 /// is converted to :
386 /// %source_1d = vector.shape_cast %source
387 /// %destination_1d = vector.shape_cast %destination
388 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
389 /// ] %out_nd = vector.shape_cast %out_1d
390 /// `shuffle_indices_1d` is computed using the position of the original insert.
391 struct LinearizeVectorInsert final
392  : public OpConversionPattern<vector::InsertOp> {
394  LinearizeVectorInsert(
395  const TypeConverter &typeConverter, MLIRContext *context,
396  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
397  PatternBenefit benefit = 1)
398  : OpConversionPattern(typeConverter, context, benefit),
399  targetVectorBitWidth(targetVectBitWidth) {}
400  LogicalResult
401  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
402  ConversionPatternRewriter &rewriter) const override {
403  VectorType dstTy = getTypeConverter()->convertType<VectorType>(
404  insertOp.getDestVectorType());
405  assert(dstTy && "vector type destination expected.");
406  if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
407  return rewriter.notifyMatchFailure(insertOp,
408  "scalable vectors are not supported.");
409 
410  if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
411  targetVectorBitWidth))
412  return rewriter.notifyMatchFailure(
413  insertOp, "Can't flatten since targetBitWidth < OpSize");
414 
415  // dynamic position is not supported
416  if (insertOp.hasDynamicPosition())
417  return rewriter.notifyMatchFailure(insertOp,
418  "dynamic position is not supported.");
419  auto srcTy = insertOp.getSourceType();
420  auto srcAsVec = dyn_cast<VectorType>(srcTy);
421  uint64_t srcSize = 0;
422  if (srcAsVec) {
423  srcSize = srcAsVec.getNumElements();
424  } else {
425  return rewriter.notifyMatchFailure(insertOp,
426  "scalars are not supported.");
427  }
428 
429  auto dstShape = insertOp.getDestVectorType().getShape();
430  const auto dstSize = insertOp.getDestVectorType().getNumElements();
431  auto dstSizeForOffsets = dstSize;
432 
433  // compute linearized offset
434  int64_t linearizedOffset = 0;
435  auto offsetsNd = insertOp.getStaticPosition();
436  for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
437  dstSizeForOffsets /= dstShape[dim];
438  linearizedOffset += offset * dstSizeForOffsets;
439  }
440 
441  llvm::SmallVector<int64_t, 2> indices(dstSize);
442  auto origValsUntil = indices.begin();
443  std::advance(origValsUntil, linearizedOffset);
444  std::iota(indices.begin(), origValsUntil,
445  0); // original values that remain [0, offset)
446  auto newValsUntil = origValsUntil;
447  std::advance(newValsUntil, srcSize);
448  std::iota(origValsUntil, newValsUntil,
449  dstSize); // new values [offset, offset+srcNumElements)
450  std::iota(newValsUntil, indices.end(),
451  linearizedOffset + srcSize); // the rest of original values
452  // [offset+srcNumElements, end)
453 
454  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
455  insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
456  rewriter.getI64ArrayAttr(indices));
457 
458  return success();
459  }
460 
461 private:
462  unsigned targetVectorBitWidth;
463 };
464 } // namespace
465 
467  TypeConverter &typeConverter, RewritePatternSet &patterns,
468  ConversionTarget &target, unsigned targetBitWidth) {
469 
470  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
471  if (!isLinearizableVector(type))
472  return type;
473 
474  return VectorType::get(type.getNumElements(), type.getElementType(),
475  type.isScalable());
476  });
477 
478  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
479  Location loc) -> Value {
480  if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
481  !isa<VectorType>(type))
482  return nullptr;
483 
484  return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
485  };
486  typeConverter.addArgumentMaterialization(materializeCast);
487  typeConverter.addSourceMaterialization(materializeCast);
488  typeConverter.addTargetMaterialization(materializeCast);
490  [=](Operation *op) -> std::optional<bool> {
491  if ((isa<arith::ConstantOp>(op) ||
493  return (isLessThanTargetBitWidth(op, targetBitWidth)
494  ? typeConverter.isLegal(op)
495  : true);
496  }
497  return std::nullopt;
498  });
499 
500  patterns.add<LinearizeConstant, LinearizeVectorizable>(
501  typeConverter, patterns.getContext(), targetBitWidth);
502 }
503 
505  TypeConverter &typeConverter, RewritePatternSet &patterns,
506  ConversionTarget &target, unsigned int targetBitWidth) {
507  target.addDynamicallyLegalOp<vector::ShuffleOp>(
508  [=](vector::ShuffleOp shuffleOp) -> bool {
509  return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
510  ? (typeConverter.isLegal(shuffleOp) &&
511  cast<mlir::VectorType>(shuffleOp.getResult().getType())
512  .getRank() == 1)
513  : true;
514  });
515  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
516  LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
517  typeConverter, patterns.getContext(), targetBitWidth);
518 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth)
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.