MLIR  20.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
29  auto resultTypes = op->getResultTypes();
30  for (auto resType : resultTypes) {
31  VectorType vecType = dyn_cast<VectorType>(resType);
32  // Reject index since getElementTypeBitWidth will abort for Index types.
33  if (!vecType || vecType.getElementType().isIndex())
34  return false;
35  // There are no dimension to fold if it is a 0-D vector.
36  if (vecType.getRank() == 0)
37  return false;
38  unsigned trailingVecDimBitWidth =
39  vecType.getShape().back() * vecType.getElementTypeBitWidth();
40  if (trailingVecDimBitWidth >= targetBitWidth)
41  return false;
42  }
43  return true;
44 }
45 
46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
47  VectorType vecType = dyn_cast<VectorType>(t);
48  // Reject index since getElementTypeBitWidth will abort for Index types.
49  if (!vecType || vecType.getElementType().isIndex())
50  return false;
51  // There are no dimension to fold if it is a 0-D vector.
52  if (vecType.getRank() == 0)
53  return false;
54  unsigned trailingVecDimBitWidth =
55  vecType.getShape().back() * vecType.getElementTypeBitWidth();
56  return trailingVecDimBitWidth <= targetBitWidth;
57 }
58 
59 namespace {
60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
62  LinearizeConstant(
63  const TypeConverter &typeConverter, MLIRContext *context,
64  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
65  PatternBenefit benefit = 1)
66  : OpConversionPattern(typeConverter, context, benefit),
67  targetVectorBitWidth(targetVectBitWidth) {}
68  LogicalResult
69  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
70  ConversionPatternRewriter &rewriter) const override {
71  Location loc = constOp.getLoc();
72  auto resType =
73  getTypeConverter()->convertType<VectorType>(constOp.getType());
74 
75  if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
76  return rewriter.notifyMatchFailure(
77  loc,
78  "Cannot linearize a constant scalable vector that's not a splat");
79 
80  if (!resType)
81  return rewriter.notifyMatchFailure(loc, "can't convert return type");
82  if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
83  return rewriter.notifyMatchFailure(
84  loc, "Can't flatten since targetBitWidth <= OpSize");
85  auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
86  if (!dstElementsAttr)
87  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
88 
89  dstElementsAttr = dstElementsAttr.reshape(resType);
90  rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
91  dstElementsAttr);
92  return success();
93  }
94 
95 private:
96  unsigned targetVectorBitWidth;
97 };
98 
99 struct LinearizeVectorizable final
100  : OpTraitConversionPattern<OpTrait::Vectorizable> {
102 
103 public:
104  LinearizeVectorizable(
105  const TypeConverter &typeConverter, MLIRContext *context,
106  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
107  PatternBenefit benefit = 1)
108  : OpTraitConversionPattern(typeConverter, context, benefit),
109  targetVectorBitWidth(targetVectBitWidth) {}
110  LogicalResult
111  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112  ConversionPatternRewriter &rewriter) const override {
113  if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
114  return rewriter.notifyMatchFailure(
115  op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
116  FailureOr<Operation *> newOp =
117  convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
118  if (failed(newOp))
119  return failure();
120 
121  rewriter.replaceOp(op, (*newOp)->getResults());
122  return success();
123  }
124 
125 private:
126  unsigned targetVectorBitWidth;
127 };
128 
129 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
130 /// on a linearized vector.
131 /// Following,
132 /// vector.extract_strided_slice %source
133 /// { offsets = [..], strides = [..], sizes = [..] }
134 /// is converted to :
135 /// %source_1d = vector.shape_cast %source
136 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
137 /// %out_nd = vector.shape_cast %out_1d
138 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
139 /// extraction.
140 struct LinearizeVectorExtractStridedSlice final
141  : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
143  LinearizeVectorExtractStridedSlice(
144  const TypeConverter &typeConverter, MLIRContext *context,
145  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
146  PatternBenefit benefit = 1)
147  : OpConversionPattern(typeConverter, context, benefit),
148  targetVectorBitWidth(targetVectBitWidth) {}
149 
150  LogicalResult
151  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
152  ConversionPatternRewriter &rewriter) const override {
153  VectorType dstType =
154  getTypeConverter()->convertType<VectorType>(extractOp.getType());
155  assert(dstType && "vector type destination expected.");
156  if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
157  return rewriter.notifyMatchFailure(extractOp,
158  "scalable vectors are not supported.");
159  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
160  return rewriter.notifyMatchFailure(
161  extractOp, "Can't flatten since targetBitWidth <= OpSize");
162 
163  ArrayAttr offsets = extractOp.getOffsets();
164  ArrayAttr sizes = extractOp.getSizes();
165  ArrayAttr strides = extractOp.getStrides();
166  if (!isConstantIntValue(strides[0], 1))
167  return rewriter.notifyMatchFailure(
168  extractOp, "Strided slice with stride != 1 is not supported.");
169  Value srcVector = adaptor.getVector();
170  // If kD offsets are specified for nD source vector (n > k), the granularity
171  // of the extraction is greater than 1. In this case last (n-k) dimensions
172  // form the extraction granularity.
173  // Example :
174  // vector.extract_strided_slice %src {
175  // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
176  // vector<4x8x8xf32> to vector<2x2x8xf32>
177  // Here, extraction granularity is 8.
178  int64_t extractGranularitySize = 1;
179  int64_t nD = extractOp.getSourceVectorType().getRank();
180  int64_t kD = (int64_t)offsets.size();
181  int64_t k = kD;
182  while (k < nD) {
183  extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
184  ++k;
185  }
186  // Get total number of extracted slices.
187  int64_t nExtractedSlices = 1;
188  for (Attribute size : sizes) {
189  nExtractedSlices *= cast<IntegerAttr>(size).getInt();
190  }
191  // Compute the strides of the source vector considering first k dimensions.
192  llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
193  for (int i = kD - 2; i >= 0; --i) {
194  sourceStrides[i] = sourceStrides[i + 1] *
195  extractOp.getSourceVectorType().getShape()[i + 1];
196  }
197  // Final shuffle indices has nExtractedSlices * extractGranularitySize
198  // elements.
199  llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
200  extractGranularitySize);
201  // Compute the strides of the extracted kD vector.
202  llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
203  // Compute extractedStrides.
204  for (int i = kD - 2; i >= 0; --i) {
205  extractedStrides[i] =
206  extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
207  }
208  // Iterate over all extracted slices from 0 to nExtractedSlices - 1
209  // and compute the multi-dimensional index and the corresponding linearized
210  // index within the source vector.
211  for (int64_t i = 0; i < nExtractedSlices; ++i) {
212  int64_t index = i;
213  // Compute the corresponding multi-dimensional index.
214  llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
215  for (int64_t j = 0; j < kD; ++j) {
216  multiDimIndex[j] = (index / extractedStrides[j]);
217  index -= multiDimIndex[j] * extractedStrides[j];
218  }
219  // Compute the corresponding linearized index in the source vector
220  // i.e. shift the multiDimIndex by the offsets.
221  int64_t linearizedIndex = 0;
222  for (int64_t j = 0; j < kD; ++j) {
223  linearizedIndex +=
224  (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
225  sourceStrides[j];
226  }
227  // Fill the indices array form linearizedIndex to linearizedIndex +
228  // extractGranularitySize.
229  for (int64_t j = 0; j < extractGranularitySize; ++j) {
230  indices[i * extractGranularitySize + j] = linearizedIndex + j;
231  }
232  }
233  // Perform a shuffle to extract the kD vector.
234  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
235  extractOp, dstType, srcVector, srcVector, indices);
236  return success();
237  }
238 
239 private:
240  unsigned targetVectorBitWidth;
241 };
242 
243 /// This pattern converts the ShuffleOp that works on nD (n > 1)
244 /// vectors to a ShuffleOp that works on linearized vectors.
245 /// Following,
246 /// vector.shuffle %v1, %v2 [ shuffle_indices ]
247 /// is converted to :
248 /// %v1_1d = vector.shape_cast %v1
249 /// %v2_1d = vector.shape_cast %v2
250 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
251 /// %out_nd = vector.shape_cast %out_1d
252 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
253 /// of the original shuffle operation.
254 struct LinearizeVectorShuffle final
255  : public OpConversionPattern<vector::ShuffleOp> {
257  LinearizeVectorShuffle(
258  const TypeConverter &typeConverter, MLIRContext *context,
259  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
260  PatternBenefit benefit = 1)
261  : OpConversionPattern(typeConverter, context, benefit),
262  targetVectorBitWidth(targetVectBitWidth) {}
263 
264  LogicalResult
265  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
266  ConversionPatternRewriter &rewriter) const override {
267  VectorType dstType =
268  getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
269  assert(dstType && "vector type destination expected.");
270  // The assert is used because vector.shuffle does not support scalable
271  // vectors.
272  assert(!(shuffleOp.getV1VectorType().isScalable() ||
273  shuffleOp.getV2VectorType().isScalable() ||
274  dstType.isScalable()) &&
275  "scalable vectors are not supported.");
276  if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
277  return rewriter.notifyMatchFailure(
278  shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
279 
280  Value vec1 = adaptor.getV1();
281  Value vec2 = adaptor.getV2();
282  int shuffleSliceLen = 1;
283  int rank = shuffleOp.getV1().getType().getRank();
284 
285  // If rank > 1, we need to do the shuffle in the granularity of slices
286  // instead of scalars. Size of the slice is equal to the rank-1 innermost
287  // dims. Mask of the shuffle op specifies which slice to take from the
288  // outermost dim.
289  if (rank > 1) {
290  llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
291  for (unsigned i = 1; i < shape.size(); ++i) {
292  shuffleSliceLen *= shape[i];
293  }
294  }
295 
296  // For each value in the mask, we generate the indices of the source vectors
297  // that needs to be shuffled to the destination vector. If shuffleSliceLen >
298  // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
299  // elements) instead of scalars.
300  ArrayRef<int64_t> mask = shuffleOp.getMask();
301  int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
302  llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
303  for (auto [i, value] : llvm::enumerate(mask)) {
304  std::iota(indices.begin() + shuffleSliceLen * i,
305  indices.begin() + shuffleSliceLen * (i + 1),
306  shuffleSliceLen * value);
307  }
308 
309  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
310  vec2, indices);
311  return success();
312  }
313 
314 private:
315  unsigned targetVectorBitWidth;
316 };
317 
318 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
319 /// linearized vector.
320 /// Following,
321 /// vector.extract %source [ position ]
322 /// is converted to :
323 /// %source_1d = vector.shape_cast %source
324 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
325 /// %out_nd = vector.shape_cast %out_1d
326 /// `shuffle_indices_1d` is computed using the position of the original extract.
327 struct LinearizeVectorExtract final
328  : public OpConversionPattern<vector::ExtractOp> {
330  LinearizeVectorExtract(
331  const TypeConverter &typeConverter, MLIRContext *context,
332  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
333  PatternBenefit benefit = 1)
334  : OpConversionPattern(typeConverter, context, benefit),
335  targetVectorBitWidth(targetVectBitWidth) {}
336  LogicalResult
337  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
338  ConversionPatternRewriter &rewriter) const override {
339  Type dstTy = getTypeConverter()->convertType(extractOp.getType());
340  if (!dstTy)
341  return rewriter.notifyMatchFailure(extractOp,
342  "expected n-D vector type.");
343 
344  if (extractOp.getVector().getType().isScalable() ||
345  cast<VectorType>(dstTy).isScalable())
346  return rewriter.notifyMatchFailure(extractOp,
347  "scalable vectors are not supported.");
348  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
349  return rewriter.notifyMatchFailure(
350  extractOp, "Can't flatten since targetBitWidth <= OpSize");
351 
352  // Dynamic position is not supported.
353  if (extractOp.hasDynamicPosition())
354  return rewriter.notifyMatchFailure(extractOp,
355  "dynamic position is not supported.");
356 
357  llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
358  int64_t size = extractOp.getVector().getType().getNumElements();
359 
360  // Compute linearized offset.
361  int64_t linearizedOffset = 0;
362  llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
363  for (auto [i, off] : llvm::enumerate(offsets)) {
364  size /= shape[i];
365  linearizedOffset += offsets[i] * size;
366  }
367 
368  llvm::SmallVector<int64_t, 2> indices(size);
369  std::iota(indices.begin(), indices.end(), linearizedOffset);
370  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
371  extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
372 
373  return success();
374  }
375 
376 private:
377  unsigned targetVectorBitWidth;
378 };
379 
380 /// This pattern converts the InsertOp to a ShuffleOp that works on a
381 /// linearized vector.
382 /// Following,
383 /// vector.insert %source %destination [ position ]
384 /// is converted to :
385 /// %source_1d = vector.shape_cast %source
386 /// %destination_1d = vector.shape_cast %destination
387 /// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
388 /// ] %out_nd = vector.shape_cast %out_1d
389 /// `shuffle_indices_1d` is computed using the position of the original insert.
390 struct LinearizeVectorInsert final
391  : public OpConversionPattern<vector::InsertOp> {
393  LinearizeVectorInsert(
394  const TypeConverter &typeConverter, MLIRContext *context,
395  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
396  PatternBenefit benefit = 1)
397  : OpConversionPattern(typeConverter, context, benefit),
398  targetVectorBitWidth(targetVectBitWidth) {}
399  LogicalResult
400  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
401  ConversionPatternRewriter &rewriter) const override {
402  VectorType dstTy = getTypeConverter()->convertType<VectorType>(
403  insertOp.getDestVectorType());
404  assert(dstTy && "vector type destination expected.");
405  if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
406  return rewriter.notifyMatchFailure(insertOp,
407  "scalable vectors are not supported.");
408 
409  if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
410  targetVectorBitWidth))
411  return rewriter.notifyMatchFailure(
412  insertOp, "Can't flatten since targetBitWidth < OpSize");
413 
414  // dynamic position is not supported
415  if (insertOp.hasDynamicPosition())
416  return rewriter.notifyMatchFailure(insertOp,
417  "dynamic position is not supported.");
418  auto srcTy = insertOp.getSourceType();
419  auto srcAsVec = dyn_cast<VectorType>(srcTy);
420  uint64_t srcSize = 0;
421  if (srcAsVec) {
422  srcSize = srcAsVec.getNumElements();
423  } else {
424  return rewriter.notifyMatchFailure(insertOp,
425  "scalars are not supported.");
426  }
427 
428  auto dstShape = insertOp.getDestVectorType().getShape();
429  const auto dstSize = insertOp.getDestVectorType().getNumElements();
430  auto dstSizeForOffsets = dstSize;
431 
432  // compute linearized offset
433  int64_t linearizedOffset = 0;
434  auto offsetsNd = insertOp.getStaticPosition();
435  for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
436  dstSizeForOffsets /= dstShape[dim];
437  linearizedOffset += offset * dstSizeForOffsets;
438  }
439 
440  llvm::SmallVector<int64_t, 2> indices(dstSize);
441  auto origValsUntil = indices.begin();
442  std::advance(origValsUntil, linearizedOffset);
443  std::iota(indices.begin(), origValsUntil,
444  0); // original values that remain [0, offset)
445  auto newValsUntil = origValsUntil;
446  std::advance(newValsUntil, srcSize);
447  std::iota(origValsUntil, newValsUntil,
448  dstSize); // new values [offset, offset+srcNumElements)
449  std::iota(newValsUntil, indices.end(),
450  linearizedOffset + srcSize); // the rest of original values
451  // [offset+srcNumElements, end)
452 
453  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
454  insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
455 
456  return success();
457  }
458 
459 private:
460  unsigned targetVectorBitWidth;
461 };
462 } // namespace
463 
465  TypeConverter &typeConverter, RewritePatternSet &patterns,
466  ConversionTarget &target, unsigned targetBitWidth) {
467 
468  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
469  if (!isLinearizableVector(type))
470  return type;
471 
472  return VectorType::get(type.getNumElements(), type.getElementType(),
473  type.isScalable());
474  });
475 
476  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
477  Location loc) -> Value {
478  if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
479  !isa<VectorType>(type))
480  return nullptr;
481 
482  return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
483  };
484  typeConverter.addArgumentMaterialization(materializeCast);
485  typeConverter.addSourceMaterialization(materializeCast);
486  typeConverter.addTargetMaterialization(materializeCast);
488  [=](Operation *op) -> std::optional<bool> {
489  if ((isa<arith::ConstantOp>(op) ||
491  return (isLessThanTargetBitWidth(op, targetBitWidth)
492  ? typeConverter.isLegal(op)
493  : true);
494  }
495  return std::nullopt;
496  });
497 
498  patterns.add<LinearizeConstant, LinearizeVectorizable>(
499  typeConverter, patterns.getContext(), targetBitWidth);
500 }
501 
503  const TypeConverter &typeConverter, RewritePatternSet &patterns,
504  ConversionTarget &target, unsigned int targetBitWidth) {
505  target.addDynamicallyLegalOp<vector::ShuffleOp>(
506  [=](vector::ShuffleOp shuffleOp) -> bool {
507  return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
508  ? (typeConverter.isLegal(shuffleOp) &&
509  cast<mlir::VectorType>(shuffleOp.getResult().getType())
510  .getRank() == 1)
511  : true;
512  });
513  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
514  LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
515  typeConverter, patterns.getContext(), targetBitWidth);
516 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth)
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.