MLIR  14.0.0git
VectorDropLeadUnitDim.cpp
Go to the documentation of this file.
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #define DEBUG_TYPE "vector-drop-unit-dim"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 // Trims leading one dimensions from `oldType` and returns the result type.
21 // Returns `vector<1xT>` if `oldType` only has one element.
22 static VectorType trimLeadingOneDims(VectorType oldType) {
23  ArrayRef<int64_t> oldShape = oldType.getShape();
24  ArrayRef<int64_t> newShape =
25  oldShape.drop_while([](int64_t dim) { return dim == 1; });
26  // Make sure we have at least 1 dimension per vector type requirements.
27  if (newShape.empty())
28  newShape = oldShape.take_back();
29  return VectorType::get(newShape, oldType.getElementType());
30 }
31 
32 /// Return a smallVector of size `rank` containing all zeros.
33 static SmallVector<int64_t> splatZero(int64_t rank) {
34  return SmallVector<int64_t>(rank, 0);
35 }
36 namespace {
37 
38 // Casts away leading one dimensions in vector.extract_strided_slice's vector
39 // input by inserting vector.shape_cast.
40 struct CastAwayExtractStridedSliceLeadingOneDim
41  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
43 
44  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
45  PatternRewriter &rewriter) const override {
46  // vector.extract_strided_slice requires the input and output vector to have
47  // the same rank. Here we drop leading one dimensions from the input vector
48  // type to make sure we don't cause mismatch.
49  VectorType oldSrcType = extractOp.getVectorType();
50  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
51 
52  if (newSrcType.getRank() == oldSrcType.getRank())
53  return failure();
54 
55  int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
56 
57  VectorType oldDstType = extractOp.getType();
58  VectorType newDstType =
59  VectorType::get(oldDstType.getShape().drop_front(dropCount),
60  oldDstType.getElementType());
61 
62  Location loc = extractOp.getLoc();
63 
64  Value newSrcVector = rewriter.create<vector::ExtractOp>(
65  loc, extractOp.vector(), splatZero(dropCount));
66 
67  // The offsets/sizes/strides attribute can have a less number of elements
68  // than the input vector's rank: it is meant for the leading dimensions.
69  auto newOffsets = rewriter.getArrayAttr(
70  extractOp.offsets().getValue().drop_front(dropCount));
71  auto newSizes = rewriter.getArrayAttr(
72  extractOp.sizes().getValue().drop_front(dropCount));
73  auto newStrides = rewriter.getArrayAttr(
74  extractOp.strides().getValue().drop_front(dropCount));
75 
76  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
77  loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
78 
79  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
80  newExtractOp);
81 
82  return success();
83  }
84 };
85 
86 // Casts away leading one dimensions in vector.extract_strided_slice's vector
87 // inputs by inserting vector.shape_cast.
88 struct CastAwayInsertStridedSliceLeadingOneDim
89  : public OpRewritePattern<vector::InsertStridedSliceOp> {
91 
92  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
93  PatternRewriter &rewriter) const override {
94  VectorType oldSrcType = insertOp.getSourceVectorType();
95  VectorType newSrcType = trimLeadingOneDims(oldSrcType);
96  VectorType oldDstType = insertOp.getDestVectorType();
97  VectorType newDstType = trimLeadingOneDims(oldDstType);
98 
99  int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
100  int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
101  if (srcDropCount == 0 && dstDropCount == 0)
102  return failure();
103 
104  // Trim leading one dimensions from both operands.
105  Location loc = insertOp.getLoc();
106 
107  Value newSrcVector = rewriter.create<vector::ExtractOp>(
108  loc, insertOp.source(), splatZero(srcDropCount));
109  Value newDstVector = rewriter.create<vector::ExtractOp>(
110  loc, insertOp.dest(), splatZero(dstDropCount));
111 
112  auto newOffsets = rewriter.getArrayAttr(
113  insertOp.offsets().getValue().take_back(newDstType.getRank()));
114  auto newStrides = rewriter.getArrayAttr(
115  insertOp.strides().getValue().take_back(newSrcType.getRank()));
116 
117  auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
118  loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
119 
120  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
121  newInsertOp);
122 
123  return success();
124  }
125 };
126 
127 // Turns vector.transfer_read on vector with leading 1 dimensions into
128 // vector.shape_cast followed by vector.transfer_read on vector without leading
129 // 1 dimensions.
130 struct CastAwayTransferReadLeadingOneDim
131  : public OpRewritePattern<vector::TransferReadOp> {
133 
134  LogicalResult matchAndRewrite(vector::TransferReadOp read,
135  PatternRewriter &rewriter) const override {
136  // TODO: support 0-d corner case.
137  if (read.getTransferRank() == 0)
138  return failure();
139 
140  if (read.mask())
141  return failure();
142 
143  auto shapedType = read.source().getType().cast<ShapedType>();
144  if (shapedType.getElementType() != read.getVectorType().getElementType())
145  return failure();
146 
147  VectorType oldType = read.getVectorType();
148  VectorType newType = trimLeadingOneDims(oldType);
149 
150  if (newType == oldType)
151  return failure();
152 
153  AffineMap oldMap = read.permutation_map();
154  ArrayRef<AffineExpr> newResults =
155  oldMap.getResults().take_back(newType.getRank());
156  AffineMap newMap =
157  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
158  rewriter.getContext());
159 
160  ArrayAttr inBoundsAttr;
161  if (read.in_bounds())
162  inBoundsAttr = rewriter.getArrayAttr(
163  read.in_boundsAttr().getValue().take_back(newType.getRank()));
164 
165  auto newRead = rewriter.create<vector::TransferReadOp>(
166  read.getLoc(), newType, read.source(), read.indices(),
167  AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(),
168  inBoundsAttr);
169  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
170 
171  return success();
172  }
173 };
174 
175 // Turns vector.transfer_write on vector with leading 1 dimensions into
176 // vector.shape_cast followed by vector.transfer_write on vector without leading
177 // 1 dimensions.
178 struct CastAwayTransferWriteLeadingOneDim
179  : public OpRewritePattern<vector::TransferWriteOp> {
181 
182  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
183  PatternRewriter &rewriter) const override {
184  // TODO: support 0-d corner case.
185  if (write.getTransferRank() == 0)
186  return failure();
187 
188  if (write.mask())
189  return failure();
190 
191  auto shapedType = write.source().getType().dyn_cast<ShapedType>();
192  if (shapedType.getElementType() != write.getVectorType().getElementType())
193  return failure();
194 
195  VectorType oldType = write.getVectorType();
196  VectorType newType = trimLeadingOneDims(oldType);
197  if (newType == oldType)
198  return failure();
199  int64_t dropDim = oldType.getRank() - newType.getRank();
200 
201  AffineMap oldMap = write.permutation_map();
202  ArrayRef<AffineExpr> newResults =
203  oldMap.getResults().take_back(newType.getRank());
204  AffineMap newMap =
205  AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
206  rewriter.getContext());
207 
208  ArrayAttr inBoundsAttr;
209  if (write.in_bounds())
210  inBoundsAttr = rewriter.getArrayAttr(
211  write.in_boundsAttr().getValue().take_back(newType.getRank()));
212 
213  auto newVector = rewriter.create<vector::ExtractOp>(
214  write.getLoc(), write.vector(), splatZero(dropDim));
215  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
216  write, newVector, write.source(), write.indices(),
217  AffineMapAttr::get(newMap), inBoundsAttr);
218 
219  return success();
220  }
221 };
222 
223 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
224 public:
225  CastAwayElementwiseLeadingOneDim(MLIRContext *context)
226  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
227 
228  LogicalResult matchAndRewrite(Operation *op,
229  PatternRewriter &rewriter) const override {
231  return failure();
232  auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
233  if (!vecType)
234  return failure();
235  VectorType newVecType = trimLeadingOneDims(vecType);
236  if (newVecType == vecType)
237  return failure();
238  int64_t dropDim = vecType.getRank() - newVecType.getRank();
239  SmallVector<Value, 4> newOperands;
240  for (Value operand : op->getOperands()) {
241  if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
242  newOperands.push_back(rewriter.create<vector::ExtractOp>(
243  op->getLoc(), operand, splatZero(dropDim)));
244  } else {
245  newOperands.push_back(operand);
246  }
247  }
248  OperationState state(op->getLoc(), op->getName());
249  state.addAttributes(op->getAttrs());
250  state.addOperands(newOperands);
251  state.addTypes(newVecType);
252  Operation *newOp = rewriter.createOperation(state);
253  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
254  newOp->getResult(0));
255  return success();
256  }
257 };
258 
259 } // namespace
260 
262  RewritePatternSet &patterns) {
263  patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
264  CastAwayInsertStridedSliceLeadingOneDim,
265  CastAwayTransferReadLeadingOneDim,
266  CastAwayTransferWriteLeadingOneDim,
267  CastAwayElementwiseLeadingOneDim>(patterns.getContext());
269 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
MLIRContext * getContext() const
Definition: Builders.h:54
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns)
Collect a set of leading one dimension removal patterns.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:308
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1122
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Operation * createOperation(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
This represents an operation in an abstracted form, suitable for use with the builder APIs...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:311
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns)
Collect a set of vector.shape_cast folding patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
result_type_range getResultTypes()
Definition: Operation.h:297
MLIRContext * getContext() const
Definition: PatternMatch.h:906