MLIR  19.0.0git
VectorLinearize.cpp
Go to the documentation of this file.
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include <cstdint>
25 #include <numeric>
26 
27 using namespace mlir;
28 
29 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
30  auto resultTypes = op->getResultTypes();
31  for (auto resType : resultTypes) {
32  VectorType vecType = dyn_cast<VectorType>(resType);
33  // Reject index since getElementTypeBitWidth will abort for Index types.
34  if (!vecType || vecType.getElementType().isIndex())
35  return false;
36  // There are no dimension to fold if it is a 0-D vector.
37  if (vecType.getRank() == 0)
38  return false;
39  unsigned trailingVecDimBitWidth =
40  vecType.getShape().back() * vecType.getElementTypeBitWidth();
41  if (trailingVecDimBitWidth >= targetBitWidth)
42  return false;
43  }
44  return true;
45 }
46 
47 namespace {
48 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
50  LinearizeConstant(
51  const TypeConverter &typeConverter, MLIRContext *context,
52  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
53  PatternBenefit benefit = 1)
54  : OpConversionPattern(typeConverter, context, benefit),
55  targetVectorBitWidth(targetVectBitWidth) {}
57  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
58  ConversionPatternRewriter &rewriter) const override {
59  Location loc = constOp.getLoc();
60  auto resType =
61  getTypeConverter()->convertType<VectorType>(constOp.getType());
62 
63  if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
64  return rewriter.notifyMatchFailure(
65  loc,
66  "Cannot linearize a constant scalable vector that's not a splat");
67 
68  if (!resType)
69  return rewriter.notifyMatchFailure(loc, "can't convert return type");
70  if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
71  return rewriter.notifyMatchFailure(
72  loc, "Can't flatten since targetBitWidth <= OpSize");
73  auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
74  if (!dstElementsAttr)
75  return rewriter.notifyMatchFailure(loc, "unsupported attr type");
76 
77  dstElementsAttr = dstElementsAttr.reshape(resType);
78  rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
79  dstElementsAttr);
80  return success();
81  }
82 
83 private:
84  unsigned targetVectorBitWidth;
85 };
86 
87 struct LinearizeVectorizable final
88  : OpTraitConversionPattern<OpTrait::Vectorizable> {
90 
91 public:
92  LinearizeVectorizable(
93  const TypeConverter &typeConverter, MLIRContext *context,
94  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
95  PatternBenefit benefit = 1)
96  : OpTraitConversionPattern(typeConverter, context, benefit),
97  targetVectorBitWidth(targetVectBitWidth) {}
99  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
100  ConversionPatternRewriter &rewriter) const override {
101  if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
102  return rewriter.notifyMatchFailure(
103  op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
104  FailureOr<Operation *> newOp =
105  convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
106  if (failed(newOp))
107  return failure();
108 
109  rewriter.replaceOp(op, (*newOp)->getResults());
110  return success();
111  }
112 
113 private:
114  unsigned targetVectorBitWidth;
115 };
116 
117 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
118 /// on a linearized vector.
119 /// Following,
120 /// vector.extract_strided_slice %source
121 /// { offsets = [..], strides = [..], sizes = [..] }
122 /// is converted to :
123 /// %source_1d = vector.shape_cast %source
124 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
125 /// %out_nd = vector.shape_cast %out_1d
126 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
127 /// extraction.
128 struct LinearizeVectorExtractStridedSlice final
129  : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
131  LinearizeVectorExtractStridedSlice(
132  const TypeConverter &typeConverter, MLIRContext *context,
133  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
134  PatternBenefit benefit = 1)
135  : OpConversionPattern(typeConverter, context, benefit),
136  targetVectorBitWidth(targetVectBitWidth) {}
137 
139  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
140  ConversionPatternRewriter &rewriter) const override {
141  Type dstType = getTypeConverter()->convertType(extractOp.getType());
142  assert(!(extractOp.getVector().getType().isScalable() ||
143  cast<VectorType>(dstType).isScalable()) &&
144  "scalable vectors are not supported.");
145  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
146  return rewriter.notifyMatchFailure(
147  extractOp, "Can't flatten since targetBitWidth <= OpSize");
148 
149  ArrayAttr offsets = extractOp.getOffsets();
150  ArrayAttr sizes = extractOp.getSizes();
151  ArrayAttr strides = extractOp.getStrides();
152  if (!isConstantIntValue(strides[0], 1))
153  return rewriter.notifyMatchFailure(
154  extractOp, "Strided slice with stride != 1 is not supported.");
155  Value srcVector = adaptor.getVector();
156  // If kD offsets are specified for nD source vector (n > k), the granularity
157  // of the extraction is greater than 1. In this case last (n-k) dimensions
158  // form the extraction granularity.
159  // Example :
160  // vector.extract_strided_slice %src {
161  // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
162  // vector<4x8x8xf32> to vector<2x2x8xf32>
163  // Here, extraction granularity is 8.
164  int64_t extractGranularitySize = 1;
165  int64_t nD = extractOp.getSourceVectorType().getRank();
166  int64_t kD = (int64_t)offsets.size();
167  int64_t k = kD;
168  while (k < nD) {
169  extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
170  ++k;
171  }
172  // Get total number of extracted slices.
173  int64_t nExtractedSlices = 1;
174  for (Attribute size : sizes) {
175  nExtractedSlices *= cast<IntegerAttr>(size).getInt();
176  }
177  // Compute the strides of the source vector considering first k dimensions.
178  llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
179  for (int i = kD - 2; i >= 0; --i) {
180  sourceStrides[i] = sourceStrides[i + 1] *
181  extractOp.getSourceVectorType().getShape()[i + 1];
182  }
183  // Final shuffle indices has nExtractedSlices * extractGranularitySize
184  // elements.
185  llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
186  extractGranularitySize);
187  // Compute the strides of the extracted kD vector.
188  llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
189  // Compute extractedStrides.
190  for (int i = kD - 2; i >= 0; --i) {
191  extractedStrides[i] =
192  extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
193  }
194  // Iterate over all extracted slices from 0 to nExtractedSlices - 1
195  // and compute the multi-dimensional index and the corresponding linearized
196  // index within the source vector.
197  for (int64_t i = 0; i < nExtractedSlices; ++i) {
198  int64_t index = i;
199  // Compute the corresponding multi-dimensional index.
200  llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
201  for (int64_t j = 0; j < kD; ++j) {
202  multiDimIndex[j] = (index / extractedStrides[j]);
203  index -= multiDimIndex[j] * extractedStrides[j];
204  }
205  // Compute the corresponding linearized index in the source vector
206  // i.e. shift the multiDimIndex by the offsets.
207  int64_t linearizedIndex = 0;
208  for (int64_t j = 0; j < kD; ++j) {
209  linearizedIndex +=
210  (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
211  sourceStrides[j];
212  }
213  // Fill the indices array form linearizedIndex to linearizedIndex +
214  // extractGranularitySize.
215  for (int64_t j = 0; j < extractGranularitySize; ++j) {
216  indices[i * extractGranularitySize + j] = linearizedIndex + j;
217  }
218  }
219  // Perform a shuffle to extract the kD vector.
220  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
221  extractOp, dstType, srcVector, srcVector,
222  rewriter.getI64ArrayAttr(indices));
223  return success();
224  }
225 
226 private:
227  unsigned targetVectorBitWidth;
228 };
229 
230 /// This pattern converts the ShuffleOp that works on nD (n > 1)
231 /// vectors to a ShuffleOp that works on linearized vectors.
232 /// Following,
233 /// vector.shuffle %v1, %v2 [ shuffle_indices ]
234 /// is converted to :
235 /// %v1_1d = vector.shape_cast %v1
236 /// %v2_1d = vector.shape_cast %v2
237 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
238 /// %out_nd = vector.shape_cast %out_1d
239 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
240 /// of the original shuffle operation.
241 struct LinearizeVectorShuffle final
242  : public OpConversionPattern<vector::ShuffleOp> {
244  LinearizeVectorShuffle(
245  const TypeConverter &typeConverter, MLIRContext *context,
246  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
247  PatternBenefit benefit = 1)
248  : OpConversionPattern(typeConverter, context, benefit),
249  targetVectorBitWidth(targetVectBitWidth) {}
250 
252  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
253  ConversionPatternRewriter &rewriter) const override {
254  Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
255  assert(!(shuffleOp.getV1VectorType().isScalable() ||
256  shuffleOp.getV2VectorType().isScalable() ||
257  cast<VectorType>(dstType).isScalable()) &&
258  "scalable vectors are not supported.");
259  if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
260  return rewriter.notifyMatchFailure(
261  shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
262 
263  Value vec1 = adaptor.getV1();
264  Value vec2 = adaptor.getV2();
265  int shuffleSliceLen = 1;
266  int rank = shuffleOp.getV1().getType().getRank();
267 
268  // If rank > 1, we need to do the shuffle in the granularity of slices
269  // instead of scalars. Size of the slice is equal to the rank-1 innermost
270  // dims. Mask of the shuffle op specifies which slice to take from the
271  // outermost dim.
272  if (rank > 1) {
273  llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
274  for (unsigned i = 1; i < shape.size(); ++i) {
275  shuffleSliceLen *= shape[i];
276  }
277  }
278 
279  // For each value in the mask, we generate the indices of the source vectors
280  // that needs to be shuffled to the destination vector. If shuffleSliceLen >
281  // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
282  // elements) instead of scalars.
283  ArrayAttr mask = shuffleOp.getMask();
284  int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
285  llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
286  for (auto [i, value] :
287  llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
288 
289  int64_t v = value.getZExtValue();
290  std::iota(indices.begin() + shuffleSliceLen * i,
291  indices.begin() + shuffleSliceLen * (i + 1),
292  shuffleSliceLen * v);
293  }
294 
295  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
296  shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
297  return success();
298  }
299 
300 private:
301  unsigned targetVectorBitWidth;
302 };
303 
304 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
305 /// linearized vector.
306 /// Following,
307 /// vector.extract %source [ position ]
308 /// is converted to :
309 /// %source_1d = vector.shape_cast %source
310 /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
311 /// %out_nd = vector.shape_cast %out_1d
312 /// `shuffle_indices_1d` is computed using the position of the original extract.
313 struct LinearizeVectorExtract final
314  : public OpConversionPattern<vector::ExtractOp> {
316  LinearizeVectorExtract(
317  const TypeConverter &typeConverter, MLIRContext *context,
318  unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
319  PatternBenefit benefit = 1)
320  : OpConversionPattern(typeConverter, context, benefit),
321  targetVectorBitWidth(targetVectBitWidth) {}
323  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
324  ConversionPatternRewriter &rewriter) const override {
325  Type dstTy = getTypeConverter()->convertType(extractOp.getType());
326  assert(!(extractOp.getVector().getType().isScalable() ||
327  cast<VectorType>(dstTy).isScalable()) &&
328  "scalable vectors are not supported.");
329  if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
330  return rewriter.notifyMatchFailure(
331  extractOp, "Can't flatten since targetBitWidth <= OpSize");
332 
333  // Dynamic position is not supported.
334  if (extractOp.hasDynamicPosition())
335  return rewriter.notifyMatchFailure(extractOp,
336  "dynamic position is not supported.");
337 
338  llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
339  int64_t size = extractOp.getVector().getType().getNumElements();
340 
341  // Compute linearized offset.
342  int64_t linearizedOffset = 0;
343  llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
344  for (auto [i, off] : llvm::enumerate(offsets)) {
345  size /= shape[i];
346  linearizedOffset += offsets[i] * size;
347  }
348 
349  llvm::SmallVector<int64_t, 2> indices(size);
350  std::iota(indices.begin(), indices.end(), linearizedOffset);
351  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
352  extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
353  rewriter.getI64ArrayAttr(indices));
354 
355  return success();
356  }
357 
358 private:
359  unsigned targetVectorBitWidth;
360 };
361 } // namespace
362 
364  TypeConverter &typeConverter, RewritePatternSet &patterns,
365  ConversionTarget &target, unsigned targetBitWidth) {
366 
367  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
368  if (!isLinearizableVector(type))
369  return type;
370 
371  return VectorType::get(type.getNumElements(), type.getElementType(),
372  type.isScalable());
373  });
374 
375  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
376  Location loc) -> Value {
377  if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
378  !isa<VectorType>(type))
379  return nullptr;
380 
381  return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
382  };
383  typeConverter.addArgumentMaterialization(materializeCast);
384  typeConverter.addSourceMaterialization(materializeCast);
385  typeConverter.addTargetMaterialization(materializeCast);
387  [=](Operation *op) -> std::optional<bool> {
388  if ((isa<arith::ConstantOp>(op) ||
390  return (isLessThanTargetBitWidth(op, targetBitWidth)
391  ? typeConverter.isLegal(op)
392  : true);
393  }
394  return std::nullopt;
395  });
396 
397  patterns.add<LinearizeConstant, LinearizeVectorizable>(
398  typeConverter, patterns.getContext(), targetBitWidth);
399 }
400 
402  TypeConverter &typeConverter, RewritePatternSet &patterns,
403  ConversionTarget &target, unsigned int targetBitWidth) {
404  target.addDynamicallyLegalOp<vector::ShuffleOp>(
405  [=](vector::ShuffleOp shuffleOp) -> bool {
406  return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
407  ? (typeConverter.isLegal(shuffleOp) &&
408  cast<mlir::VectorType>(shuffleOp.getResult().getType())
409  .getRank() == 1)
410  : true;
411  });
412  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
413  LinearizeVectorExtractStridedSlice>(
414  typeConverter, patterns.getContext(), targetBitWidth);
415 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.