MLIR  16.0.0git
VectorTransferPermutationMapRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute according to given
24 /// indices.
25 static ArrayAttr
26 transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27  const SmallVector<unsigned> &permutation) {
28  SmallVector<bool> newInBoundsValues;
29  for (unsigned pos : permutation)
30  newInBoundsValues.push_back(
31  attr.getValue()[pos].cast<BoolAttr>().getValue());
32  return builder.getBoolArrayAttr(newInBoundsValues);
33 }
34 
35 /// Lower transfer_read op with permutation into a transfer_read with a
36 /// permutation map composed of leading zeros followed by a minor identiy +
37 /// vector.transpose op.
38 /// Ex:
39 /// vector.transfer_read ...
40 /// permutation_map: (d0, d1, d2) -> (0, d1)
41 /// into:
42 /// %v = vector.transfer_read ...
43 /// permutation_map: (d0, d1, d2) -> (d1, 0)
44 /// vector.transpose %v, [1, 0]
45 ///
46 /// vector.transfer_read ...
47 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
48 /// into:
49 /// %v = vector.transfer_read ...
50 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
51 /// vector.transpose %v, [0, 1, 3, 2, 4]
52 /// Note that an alternative is to transform it to linalg.transpose +
53 /// vector.transfer_read to do the transpose in memory instead.
55  : public OpRewritePattern<vector::TransferReadOp> {
57 
58  LogicalResult matchAndRewrite(vector::TransferReadOp op,
59  PatternRewriter &rewriter) const override {
60  // TODO: support 0-d corner case.
61  if (op.getTransferRank() == 0)
62  return failure();
63 
64  SmallVector<unsigned> permutation;
65  AffineMap map = op.getPermutationMap();
66  if (map.getNumResults() == 0)
67  return failure();
69  return failure();
70  AffineMap permutationMap =
71  map.getPermutationMap(permutation, op.getContext());
72  if (permutationMap.isIdentity())
73  return failure();
74 
75  permutationMap = map.getPermutationMap(permutation, op.getContext());
76  // Caluclate the map of the new read by applying the inverse permutation.
77  permutationMap = inversePermutation(permutationMap);
78  AffineMap newMap = permutationMap.compose(map);
79  // Apply the reverse transpose to deduce the type of the transfer_read.
80  ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
81  SmallVector<int64_t> newVectorShape(originalShape.size());
82  for (const auto &pos : llvm::enumerate(permutation)) {
83  newVectorShape[pos.value()] = originalShape[pos.index()];
84  }
85 
86  // Transpose mask operand.
87  Value newMask;
88  if (op.getMask()) {
89  // Remove unused dims from the permutation map. E.g.:
90  // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
91  // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
92  auto comp = compressUnusedDims(map);
93  // Get positions of remaining result dims.
94  // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
95  // maskTransposeIndices = [ 2, 1, 0]
96  SmallVector<int64_t> maskTransposeIndices;
97  for (unsigned i = 0; i < comp.getNumResults(); ++i) {
98  if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
99  maskTransposeIndices.push_back(expr.getPosition());
100  }
101 
102  newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
103  maskTransposeIndices);
104  }
105 
106  // Transpose in_bounds attribute.
107  ArrayAttr newInBoundsAttr =
108  op.getInBounds() ? transposeInBoundsAttr(
109  rewriter, op.getInBounds().value(), permutation)
110  : ArrayAttr();
111 
112  // Generate new transfer_read operation.
113  VectorType newReadType =
114  VectorType::get(newVectorShape, op.getVectorType().getElementType());
115  Value newRead = rewriter.create<vector::TransferReadOp>(
116  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
117  AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
118 
119  // Transpose result of transfer_read.
120  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
121  rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
122  transposePerm);
123  return success();
124  }
125 };
126 
127 /// Lower transfer_write op with permutation into a transfer_write with a
128 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
129 /// Ex:
130 /// vector.transfer_write %v ...
131 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
132 /// into:
133 /// %tmp = vector.transpose %v, [2, 0, 1]
134 /// vector.transfer_write %tmp ...
135 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
136 ///
137 /// vector.transfer_write %v ...
138 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
139 /// into:
140 /// %tmp = vector.transpose %v, [1, 0]
141 /// %v = vector.transfer_write %tmp ...
142 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
144  : public OpRewritePattern<vector::TransferWriteOp> {
146 
147  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
148  PatternRewriter &rewriter) const override {
149  // TODO: support 0-d corner case.
150  if (op.getTransferRank() == 0)
151  return failure();
152 
153  SmallVector<unsigned> permutation;
154  AffineMap map = op.getPermutationMap();
155  if (map.isMinorIdentity())
156  return failure();
157  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
158  return failure();
159 
160  // Remove unused dims from the permutation map. E.g.:
161  // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
162  // comp = (d0, d1, d2) -> (d2, d0, d1)
163  auto comp = compressUnusedDims(map);
164  // Get positions of remaining result dims.
165  SmallVector<int64_t> indices;
166  llvm::transform(comp.getResults(), std::back_inserter(indices),
167  [](AffineExpr expr) {
168  return expr.dyn_cast<AffineDimExpr>().getPosition();
169  });
170 
171  // Transpose mask operand.
172  Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
173  op.getLoc(), op.getMask(), indices)
174  : Value();
175 
176  // Transpose in_bounds attribute.
177  ArrayAttr newInBoundsAttr =
178  op.getInBounds() ? transposeInBoundsAttr(
179  rewriter, op.getInBounds().value(), permutation)
180  : ArrayAttr();
181 
182  // Generate new transfer_write operation.
183  Value newVec = rewriter.create<vector::TransposeOp>(
184  op.getLoc(), op.getVector(), indices);
185  auto newMap = AffineMap::getMinorIdentityMap(
186  map.getNumDims(), map.getNumResults(), rewriter.getContext());
187  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
188  op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
189  newMask, newInBoundsAttr);
190 
191  return success();
192  }
193 };
194 
195 /// Lower transfer_read op with broadcast in the leading dimensions into
196 /// transfer_read of lower rank + vector.broadcast.
197 /// Ex: vector.transfer_read ...
198 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
199 /// into:
200 /// %v = vector.transfer_read ...
201 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
202 /// vector.broadcast %v
203 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
205 
206  LogicalResult matchAndRewrite(vector::TransferReadOp op,
207  PatternRewriter &rewriter) const override {
208  // TODO: support 0-d corner case.
209  if (op.getTransferRank() == 0)
210  return failure();
211 
212  AffineMap map = op.getPermutationMap();
213  unsigned numLeadingBroadcast = 0;
214  for (auto expr : map.getResults()) {
215  auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
216  if (!dimExpr || dimExpr.getValue() != 0)
217  break;
218  numLeadingBroadcast++;
219  }
220  // If there are no leading zeros in the map there is nothing to do.
221  if (numLeadingBroadcast == 0)
222  return failure();
223  VectorType originalVecType = op.getVectorType();
224  unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
225  // Calculate new map, vector type and masks without the leading zeros.
226  AffineMap newMap = AffineMap::get(
227  map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
228  op.getContext());
229  // Only remove the leading zeros if the rest of the map is a minor identity
230  // with broadasting. Otherwise we first want to permute the map.
231  if (!newMap.isMinorIdentityWithBroadcasting())
232  return failure();
233 
234  // TODO: support zero-dimension vectors natively. See:
235  // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
236  // In the meantime, lower these to a scalar load when they pop up.
237  if (reducedShapeRank == 0) {
238  Value newRead;
239  if (op.getShapedType().isa<TensorType>()) {
240  newRead = rewriter.create<tensor::ExtractOp>(
241  op.getLoc(), op.getSource(), op.getIndices());
242  } else {
243  newRead = rewriter.create<memref::LoadOp>(
244  op.getLoc(), originalVecType.getElementType(), op.getSource(),
245  op.getIndices());
246  }
247  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
248  newRead);
249  return success();
250  }
251  SmallVector<int64_t> newShape = llvm::to_vector<4>(
252  originalVecType.getShape().take_back(reducedShapeRank));
253  // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
254  if (newShape.empty())
255  return failure();
256  VectorType newReadType =
257  VectorType::get(newShape, originalVecType.getElementType());
258  ArrayAttr newInBoundsAttr =
259  op.getInBounds()
260  ? rewriter.getArrayAttr(
261  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
262  : ArrayAttr();
263  Value newRead = rewriter.create<vector::TransferReadOp>(
264  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
265  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
266  newInBoundsAttr);
267  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
268  newRead);
269  return success();
270  }
271 };
272 
274  RewritePatternSet &patterns) {
277  patterns.getContext());
278 }
Include the generated interface declarations.
LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:653
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:439
unsigned getNumDims() const
Definition: AffineMap.cpp:294
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
Definition: AffineMap.cpp:117
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:205
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:562
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:109
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Lower transfer_read op with broadcast in the leading dimensions into transfer_read of lower rank + ve...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Lower transfer_write op with permutation into a transfer_write with a minor identity permutation map...
LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
Base type for affine expression.
Definition: AffineExpr.h:68
unsigned getNumResults() const
Definition: AffineMap.cpp:302
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:102
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:233
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:157
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
static ArrayAttr transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op&#39;s in_bounds attribute according to given indices.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:255
LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override
Lower transfer_read op with permutation into a transfer_read with a permutation map composed of leadi...
This class helps build Operations.
Definition: Builders.h:192
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:229
MLIRContext * getContext() const