MLIR  16.0.0git
VectorTransferPermutationMapRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute according to given
24 /// indices.
25 static ArrayAttr
26 transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27  const SmallVector<unsigned> &permutation) {
28  SmallVector<bool> newInBoundsValues;
29  for (unsigned pos : permutation)
30  newInBoundsValues.push_back(
31  attr.getValue()[pos].cast<BoolAttr>().getValue());
32  return builder.getBoolArrayAttr(newInBoundsValues);
33 }
34 
35 /// Lower transfer_read op with permutation into a transfer_read with a
36 /// permutation map composed of leading zeros followed by a minor identiy +
37 /// vector.transpose op.
38 /// Ex:
39 /// vector.transfer_read ...
40 /// permutation_map: (d0, d1, d2) -> (0, d1)
41 /// into:
42 /// %v = vector.transfer_read ...
43 /// permutation_map: (d0, d1, d2) -> (d1, 0)
44 /// vector.transpose %v, [1, 0]
45 ///
46 /// vector.transfer_read ...
47 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
48 /// into:
49 /// %v = vector.transfer_read ...
50 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
51 /// vector.transpose %v, [0, 1, 3, 2, 4]
52 /// Note that an alternative is to transform it to linalg.transpose +
53 /// vector.transfer_read to do the transpose in memory instead.
55  : public OpRewritePattern<vector::TransferReadOp> {
57 
58  LogicalResult matchAndRewrite(vector::TransferReadOp op,
59  PatternRewriter &rewriter) const override {
60  // TODO: support 0-d corner case.
61  if (op.getTransferRank() == 0)
62  return failure();
63 
64  SmallVector<unsigned> permutation;
65  AffineMap map = op.getPermutationMap();
66  if (map.getNumResults() == 0)
67  return failure();
69  return failure();
70  AffineMap permutationMap =
71  map.getPermutationMap(permutation, op.getContext());
72  if (permutationMap.isIdentity())
73  return failure();
74 
75  permutationMap = map.getPermutationMap(permutation, op.getContext());
76  // Caluclate the map of the new read by applying the inverse permutation.
77  permutationMap = inversePermutation(permutationMap);
78  AffineMap newMap = permutationMap.compose(map);
79  // Apply the reverse transpose to deduce the type of the transfer_read.
80  ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
81  SmallVector<int64_t> newVectorShape(originalShape.size());
82  for (const auto &pos : llvm::enumerate(permutation)) {
83  newVectorShape[pos.value()] = originalShape[pos.index()];
84  }
85 
86  // Transpose in_bounds attribute.
87  ArrayAttr newInBoundsAttr =
88  op.getInBounds() ? transposeInBoundsAttr(
89  rewriter, op.getInBounds().value(), permutation)
90  : ArrayAttr();
91 
92  // Generate new transfer_read operation.
93  VectorType newReadType =
94  VectorType::get(newVectorShape, op.getVectorType().getElementType());
95  Value newRead = rewriter.create<vector::TransferReadOp>(
96  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
97  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
98  newInBoundsAttr);
99 
100  // Transpose result of transfer_read.
101  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
102  rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
103  transposePerm);
104  return success();
105  }
106 };
107 
108 /// Lower transfer_write op with permutation into a transfer_write with a
109 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
110 /// Ex:
111 /// vector.transfer_write %v ...
112 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
113 /// into:
114 /// %tmp = vector.transpose %v, [2, 0, 1]
115 /// vector.transfer_write %tmp ...
116 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
117 ///
118 /// vector.transfer_write %v ...
119 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
120 /// into:
121 /// %tmp = vector.transpose %v, [1, 0]
122 /// %v = vector.transfer_write %tmp ...
123 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
125  : public OpRewritePattern<vector::TransferWriteOp> {
127 
128  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
129  PatternRewriter &rewriter) const override {
130  // TODO: support 0-d corner case.
131  if (op.getTransferRank() == 0)
132  return failure();
133 
134  SmallVector<unsigned> permutation;
135  AffineMap map = op.getPermutationMap();
136  if (map.isMinorIdentity())
137  return failure();
138  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
139  return failure();
140 
141  // Remove unused dims from the permutation map. E.g.:
142  // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
143  // comp = (d0, d1, d2) -> (d2, d0, d1)
144  auto comp = compressUnusedDims(map);
145  // Get positions of remaining result dims.
146  SmallVector<int64_t> indices;
147  llvm::transform(comp.getResults(), std::back_inserter(indices),
148  [](AffineExpr expr) {
149  return expr.dyn_cast<AffineDimExpr>().getPosition();
150  });
151 
152  // Transpose in_bounds attribute.
153  ArrayAttr newInBoundsAttr =
154  op.getInBounds() ? transposeInBoundsAttr(
155  rewriter, op.getInBounds().value(), permutation)
156  : ArrayAttr();
157 
158  // Generate new transfer_write operation.
159  Value newVec = rewriter.create<vector::TransposeOp>(
160  op.getLoc(), op.getVector(), indices);
161  auto newMap = AffineMap::getMinorIdentityMap(
162  map.getNumDims(), map.getNumResults(), rewriter.getContext());
163  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
164  op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
165  op.getMask(), newInBoundsAttr);
166 
167  return success();
168  }
169 };
170 
171 /// Lower transfer_read op with broadcast in the leading dimensions into
172 /// transfer_read of lower rank + vector.broadcast.
173 /// Ex: vector.transfer_read ...
174 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
175 /// into:
176 /// %v = vector.transfer_read ...
177 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
178 /// vector.broadcast %v
179 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
181 
182  LogicalResult matchAndRewrite(vector::TransferReadOp op,
183  PatternRewriter &rewriter) const override {
184  // TODO: support 0-d corner case.
185  if (op.getTransferRank() == 0)
186  return failure();
187 
188  AffineMap map = op.getPermutationMap();
189  unsigned numLeadingBroadcast = 0;
190  for (auto expr : map.getResults()) {
191  auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
192  if (!dimExpr || dimExpr.getValue() != 0)
193  break;
194  numLeadingBroadcast++;
195  }
196  // If there are no leading zeros in the map there is nothing to do.
197  if (numLeadingBroadcast == 0)
198  return failure();
199  VectorType originalVecType = op.getVectorType();
200  unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
201  // Calculate new map, vector type and masks without the leading zeros.
202  AffineMap newMap = AffineMap::get(
203  map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
204  op.getContext());
205  // Only remove the leading zeros if the rest of the map is a minor identity
206  // with broadasting. Otherwise we first want to permute the map.
207  if (!newMap.isMinorIdentityWithBroadcasting())
208  return failure();
209 
210  // TODO: support zero-dimension vectors natively. See:
211  // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
212  // In the meantime, lower these to a scalar load when they pop up.
213  if (reducedShapeRank == 0) {
214  Value newRead;
215  if (op.getShapedType().isa<TensorType>()) {
216  newRead = rewriter.create<tensor::ExtractOp>(
217  op.getLoc(), op.getSource(), op.getIndices());
218  } else {
219  newRead = rewriter.create<memref::LoadOp>(
220  op.getLoc(), originalVecType.getElementType(), op.getSource(),
221  op.getIndices());
222  }
223  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
224  newRead);
225  return success();
226  }
227  SmallVector<int64_t> newShape = llvm::to_vector<4>(
228  originalVecType.getShape().take_back(reducedShapeRank));
229  // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
230  if (newShape.empty())
231  return failure();
232  VectorType newReadType =
233  VectorType::get(newShape, originalVecType.getElementType());
234  ArrayAttr newInBoundsAttr =
235  op.getInBounds()
236  ? rewriter.getArrayAttr(
237  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
238  : ArrayAttr();
239  Value newRead = rewriter.create<vector::TransferReadOp>(
240  op.getLoc(), newReadType, op.getSource(), op.getIndices(),
241  AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
242  newInBoundsAttr);
243  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
244  newRead);
245  return success();
246  }
247 };
248 
250  RewritePatternSet &patterns, PatternBenefit benefit) {
253  patterns.getContext(), benefit);
254 }
static ArrayAttr transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute according to given indices.
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:103
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:110
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
Definition: AffineMap.cpp:118
unsigned getNumDims() const
Definition: AffineMap.cpp:306
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:319
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:158
unsigned getNumResults() const
Definition: AffineMap.cpp:314
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:206
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:455
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:267
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
MLIRContext * getContext() const
Definition: Builders.h:54
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:247
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:251
This class helps build Operations.
Definition: Builders.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:669
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:578
Lower transfer_read op with broadcast in the leading dimensions into transfer_read of lower rank + ve...
LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override
Lower transfer_read op with permutation into a transfer_read with a permutation map composed of leadi...
LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override
Lower transfer_write op with permutation into a transfer_write with a minor identity permutation map.
LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360