MLIR  22.0.0git
DecomposeGenericByUnfoldingPermutation.cpp
Go to the documentation of this file.
1 //===- DecomposeGenericByUnfoldingPermutation.cpp -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
11 #include <map>
12 #include <utility>
13 
14 using namespace mlir;
15 using namespace mlir::linalg;
16 
17 namespace {
18 
19 /// This pattern decomposes the input operand(s) of a linalg.generic that has
20 /// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
21 /// and broadcast. Having them folded into the linalg.generic is a good
22 /// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
23 /// explicit transpose and broadcast. This rewrite pattern helps do it for
24 /// each input operand. This is useful for instance when trying to recognize
25 /// named ops.
26 ///
27 /// The transpose, broadcast, or mixture of both, are expressed in the affine
28 /// map of the operand. Technically it is essentially `projected permutation`.
29 ///
30 /// Example
31 ///
32 /// ```mlir
33 ///
34 /// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
35 /// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
36 /// ...
37 /// %res = linalg.generic
38 /// { indexing_maps = [#projection, #identity, #identity],
39 /// iterator_types = ["parallel", "parallel", "parallel",
40 /// "parallel", "parallel"]}
41 /// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
42 /// outs(%z : tensor<5x9x7x8x10xf32>) {
43 /// ^bb0(%in: f32, %in_1: f32, %out: f32):
44 /// %div = arith.divf %in, %in_1 : f32
45 /// linalg.yield %div : f32
46 /// } -> tensor<5x9x7x8x10xf32>
47 /// ```
48 ///
49 /// In the above IR operand `%x` map is a projected-permutation. This can be
50 /// unfolded as:
51 ///
52 /// ```mlir
53 /// ...
54 /// %x_trans = linalg.transpose
55 /// ins(%x : tensor<7x8x9xf32>)
56 /// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
57 /// ...
58 /// %x_trans_bc = linalg.broadcast
59 /// ins(%x_trans : tensor<9x7x8xf32>)
60 /// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
61 /// %2 = linalg.div
62 /// ins(%x_trans_bc, %y :
63 /// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
64 /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
65 ///
66 /// Note that linalg.generic has been 'specialized' to linalg.div.
67 ///
68 /// To unfold it, it is more optimal to transpose first and then do the
69 /// broadcast. However, if transpose is done first, the permutation map needs
70 /// to be expressed in terms of reduced dimension as broadcast hasn't happened
71 /// yet. Also, the broadcast dimensions in a linalg.generic come from other
72 /// operands (those not broadcasted along that particular dimension). We work
73 /// this out by computing the convex-polyhedron shape of the linalg.generic
74 /// iteration space from shapes of all the operands, both inputs and outputs.
75 ///
76 struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
78 
79  LogicalResult matchAndRewrite(GenericOp genericOp,
80  PatternRewriter &rewriter) const override;
81 };
82 
83 /// For the given `map`, determine what dimensions are transposed and what
84 /// dimensions are broadcasted.
85 /// Returns :
86 /// transpose-permutation, broadcast-dimensions` (empty if not needed)
87 ///
88 std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
89 computeTransposeBroadcast(AffineMap &map) {
90  assert(map.isProjectedPermutation(false) && "not a projection");
91 
92  // As the map is a projection it likely operates on a smaller set of
93  // dimensions as far as the transpose is concerned (rest are broadcast).
94  int64_t minorSize = map.getNumResults();
95 
96  SmallVector<int64_t> minorResult;
97  for (int64_t i = 0; i < minorSize; ++i) {
98  auto expr = cast<AffineDimExpr>(map.getResults()[i]);
99  minorResult.push_back(expr.getPosition());
100  }
101 
102  // If dims are not monotonically increasing then transpose is present.
103  SmallVector<int64_t> sortedResMap(minorResult);
104  llvm::sort(sortedResMap);
105  bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
106  sortedResMap.begin(), sortedResMap.end());
107 
108  // Walk the sorted map result to determine which dimensions are broadcasted.
110  for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
111  if (j < minorSize && sortedResMap[j] == i) {
112  j++;
113  continue;
114  }
115  broadcast.push_back(i);
116  }
117 
118  SmallVector<int64_t> permutation;
119  if (hasTranspose) {
120  // Consider an operand `x : tensor<7x8x9>` of a genericOp that has
121  // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
122  // `x`s access is both transposed and broadcast. But when specifying
123  // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
124  // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
125  // refering to d3, d4. Therefore, re-base the transpose dimensions so
126  // that they start from d0.
127  permutation.resize(minorSize);
128  std::map<int64_t, int64_t> minorMap;
129  for (int64_t i = 0; i < minorSize; ++i)
130  minorMap.insert({sortedResMap[i], i});
131 
132  // Re-map the dimensions.
133  SmallVector<int64_t> remappedResult(minorSize);
134  for (int64_t i = 0; i < minorSize; ++i)
135  remappedResult[i] = minorMap[minorResult[i]];
136 
137  /// Calculate the permutation for the transpose.
138  for (unsigned i = 0; i < minorSize; ++i) {
139  permutation[remappedResult[i]] = i;
140  }
141  }
142  return {permutation, broadcast};
143 }
144 
145 LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
146  GenericOp op, PatternRewriter &rewriter) const {
147  if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
148  op.isSingleYieldOp() || !op.isAllParallelLoops())
149  return failure();
150 
151  // If the map of an operand is not a `projected permutation` then
152  // it cannot be decomposed to mere transpose and broadcast.
153  // The requirement that all maps be `projected permutation` may be
154  // over-restrictive but since we need to determine shape of the
155  // iteration space as well, reject if any map violates assumption.
156  for (auto &opOperand : op->getOpOperands()) {
157  auto map = op.getMatchingIndexingMap(&opOperand);
158  if (!map.isProjectedPermutation(false))
159  return failure();
160  }
161 
162  // Decomposing linalg.generic involves creating `tensor.empty`
163  // which can have dynamic shapes but then we would have to work
164  // out which operand can supply that runtime-value (tensor.dim).
165  // Leaving it as a future TODO.
166  if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
167  auto opType = cast<RankedTensorType>(oper.get().getType());
168  return ShapedType::isDynamicShape(opType.getShape());
169  }))
170  return failure();
171 
172  auto outputShape = op.getStaticLoopRanges();
173 
174  auto loc = op.getLoc();
175  bool isChanged = false;
176  SmallVector<Value> newInitValues = op.getDpsInputs();
177  SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
178 
179  // Walk over each input operand and unfold if it is transposed, broadcast
180  // or mix of two via operand's affine-map.
181  for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
182  auto &map = newMap[i];
183  auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
184  auto elType = inputRTType.getElementType();
185 
186  /// Nothing to do if map is already an identity.
187  if (map.isIdentity())
188  continue;
189 
190  auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
191 
192  // Does it need transpose?
193  if (!permutation.empty()) {
194  /// linalg.transpose permutes the dimensions of input using
195  /// rule: dim(result, i) = dim(input, permutation[i])
196  SmallVector<int64_t> transposedShape(map.getNumResults());
197  for (int64_t i = 0; i < map.getNumResults(); ++i)
198  transposedShape[i] = inputRTType.getShape()[permutation[i]];
199 
200  Value emptyTensor =
201  tensor::EmptyOp::create(rewriter, loc, transposedShape, elType);
202 
203  auto transposeOp = TransposeOp::create(rewriter, loc, newInitValues[i],
204  emptyTensor, permutation);
205  newInitValues[i] = transposeOp->getResult(0);
206  isChanged = true;
207  }
208 
209  // Does it require broadcast?
210  if (!broadcastedDims.empty()) {
211  assert(broadcastedDims.size() && "should have non size broadcast");
212  Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape,
213  inputRTType.getElementType());
214 
215  auto broadcastOp = linalg::BroadcastOp::create(
216  rewriter, loc, newInitValues[i], emptyTensor, broadcastedDims);
217 
218  newInitValues[i] = broadcastOp->getResult(0);
219  isChanged = true;
220  }
221  newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
222  }
223 
224  if (!isChanged)
225  return failure();
226 
227  SmallVector<Value> operands = op->getOperands();
228  ValueRange operandsRef(operands);
229 
230  auto newOp = linalg::GenericOp::create(
231  rewriter,
232  /*location=*/op.getLoc(),
233  /*resultTensorTypes=*/op->getResultTypes(),
234  /*inputs=*/newInitValues,
235  /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
236  /*indexingMaps=*/newMap,
237  /*iteratorTypes=*/op.getIteratorTypesArray());
238  newOp.getRegion().takeBody(op->getRegion(0));
239  rewriter.replaceOp(op, newOp->getResults());
240  return success();
241 }
242 
243 } // namespace
244 
247  patterns.insert<DecomposeProjectedPermutation>(patterns.getContext());
248 }
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
This class represents an operand of an operation.
Definition: Value.h:257
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.