MLIR  16.0.0git
VectorDropLeadUnitDim.cpp
Go to the documentation of this file.
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 #define DEBUG_TYPE "vector-drop-unit-dim"
18 
19 using namespace mlir;
20 using namespace mlir::vector;
21 
22 // Trims leading one dimensions from `oldType` and returns the result type.
23 // Returns `vector<1xT>` if `oldType` only has one element.
24 static VectorType trimLeadingOneDims(VectorType oldType) {
25  ArrayRef<int64_t> oldShape = oldType.getShape();
26  ArrayRef<int64_t> newShape =
27  oldShape.drop_while([](int64_t dim) { return dim == 1; });
28  // Make sure we have at least 1 dimension per vector type requirements.
29  if (newShape.empty())
30  newShape = oldShape.take_back();
31  return VectorType::get(newShape, oldType.getElementType());
32 }
33 
34 /// Return a smallVector of size `rank` containing all zeros.
35 static SmallVector<int64_t> splatZero(int64_t rank) {
36  return SmallVector<int64_t>(rank, 0);
37 }
38 namespace {
39 
40 // Casts away leading one dimensions in vector.extract_strided_slice's vector
41 // input by inserting vector.broadcast.
42 struct CastAwayExtractStridedSliceLeadingOneDim
43  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
45 
46  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
47  PatternRewriter &rewriter) const override {
48  // vector.extract_strided_slice requires the input and output vector to have
49  // the same rank. Here we drop leading one dimensions from the input vector
50  // type to make sure we don't cause mismatch.
51  VectorType oldSrcType = extractOp.getVectorType();
52  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
53 
54  if (newSrcType.getRank() == oldSrcType.getRank())
55  return failure();
56 
57  int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
58 
59  VectorType oldDstType = extractOp.getType();
60  VectorType newDstType =
61  VectorType::get(oldDstType.getShape().drop_front(dropCount),
62  oldDstType.getElementType());
63 
64  Location loc = extractOp.getLoc();
65 
66  Value newSrcVector = rewriter.create<vector::ExtractOp>(
67  loc, extractOp.getVector(), splatZero(dropCount));
68 
69  // The offsets/sizes/strides attribute can have a less number of elements
70  // than the input vector's rank: it is meant for the leading dimensions.
71  auto newOffsets = rewriter.getArrayAttr(
72  extractOp.getOffsets().getValue().drop_front(dropCount));
73  auto newSizes = rewriter.getArrayAttr(
74  extractOp.getSizes().getValue().drop_front(dropCount));
75  auto newStrides = rewriter.getArrayAttr(
76  extractOp.getStrides().getValue().drop_front(dropCount));
77 
78  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
79  loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
80 
81  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
82  newExtractOp);
83 
84  return success();
85  }
86 };
87 
88 // Casts away leading one dimensions in vector.insert_strided_slice's vector
89 // inputs by inserting vector.broadcast.
90 struct CastAwayInsertStridedSliceLeadingOneDim
91  : public OpRewritePattern<vector::InsertStridedSliceOp> {
93 
94  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
95  PatternRewriter &rewriter) const override {
96  VectorType oldSrcType = insertOp.getSourceVectorType();
97  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
98  VectorType oldDstType = insertOp.getDestVectorType();
99  VectorType newDstType = trimLeadingOneDims(oldDstType);
100 
101  int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103  if (srcDropCount == 0 && dstDropCount == 0)
104  return failure();
105 
106  // Trim leading one dimensions from both operands.
107  Location loc = insertOp.getLoc();
108 
109  Value newSrcVector = rewriter.create<vector::ExtractOp>(
110  loc, insertOp.getSource(), splatZero(srcDropCount));
111  Value newDstVector = rewriter.create<vector::ExtractOp>(
112  loc, insertOp.getDest(), splatZero(dstDropCount));
113 
114  auto newOffsets = rewriter.getArrayAttr(
115  insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
116  auto newStrides = rewriter.getArrayAttr(
117  insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
118 
119  auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
120  loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
121 
122  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
123  newInsertOp);
124 
125  return success();
126  }
127 };
128 
129 // Casts away leading one dimensions in vector.insert's vector inputs by
130 // inserting vector.broadcast.
131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
133 
134  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135  PatternRewriter &rewriter) const override {
136  Type oldSrcType = insertOp.getSourceType();
137  Type newSrcType = oldSrcType;
138  int64_t oldSrcRank = 0, newSrcRank = 0;
139  if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140  newSrcType = trimLeadingOneDims(type);
141  oldSrcRank = type.getRank();
142  newSrcRank = newSrcType.cast<VectorType>().getRank();
143  }
144 
145  VectorType oldDstType = insertOp.getDestVectorType();
146  VectorType newDstType = trimLeadingOneDims(oldDstType);
147 
148  int64_t srcDropCount = oldSrcRank - newSrcRank;
149  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150  if (srcDropCount == 0 && dstDropCount == 0)
151  return failure();
152 
153  // Trim leading one dimensions from both operands.
154  Location loc = insertOp.getLoc();
155 
156  Value newSrcVector = insertOp.getSource();
157  if (oldSrcRank != 0) {
158  newSrcVector = rewriter.create<vector::ExtractOp>(
159  loc, insertOp.getSource(), splatZero(srcDropCount));
160  }
161  Value newDstVector = rewriter.create<vector::ExtractOp>(
162  loc, insertOp.getDest(), splatZero(dstDropCount));
163 
164  unsigned oldPosRank = insertOp.getPosition().getValue().size();
165  unsigned newPosRank = newDstType.getRank() - newSrcRank;
166  SmallVector<Attribute> newPositions = llvm::to_vector(
167  insertOp.getPosition().getValue().take_back(newPosRank));
168  if (newPosRank > oldPosRank) {
169  auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
170  newPositions.resize(newPosRank, zeroAttr);
171  }
172 
173  auto newInsertOp = rewriter.create<vector::InsertOp>(
174  loc, newDstType, newSrcVector, newDstVector,
175  rewriter.getArrayAttr(newPositions));
176 
177  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178  newInsertOp);
179 
180  return success();
181  }
182 };
183 
184 // Turns vector.transfer_read on vector with leading 1 dimensions into
185 // vector.shape_cast followed by vector.transfer_read on vector without leading
186 // 1 dimensions.
187 struct CastAwayTransferReadLeadingOneDim
188  : public OpRewritePattern<vector::TransferReadOp> {
190 
191  LogicalResult matchAndRewrite(vector::TransferReadOp read,
192  PatternRewriter &rewriter) const override {
193  // TODO: support 0-d corner case.
194  if (read.getTransferRank() == 0)
195  return failure();
196 
197  if (read.getMask())
198  return failure();
199 
200  auto shapedType = read.getSource().getType().cast<ShapedType>();
201  if (shapedType.getElementType() != read.getVectorType().getElementType())
202  return failure();
203 
204  VectorType oldType = read.getVectorType();
205  VectorType newType = trimLeadingOneDims(oldType);
206 
207  if (newType == oldType)
208  return failure();
209 
210  AffineMap oldMap = read.getPermutationMap();
211  ArrayRef<AffineExpr> newResults =
212  oldMap.getResults().take_back(newType.getRank());
213  AffineMap newMap =
214  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
215  rewriter.getContext());
216 
217  ArrayAttr inBoundsAttr;
218  if (read.getInBounds())
219  inBoundsAttr = rewriter.getArrayAttr(
220  read.getInBoundsAttr().getValue().take_back(newType.getRank()));
221 
222  auto newRead = rewriter.create<vector::TransferReadOp>(
223  read.getLoc(), newType, read.getSource(), read.getIndices(),
224  AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
225  inBoundsAttr);
226  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
227 
228  return success();
229  }
230 };
231 
232 // Turns vector.transfer_write on vector with leading 1 dimensions into
233 // vector.shape_cast followed by vector.transfer_write on vector without leading
234 // 1 dimensions.
235 struct CastAwayTransferWriteLeadingOneDim
236  : public OpRewritePattern<vector::TransferWriteOp> {
238 
239  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
240  PatternRewriter &rewriter) const override {
241  // TODO: support 0-d corner case.
242  if (write.getTransferRank() == 0)
243  return failure();
244 
245  if (write.getMask())
246  return failure();
247 
248  auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
249  if (shapedType.getElementType() != write.getVectorType().getElementType())
250  return failure();
251 
252  VectorType oldType = write.getVectorType();
253  VectorType newType = trimLeadingOneDims(oldType);
254  if (newType == oldType)
255  return failure();
256  int64_t dropDim = oldType.getRank() - newType.getRank();
257 
258  AffineMap oldMap = write.getPermutationMap();
259  ArrayRef<AffineExpr> newResults =
260  oldMap.getResults().take_back(newType.getRank());
261  AffineMap newMap =
262  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
263  rewriter.getContext());
264 
265  ArrayAttr inBoundsAttr;
266  if (write.getInBounds())
267  inBoundsAttr = rewriter.getArrayAttr(
268  write.getInBoundsAttr().getValue().take_back(newType.getRank()));
269 
270  auto newVector = rewriter.create<vector::ExtractOp>(
271  write.getLoc(), write.getVector(), splatZero(dropDim));
272  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
273  write, newVector, write.getSource(), write.getIndices(),
274  AffineMapAttr::get(newMap), inBoundsAttr);
275 
276  return success();
277  }
278 };
279 
280 /// Turns vector.contract on vector with leading 1 dimensions into
281 /// vector.extract followed by vector.contract on vector without leading
282 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
283 /// prior to extract.
284 struct CastAwayContractionLeadingOneDim
285  : public OpRewritePattern<vector::ContractionOp> {
287 
288  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289  PatternRewriter &rewriter) const override {
290  VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291  if (oldAccType == nullptr)
292  return failure();
293  if (oldAccType.getRank() < 2)
294  return failure();
295  // TODO: implement masks.
296  if (llvm::size(contractOp.getMasks()) != 0)
297  return failure();
298  if (oldAccType.getShape()[0] != 1)
299  return failure();
300  // currently we support only dropping one dim but the pattern can be applied
301  // greedily to drop more.
302  int64_t dropDim = 1;
303 
304  auto oldIndexingMaps = contractOp.getIndexingMapsArray();
305  SmallVector<AffineMap> newIndexingMaps;
306 
307  auto oldIteratorTypes = contractOp.getIteratorTypes();
308  SmallVector<Attribute> newIteratorTypes;
309 
310  int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
311 
312  if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
313  // only parallel type iterators can be dropped.
314  return failure();
315 
316  for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
317  int64_t currDim = it.index();
318  if (currDim == dimToDrop)
319  continue;
320  newIteratorTypes.push_back(it.value());
321  }
322 
323  SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
324  contractOp.getAcc()};
325  SmallVector<Value> newOperands;
326 
327  for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
328  // Check if the dim to be dropped exists as a leading dim in the operand
329  // if it does then we use vector.extract to drop it.
330  bool validExtract = false;
331  SmallVector<AffineExpr> results;
332  auto map = it.value();
333  int64_t orginalZeroDim = it.value().getDimPosition(0);
334  if (orginalZeroDim != dimToDrop) {
335  // There are two reasons to be in this path, 1. We need to
336  // tranpose the operand to make the dim to be dropped
337  // leading. 2. The dim to be dropped does not exist and in
338  // that case we dont want to add a unit tranpose but we must
339  // check all the indices to make sure this is the case.
340  bool tranposeNeeded = false;
341  SmallVector<int64_t> perm;
342  SmallVector<AffineExpr> transposeResults;
343 
344  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
345  int64_t currDim = map.getDimPosition(i);
346  if (currDim == dimToDrop) {
347  tranposeNeeded = true;
348  perm.insert(perm.begin(), i);
349  auto targetExpr = rewriter.getAffineDimExpr(currDim);
350  transposeResults.insert(transposeResults.begin(), targetExpr);
351  } else {
352  perm.push_back(i);
353  auto targetExpr = rewriter.getAffineDimExpr(currDim);
354  transposeResults.push_back(targetExpr);
355  }
356  }
357  // Do the tranpose now if needed so that we can drop the
358  // correct dim using extract later.
359  if (tranposeNeeded) {
360  map = AffineMap::get(map.getNumDims(), 0, transposeResults,
361  contractOp.getContext());
362  operands[it.index()] = rewriter.create<vector::TransposeOp>(
363  contractOp.getLoc(), operands[it.index()], perm);
364  }
365  }
366  // We have taken care to have the dim to be dropped be
367  // the leading dim. If its still not leading that means it
368  // does not exist in this operand and hence we do not need
369  // an extract.
370  if (map.getDimPosition(0) == dimToDrop)
371  validExtract = true;
372 
373  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
374  int64_t currDim = map.getDimPosition(i);
375  if (currDim == dimToDrop)
376  // This is the dim we are dropping.
377  continue;
378  auto targetExpr = rewriter.getAffineDimExpr(
379  currDim < dimToDrop ? currDim : currDim - 1);
380  results.push_back(targetExpr);
381  }
382  newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
383  contractOp.getContext()));
384  // Extract if its a valid extraction, otherwise use the operand
385  // without extraction.
386  newOperands.push_back(validExtract
387  ? rewriter.create<vector::ExtractOp>(
388  contractOp.getLoc(), operands[it.index()],
389  splatZero(dropDim))
390  : operands[it.index()]);
391  }
392  auto newContractOp = rewriter.create<vector::ContractionOp>(
393  contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
394  rewriter.getAffineMapArrayAttr(newIndexingMaps),
395  rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
396  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
397  contractOp, contractOp->getResultTypes()[0], newContractOp);
398  return success();
399  }
400 };
401 
402 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
403 public:
404  CastAwayElementwiseLeadingOneDim(MLIRContext *context)
405  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
406 
407  LogicalResult matchAndRewrite(Operation *op,
408  PatternRewriter &rewriter) const override {
410  return failure();
411  auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
412  if (!vecType)
413  return failure();
414  VectorType newVecType = trimLeadingOneDims(vecType);
415  if (newVecType == vecType)
416  return failure();
417  int64_t dropDim = vecType.getRank() - newVecType.getRank();
418  SmallVector<Value, 4> newOperands;
419  for (Value operand : op->getOperands()) {
420  if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
421  newOperands.push_back(rewriter.create<vector::ExtractOp>(
422  op->getLoc(), operand, splatZero(dropDim)));
423  } else {
424  newOperands.push_back(operand);
425  }
426  }
427  Operation *newOp =
428  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
429  newOperands, newVecType, op->getAttrs());
430  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
431  newOp->getResult(0));
432  return success();
433  }
434 };
435 
436 } // namespace
437 
439  RewritePatternSet &patterns) {
440  patterns
441  .add<CastAwayExtractStridedSliceLeadingOneDim,
442  CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
443  CastAwayTransferReadLeadingOneDim,
444  CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
445  CastAwayContractionLeadingOneDim>(patterns.getContext());
447 }
Include the generated interface declarations.
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool isParallelIterator(Attribute attr)
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns)
Collect a set of leading one dimension removal patterns.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1122
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:205
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:282
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:270
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
IntegerType getI64Type()
Definition: Builders.cpp:56
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns)
Collect a set of vector.shape_cast folding patterns.
static SmallVector< SmallVector< Operation *, 1 > > transposeResults(const SmallVector< SmallVector< Operation *>, 1 > &m)
Helper function: transpose MxN into NxM; assumes that the input is a valid.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:309
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:229
result_type_range getResultTypes()
Definition: Operation.h:345
MLIRContext * getContext() const
U cast() const
Definition: Types.h:278