MLIR  20.0.0git
DecomposeGenericByUnfoldingPermutation.cpp
Go to the documentation of this file.
1 //===- DecomposeGenericByUnfoldingPermutation.cpp -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
12 #include <map>
13 #include <optional>
14 #include <utility>
15 
16 using namespace mlir;
17 using namespace mlir::linalg;
18 
19 namespace {
20 
21 /// This pattern decomposes the input operand(s) of a linalg.generic that has
22 /// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
23 /// and broadcast. Having them folded into the linalg.generic is a good
24 /// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
25 /// explicit transpose and broadcast. This rewrite pattern helps do it for
26 /// each input operand. This is useful for instance when trying to recognize
27 /// named ops.
28 ///
29 /// The transpose, broadcast, or mixture of both, are expressed in the affine
30 /// map of the operand. Technically it is essentially `projected permutation`.
31 ///
32 /// Example
33 ///
34 /// ```mlir
35 ///
36 /// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
37 /// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
38 /// ...
39 /// %res = linalg.generic
40 /// { indexing_maps = [#projection, #identity, #identity],
41 /// iterator_types = ["parallel", "parallel", "parallel",
42 /// "parallel", "parallel"]}
43 /// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
44 /// outs(%z : tensor<5x9x7x8x10xf32>) {
45 /// ^bb0(%in: f32, %in_1: f32, %out: f32):
46 /// %div = arith.divf %in, %in_1 : f32
47 /// linalg.yield %div : f32
48 /// } -> tensor<5x9x7x8x10xf32>
49 /// ```
50 ///
51 /// In the above IR operand `%x` map is a projected-permutation. This can be
52 /// unfolded as:
53 ///
54 /// ```mlir
55 /// ...
56 /// %x_trans = linalg.transpose
57 /// ins(%x : tensor<7x8x9xf32>)
58 /// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
59 /// ...
60 /// %x_trans_bc = linalg.broadcast
61 /// ins(%x_trans : tensor<9x7x8xf32>)
62 /// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
63 /// %2 = linalg.div
64 /// ins(%x_trans_bc, %y :
65 /// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
66 /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
67 ///
68 /// Note that linalg.generic has been 'specialized' to linalg.div.
69 ///
70 /// To unfold it, it is more optimal to transpose first and then do the
71 /// broadcast. However, if transpose is done first, the permutation map needs
72 /// to be expressed in terms of reduced dimension as broadcast hasn't happened
73 /// yet. Also, the broadcast dimensions in a linalg.generic come from other
74 /// operands (those not broadcasted along that particular dimension). We work
75 /// this out by computing the convex-polyhedron shape of the linalg.generic
76 /// iteration space from shapes of all the operands, both inputs and outputs.
77 ///
78 struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
80 
81  LogicalResult matchAndRewrite(GenericOp genericOp,
82  PatternRewriter &rewriter) const override;
83 };
84 
85 /// For the given `map`, determine what dimensions are transposed and what
86 /// dimensions are broadcasted.
87 /// Returns :
88 /// transpose-permutation, broadcast-dimensions` (empty if not needed)
89 ///
90 std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
91 computeTransposeBroadcast(AffineMap &map) {
92  assert(map.isProjectedPermutation(false) && "not a projection");
93 
94  // As the map is a projection it likely operates on a smaller set of
95  // dimensions as far as the transpose is concerned (rest are broadcast).
96  int64_t minorSize = map.getNumResults();
97 
98  SmallVector<int64_t> minorResult;
99  for (int64_t i = 0; i < minorSize; ++i) {
100  auto expr = cast<AffineDimExpr>(map.getResults()[i]);
101  minorResult.push_back(expr.getPosition());
102  }
103 
104  // If dims are not monotonically increasing then transpose is present.
105  SmallVector<int64_t> sortedResMap(minorResult);
106  std::sort(sortedResMap.begin(), sortedResMap.end());
107  bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
108  sortedResMap.begin(), sortedResMap.end());
109 
110  // Walk the sorted map result to determine which dimensions are broadcasted.
112  for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
113  if (j < minorSize && sortedResMap[j] == i) {
114  j++;
115  continue;
116  }
117  broadcast.push_back(i);
118  }
119 
120  SmallVector<int64_t> permutation;
121  if (hasTranspose) {
122  // Consider an operand `x : tensor<7x8x9>` of a genericOp that has
123  // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
124  // `x`s access is both transposed and broadcast. But when specifying
125  // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
126  // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
127  // refering to d3, d4. Therefore, re-base the transpose dimensions so
128  // that they start from d0.
129  permutation.resize(minorSize);
130  std::map<int64_t, int64_t> minorMap;
131  for (int64_t i = 0; i < minorSize; ++i)
132  minorMap.insert({sortedResMap[i], i});
133 
134  // Re-map the dimensions.
135  SmallVector<int64_t> remappedResult(minorSize);
136  for (int64_t i = 0; i < minorSize; ++i)
137  remappedResult[i] = minorMap[minorResult[i]];
138 
139  /// Calculate the permutation for the transpose.
140  for (unsigned i = 0; i < minorSize; ++i) {
141  permutation[remappedResult[i]] = i;
142  }
143  }
144  return {permutation, broadcast};
145 }
146 
147 LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
148  GenericOp op, PatternRewriter &rewriter) const {
149  if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
150  op.isSingleYieldOp() || !op.isAllParallelLoops())
151  return failure();
152 
153  // If the map of an operand is not a `projected permutation` then
154  // it cannot be decomposed to mere transpose and broadcast.
155  // The requirement that all maps be `projected permutation` may be
156  // over-restrictive but since we need to determine shape of the
157  // iteration space as well, reject if any map violates assumption.
158  for (auto &opOperand : op->getOpOperands()) {
159  auto map = op.getMatchingIndexingMap(&opOperand);
160  if (!map.isProjectedPermutation(false))
161  return failure();
162  }
163 
164  // Decomposing linalg.generic involves creating `tensor.empty`
165  // which can have dynamic shapes but then we would have to work
166  // out which operand can supply that runtime-value (tensor.dim).
167  // Leaving it as a future TODO.
168  if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
169  auto opType = cast<RankedTensorType>(oper.get().getType());
170  return ShapedType::isDynamicShape(opType.getShape());
171  }))
172  return failure();
173 
174  auto outputShape = op.getStaticLoopRanges();
175 
176  auto loc = op.getLoc();
177  bool isChanged = false;
178  SmallVector<Value> newInitValues = op.getDpsInputs();
179  SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
180 
181  // Walk over each input operand and unfold if it is transposed, broadcast
182  // or mix of two via operand's affine-map.
183  for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
184  auto &map = newMap[i];
185  auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
186  auto elType = inputRTType.getElementType();
187 
188  /// Nothing to do if map is already an identity.
189  if (map.isIdentity())
190  continue;
191 
192  auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
193 
194  // Does it need transpose?
195  if (!permutation.empty()) {
196  /// linalg.transpose permutes the dimensions of input using
197  /// rule: dim(result, i) = dim(input, permutation[i])
198  SmallVector<int64_t> transposedShape(map.getNumResults());
199  for (int64_t i = 0; i < map.getNumResults(); ++i)
200  transposedShape[i] = inputRTType.getShape()[permutation[i]];
201 
202  Value emptyTensor =
203  rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
204 
205  auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
206  emptyTensor, permutation);
207  newInitValues[i] = transposeOp->getResult(0);
208  isChanged = true;
209  }
210 
211  // Does it require broadcast?
212  if (!broadcastedDims.empty()) {
213  assert(broadcastedDims.size() && "should have non size broadcast");
214  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
215  loc, outputShape, inputRTType.getElementType());
216 
217  auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
218  loc, newInitValues[i], emptyTensor, broadcastedDims);
219 
220  newInitValues[i] = broadcastOp->getResult(0);
221  isChanged = true;
222  }
223  newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
224  }
225 
226  if (isChanged) {
227  SmallVector<Value> operands = op->getOperands();
228  ValueRange operandsRef(operands);
229 
230  auto newOp = rewriter.create<linalg::GenericOp>(
231  /*location=*/op.getLoc(),
232  /*resultTensorTypes=*/op->getResultTypes(),
233  /*inputs=*/newInitValues,
234  /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
235  /*indexingMaps=*/newMap,
236  /*iteratorTypes=*/op.getIteratorTypesArray());
237 
238  newOp.getRegion().takeBody(op->getRegion(0));
239  rewriter.replaceOp(op, newOp->getResults());
240  }
241  return success();
242 }
243 
244 } // namespace
245 
248  patterns.insert<DecomposeProjectedPermutation>(patterns.getContext());
249 }
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.