MLIR  17.0.0git
VectorDropLeadUnitDim.cpp
Go to the documentation of this file.
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 #define DEBUG_TYPE "vector-drop-unit-dim"
18 
19 using namespace mlir;
20 using namespace mlir::vector;
21 
22 // Trims leading one dimensions from `oldType` and returns the result type.
23 // Returns `vector<1xT>` if `oldType` only has one element.
24 static VectorType trimLeadingOneDims(VectorType oldType) {
25  ArrayRef<int64_t> oldShape = oldType.getShape();
26  ArrayRef<int64_t> newShape =
27  oldShape.drop_while([](int64_t dim) { return dim == 1; });
28  // Make sure we have at least 1 dimension per vector type requirements.
29  if (newShape.empty())
30  newShape = oldShape.take_back();
31  return VectorType::get(newShape, oldType.getElementType());
32 }
33 
34 /// Return a smallVector of size `rank` containing all zeros.
35 static SmallVector<int64_t> splatZero(int64_t rank) {
36  return SmallVector<int64_t>(rank, 0);
37 }
38 namespace {
39 
40 // Casts away leading one dimensions in vector.extract_strided_slice's vector
41 // input by inserting vector.broadcast.
42 struct CastAwayExtractStridedSliceLeadingOneDim
43  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
45 
46  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
47  PatternRewriter &rewriter) const override {
48  // vector.extract_strided_slice requires the input and output vector to have
49  // the same rank. Here we drop leading one dimensions from the input vector
50  // type to make sure we don't cause mismatch.
51  VectorType oldSrcType = extractOp.getSourceVectorType();
52  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
53 
54  if (newSrcType.getRank() == oldSrcType.getRank())
55  return failure();
56 
57  int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
58 
59  VectorType oldDstType = extractOp.getType();
60  VectorType newDstType =
61  VectorType::get(oldDstType.getShape().drop_front(dropCount),
62  oldDstType.getElementType());
63 
64  Location loc = extractOp.getLoc();
65 
66  Value newSrcVector = rewriter.create<vector::ExtractOp>(
67  loc, extractOp.getVector(), splatZero(dropCount));
68 
69  // The offsets/sizes/strides attribute can have a less number of elements
70  // than the input vector's rank: it is meant for the leading dimensions.
71  auto newOffsets = rewriter.getArrayAttr(
72  extractOp.getOffsets().getValue().drop_front(dropCount));
73  auto newSizes = rewriter.getArrayAttr(
74  extractOp.getSizes().getValue().drop_front(dropCount));
75  auto newStrides = rewriter.getArrayAttr(
76  extractOp.getStrides().getValue().drop_front(dropCount));
77 
78  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
79  loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
80 
81  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
82  newExtractOp);
83 
84  return success();
85  }
86 };
87 
88 // Casts away leading one dimensions in vector.insert_strided_slice's vector
89 // inputs by inserting vector.broadcast.
90 struct CastAwayInsertStridedSliceLeadingOneDim
91  : public OpRewritePattern<vector::InsertStridedSliceOp> {
93 
94  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
95  PatternRewriter &rewriter) const override {
96  VectorType oldSrcType = insertOp.getSourceVectorType();
97  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
98  VectorType oldDstType = insertOp.getDestVectorType();
99  VectorType newDstType = trimLeadingOneDims(oldDstType);
100 
101  int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103  if (srcDropCount == 0 && dstDropCount == 0)
104  return failure();
105 
106  // Trim leading one dimensions from both operands.
107  Location loc = insertOp.getLoc();
108 
109  Value newSrcVector = rewriter.create<vector::ExtractOp>(
110  loc, insertOp.getSource(), splatZero(srcDropCount));
111  Value newDstVector = rewriter.create<vector::ExtractOp>(
112  loc, insertOp.getDest(), splatZero(dstDropCount));
113 
114  auto newOffsets = rewriter.getArrayAttr(
115  insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
116  auto newStrides = rewriter.getArrayAttr(
117  insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
118 
119  auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
120  loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
121 
122  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
123  newInsertOp);
124 
125  return success();
126  }
127 };
128 
129 // Casts away leading one dimensions in vector.insert's vector inputs by
130 // inserting vector.broadcast.
131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
133 
134  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135  PatternRewriter &rewriter) const override {
136  Type oldSrcType = insertOp.getSourceType();
137  Type newSrcType = oldSrcType;
138  int64_t oldSrcRank = 0, newSrcRank = 0;
139  if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140  newSrcType = trimLeadingOneDims(type);
141  oldSrcRank = type.getRank();
142  newSrcRank = newSrcType.cast<VectorType>().getRank();
143  }
144 
145  VectorType oldDstType = insertOp.getDestVectorType();
146  VectorType newDstType = trimLeadingOneDims(oldDstType);
147 
148  int64_t srcDropCount = oldSrcRank - newSrcRank;
149  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150  if (srcDropCount == 0 && dstDropCount == 0)
151  return failure();
152 
153  // Trim leading one dimensions from both operands.
154  Location loc = insertOp.getLoc();
155 
156  Value newSrcVector = insertOp.getSource();
157  if (oldSrcRank != 0) {
158  newSrcVector = rewriter.create<vector::ExtractOp>(
159  loc, insertOp.getSource(), splatZero(srcDropCount));
160  }
161  Value newDstVector = rewriter.create<vector::ExtractOp>(
162  loc, insertOp.getDest(), splatZero(dstDropCount));
163 
164  unsigned oldPosRank = insertOp.getPosition().getValue().size();
165  unsigned newPosRank = newDstType.getRank() - newSrcRank;
166  SmallVector<Attribute> newPositions = llvm::to_vector(
167  insertOp.getPosition().getValue().take_back(newPosRank));
168  if (newPosRank > oldPosRank) {
169  auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
170  newPositions.resize(newPosRank, zeroAttr);
171  }
172 
173  auto newInsertOp = rewriter.create<vector::InsertOp>(
174  loc, newDstType, newSrcVector, newDstVector,
175  rewriter.getArrayAttr(newPositions));
176 
177  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178  newInsertOp);
179 
180  return success();
181  }
182 };
183 
184 // Turns vector.transfer_read on vector with leading 1 dimensions into
185 // vector.shape_cast followed by vector.transfer_read on vector without leading
186 // 1 dimensions.
187 struct CastAwayTransferReadLeadingOneDim
188  : public OpRewritePattern<vector::TransferReadOp> {
190 
191  LogicalResult matchAndRewrite(vector::TransferReadOp read,
192  PatternRewriter &rewriter) const override {
193  // TODO: support 0-d corner case.
194  if (read.getTransferRank() == 0)
195  return failure();
196 
197  if (read.getMask())
198  return failure();
199 
200  auto shapedType = read.getSource().getType().cast<ShapedType>();
201  if (shapedType.getElementType() != read.getVectorType().getElementType())
202  return failure();
203 
204  VectorType oldType = read.getVectorType();
205  VectorType newType = trimLeadingOneDims(oldType);
206 
207  if (newType == oldType)
208  return failure();
209 
210  AffineMap oldMap = read.getPermutationMap();
211  ArrayRef<AffineExpr> newResults =
212  oldMap.getResults().take_back(newType.getRank());
213  AffineMap newMap =
214  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
215  rewriter.getContext());
216 
217  ArrayAttr inBoundsAttr;
218  if (read.getInBounds())
219  inBoundsAttr = rewriter.getArrayAttr(
220  read.getInBoundsAttr().getValue().take_back(newType.getRank()));
221 
222  auto newRead = rewriter.create<vector::TransferReadOp>(
223  read.getLoc(), newType, read.getSource(), read.getIndices(),
224  AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
225  inBoundsAttr);
226  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
227 
228  return success();
229  }
230 };
231 
232 // Turns vector.transfer_write on vector with leading 1 dimensions into
233 // vector.shape_cast followed by vector.transfer_write on vector without leading
234 // 1 dimensions.
235 struct CastAwayTransferWriteLeadingOneDim
236  : public OpRewritePattern<vector::TransferWriteOp> {
238 
239  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
240  PatternRewriter &rewriter) const override {
241  // TODO: support 0-d corner case.
242  if (write.getTransferRank() == 0)
243  return failure();
244 
245  if (write.getMask())
246  return failure();
247 
248  auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
249  if (shapedType.getElementType() != write.getVectorType().getElementType())
250  return failure();
251 
252  VectorType oldType = write.getVectorType();
253  VectorType newType = trimLeadingOneDims(oldType);
254  if (newType == oldType)
255  return failure();
256  int64_t dropDim = oldType.getRank() - newType.getRank();
257 
258  AffineMap oldMap = write.getPermutationMap();
259  ArrayRef<AffineExpr> newResults =
260  oldMap.getResults().take_back(newType.getRank());
261  AffineMap newMap =
262  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
263  rewriter.getContext());
264 
265  ArrayAttr inBoundsAttr;
266  if (write.getInBounds())
267  inBoundsAttr = rewriter.getArrayAttr(
268  write.getInBoundsAttr().getValue().take_back(newType.getRank()));
269 
270  auto newVector = rewriter.create<vector::ExtractOp>(
271  write.getLoc(), write.getVector(), splatZero(dropDim));
272  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
273  write, newVector, write.getSource(), write.getIndices(),
274  AffineMapAttr::get(newMap), inBoundsAttr);
275 
276  return success();
277  }
278 };
279 
280 /// Turns vector.contract on vector with leading 1 dimensions into
281 /// vector.extract followed by vector.contract on vector without leading
282 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
283 /// prior to extract.
284 struct CastAwayContractionLeadingOneDim
285  : public OpRewritePattern<vector::ContractionOp> {
287 
288  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289  PatternRewriter &rewriter) const override {
290  VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291  if (oldAccType == nullptr)
292  return failure();
293  if (oldAccType.getRank() < 2)
294  return failure();
295  // TODO: implement masks.
296  if (!contractOp.getMasks().empty())
297  return failure();
298  if (oldAccType.getShape()[0] != 1)
299  return failure();
300  // currently we support only dropping one dim but the pattern can be applied
301  // greedily to drop more.
302  int64_t dropDim = 1;
303 
304  auto oldIndexingMaps = contractOp.getIndexingMapsArray();
305  SmallVector<AffineMap> newIndexingMaps;
306 
307  auto oldIteratorTypes = contractOp.getIteratorTypes();
308  SmallVector<Attribute> newIteratorTypes;
309 
310  int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
311 
312  if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
313  // only parallel type iterators can be dropped.
314  return failure();
315 
316  for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
317  int64_t currDim = it.index();
318  if (currDim == dimToDrop)
319  continue;
320  newIteratorTypes.push_back(it.value());
321  }
322 
323  SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
324  contractOp.getAcc()};
325  SmallVector<Value> newOperands;
326 
327  for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
328  // Check if the dim to be dropped exists as a leading dim in the operand
329  // if it does then we use vector.extract to drop it.
330  bool validExtract = false;
331  SmallVector<AffineExpr> results;
332  auto map = it.value();
333  int64_t orginalZeroDim = it.value().getDimPosition(0);
334  if (orginalZeroDim != dimToDrop) {
335  // There are two reasons to be in this path, 1. We need to
336  // tranpose the operand to make the dim to be dropped
337  // leading. 2. The dim to be dropped does not exist and in
338  // that case we dont want to add a unit tranpose but we must
339  // check all the indices to make sure this is the case.
340  bool tranposeNeeded = false;
342  SmallVector<AffineExpr> transposeResults;
343 
344  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
345  int64_t currDim = map.getDimPosition(i);
346  if (currDim == dimToDrop) {
347  tranposeNeeded = true;
348  perm.insert(perm.begin(), i);
349  auto targetExpr = rewriter.getAffineDimExpr(currDim);
350  transposeResults.insert(transposeResults.begin(), targetExpr);
351  } else {
352  perm.push_back(i);
353  auto targetExpr = rewriter.getAffineDimExpr(currDim);
354  transposeResults.push_back(targetExpr);
355  }
356  }
357  // Do the tranpose now if needed so that we can drop the
358  // correct dim using extract later.
359  if (tranposeNeeded) {
360  map = AffineMap::get(map.getNumDims(), 0, transposeResults,
361  contractOp.getContext());
362  operands[it.index()] = rewriter.create<vector::TransposeOp>(
363  contractOp.getLoc(), operands[it.index()], perm);
364  }
365  }
366  // We have taken care to have the dim to be dropped be
367  // the leading dim. If its still not leading that means it
368  // does not exist in this operand and hence we do not need
369  // an extract.
370  if (map.getDimPosition(0) == dimToDrop)
371  validExtract = true;
372 
373  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
374  int64_t currDim = map.getDimPosition(i);
375  if (currDim == dimToDrop)
376  // This is the dim we are dropping.
377  continue;
378  auto targetExpr = rewriter.getAffineDimExpr(
379  currDim < dimToDrop ? currDim : currDim - 1);
380  results.push_back(targetExpr);
381  }
382  newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
383  contractOp.getContext()));
384  // Extract if its a valid extraction, otherwise use the operand
385  // without extraction.
386  newOperands.push_back(validExtract
387  ? rewriter.create<vector::ExtractOp>(
388  contractOp.getLoc(), operands[it.index()],
389  splatZero(dropDim))
390  : operands[it.index()]);
391  }
392  auto newContractOp = rewriter.create<vector::ContractionOp>(
393  contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
394  rewriter.getAffineMapArrayAttr(newIndexingMaps),
395  rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
396  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
397  contractOp, contractOp->getResultTypes()[0], newContractOp);
398  return success();
399  }
400 };
401 
402 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
403 public:
404  CastAwayElementwiseLeadingOneDim(MLIRContext *context,
405  PatternBenefit benefit = 1)
406  : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
407 
408  LogicalResult matchAndRewrite(Operation *op,
409  PatternRewriter &rewriter) const override {
411  return failure();
412  auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
413  if (!vecType)
414  return failure();
415  VectorType newVecType = trimLeadingOneDims(vecType);
416  if (newVecType == vecType)
417  return failure();
418  int64_t dropDim = vecType.getRank() - newVecType.getRank();
419  SmallVector<Value, 4> newOperands;
420  for (Value operand : op->getOperands()) {
421  if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
422  newOperands.push_back(rewriter.create<vector::ExtractOp>(
423  op->getLoc(), operand, splatZero(dropDim)));
424  } else {
425  newOperands.push_back(operand);
426  }
427  }
428  Operation *newOp =
429  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
430  newOperands, newVecType, op->getAttrs());
431  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
432  newOp->getResult(0));
433  return success();
434  }
435 };
436 
437 } // namespace
438 
440  RewritePatternSet &patterns, PatternBenefit benefit) {
441  patterns
442  .add<CastAwayExtractStridedSliceLeadingOneDim,
443  CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
444  CastAwayTransferReadLeadingOneDim,
445  CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
446  CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
447  populateShapeCastFoldingPatterns(patterns, benefit);
448 }
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:328
unsigned getNumDims() const
Definition: AffineMap.cpp:324
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:337
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:212
IntegerType getI64Type()
Definition: Builders.cpp:82
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:339
MLIRContext * getContext() const
Definition: Builders.h:55
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:318
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:259
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:312
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:418
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
result_type_range getResultTypes()
Definition: Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast() const
Definition: Types.h:311
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1147
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:165
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361