MLIR  21.0.0git
TosaReduceTransposes.cpp
Go to the documentation of this file.
1 //===- TosaReduceTransposes.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // ----------
10 // Motivation:
11 // ----------
12 
13 // Some legalization pathways introduce redundant tosa.TRANSPOSE
14 // operations that result in avoidable data movement. For example,
15 // PyTorch -> TOSA contains a lot of unnecessary transposes due
16 // to conversions between NCHW and NHWC.
17 
18 // We wish to remove all the ones that we can, since in general
19 // it is possible to remove the overwhelming majority.
20 
21 // -------------------
22 // High-Level Overview:
23 // -------------------
24 
25 // The pass works through the transpose operators in the program. It begins at
26 // some transpose operator with an associated permutations tensor. It traverses
27 // upwards through the dependencies of this transpose and verifies that we
28 // encounter only operators with the TosaElementwiseOperator trait and terminate
29 // in either constants, reshapes, or transposes.
30 
31 // We then evaluate whether there are any additional restrictions (the
32 // transposes it terminates in must invert the one we began at, and the reshapes
33 // must be ones in which we can fold the transpose into), and then we hoist the
34 // transpose through the intervening operators, folding it at the constants,
35 // reshapes, and transposes.
36 
37 // Finally, we ensure that we do not need both the transposed form (the form
38 // that had the transpose hoisted through it) and the untransposed form (which
39 // it was prior), by analyzing the usages of those dependent operators of a
40 // given transpose we are attempting to hoist and replace.
41 
42 // If they are such that it would require both forms to be necessary, then we do
43 // not replace the hoisted transpose, causing the new chain to be dead.
44 // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
45 // chain will ever then be live, resulting in no duplication.
46 
47 // We then perform a simple one-pass DCE, so no canonicalization is necessary.
48 
49 // -----------
50 // Future Work:
51 // -----------
52 
53 // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
54 // hoisted
55 // transposes with different permutation tensors.
56 
57 // (2) Expand the class of foldable upstream ReshapeOp we permit beyond
58 // N -> 1x1x...x1xNx1x...x1x1.
59 
60 // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
61 // those that form the identity.
62 
63 // (4) Add support for more instructions besides TosaElementwiseOperator as
64 // the intervening ones (for example, the reduce_* operators).
65 
66 // (5) Support hoisting transposes up to an input parameter.
67 
68 //===----------------------------------------------------------------------===//
69 
74 #include "mlir/IR/Iterators.h"
75 #include "mlir/IR/Matchers.h"
76 #include "llvm/ADT/TypeSwitch.h"
77 #include <memory>
78 #include <set>
79 #include <stack>
80 
81 namespace mlir {
82 namespace tosa {
83 #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
84 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
85 } // namespace tosa
86 } // namespace mlir
87 
88 using namespace mlir;
89 using namespace mlir::tosa;
90 
91 //===----------------------------------------------------------------------===//
92 // TOSA Reduce Transposes Pass.
93 //===----------------------------------------------------------------------===//
94 
95 namespace {
96 
97 struct TosaReduceTransposes final
98  : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
99  void runOnOperation() override;
100 
101 private:
102  // This will collect all the data dependencies for the given Operation
103  // up to and including ConstOp, ReshapeOp, and TransposeOp.
104  bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
105  bool convertDependentOps(SetVector<Operation *> &dependentOps,
106  DenseMap<Value, Value> &valuesMap,
107  IRRewriter &rewriter,
108  ArrayRef<int32_t> hoistedPerms);
109 
110  // Checks if the two permutations, when applied consecutively, result
111  // in the identity.
112  bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
113  ArrayRef<int32_t> perms2);
114 
115  // This is meant to apply to operations with the TosaElementwiseOperator
116  // trait.
117  std::optional<Value>
118  buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
119  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
120 
121  // This updates valuesMap when we encounter another TransposeOp as a
122  // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
123  // this %1 = tosa.transpose %0 <- when tracking back from this
124  std::optional<Value>
125  buildMappedToValue(TransposeOp transposeOp,
126  const DenseMap<Value, Value> &valuesMap,
127  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
128 
129  // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
130  // it creates new ReshapeOp with that fold.
131  std::optional<Value>
132  buildMappedToValue(ReshapeOp reshapeOp,
133  const DenseMap<Value, Value> &valuesMap,
134  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
135 
136  // We may have something like:
137  // %0 = tosa.const
138  // %1 = tosa.transpose
139  // %2 = tosa.add %0, %1
140  // %3 = tosa.transpose %2
141  // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
142  // in MobilenetV3.
143  std::optional<Value>
144  buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
145  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
146 
147  // Checks which TransposeOp we should "replace", turning their converted
148  // chains of ops, through which they were propagated, "live", and the old code
149  // "dead." Attempts to avoid doing so when doing so would result in the old
150  // code staying "live," resulting in duplication.
151  std::set<TransposeOp> getGoodReplacements(
152  ArrayRef<int32_t> perms,
153  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
154  &transposeInfo);
155 
156  // Helper function for dependenciesAreValid.
157  bool userNotContainedInValidTransposeDependencies(
158  Operation *user, std::set<TransposeOp> &validTransposes,
159  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
160  &transposeInfo);
161 
162  // Helper function for getGoodReplacements to check if some TransposeOp's
163  // dependencies are OK.
164  bool dependenciesAreValid(
165  ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
166  std::set<TransposeOp> &validTransposes,
167  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
168  &transposeInfo);
169 
170  // Applies perms to the DenseElementsAttr.
171  // If it returns std::nullopt, it also triggers pass failure, since verifier
172  // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
173  // it should fail).
174  // This is a basic API and may benefit from refactor into the core MLIR APIs.
175  std::optional<DenseElementsAttr>
176  transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
177 };
178 
179 std::optional<DenseElementsAttr>
180 TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
181  ArrayRef<int32_t> perms) {
182  RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
183  RankedTensorType newType =
184  RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
185  oldType.getElementType());
186  size_t rank = oldType.getRank();
187 
188  // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
189  // 0. If not in place, something is very wrong.
190  if (rank <= 0 || oldType.getNumElements() <= 0) {
191  signalPassFailure();
192  return std::nullopt;
193  }
194 
195  if (input.isSplat())
196  return input.reshape(newType);
197 
198  // The algorithm is approximately as follows:
199  // input: perms, input flat array, input tensor type
200  // (1/2) determine the strides of input/output if
201  // they were strided in row-major order. (3) adjust the strides for the
202  // input to be in the same order of indices as the output is written.
203  // (4) process dimension by dimension. example: perms 2, 0, 1; input
204  // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
205  // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
206  // input strides to be as input[i + 12j + 4k] so we may process
207  // layer-by-layer.
208 
209  // Step 1/2: Strides for input. We ignore output since row-major and can just
210  // push_back.
211 
212  SmallVector<int64_t> originalInputStrides(rank);
213  originalInputStrides[rank - 1] = 1;
214  // index with int64_t to avoid overflow
215  for (int64_t i = rank - 2; i >= 0; i--)
216  originalInputStrides[i] =
217  originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
218 
219  // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
220  // output which is done in row-major order.
221 
222  SmallVector<int64_t> newInputStrides;
223  newInputStrides.reserve(rank);
224  for (int32_t v : perms)
225  newInputStrides.push_back(originalInputStrides[v]);
226 
227  // Step 4: Write out the transposed "flat array" dimension by dimension.
228 
229  auto inputArray = input.getValues<Attribute>();
230  SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
231  for (size_t i = 0; i < rank; i++)
232  boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
233 
234  SmallVector<Attribute> resultArray;
235  resultArray.reserve(inputArray.size());
236 
237  std::function<void(int64_t,
238  SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
239  processTransposeDim = [&](auto accumulatedIndex, auto it) {
240  if (it == boundsAndStrides.end()) {
241  resultArray.push_back(inputArray[accumulatedIndex]);
242  return;
243  }
244 
245  for (int64_t i = 0; i < it->first; i++) {
246  int64_t j = accumulatedIndex + i * it->second;
247  processTransposeDim(j, it + 1);
248  }
249  };
250 
251  processTransposeDim(0, boundsAndStrides.begin());
252 
253  return DenseElementsAttr::get(newType, resultArray);
254 }
255 
256 // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
257 // as the sources of the data dependencies, and TosaElementWiseOperator
258 // after that, if the function returns true.
259 bool TosaReduceTransposes::collectFanIn(Operation *op,
260  SetVector<Operation *> &collected) {
261  // Can occur if defined through the parameter to a func.func.
262  if (!op)
263  return false;
264 
265  if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
266  return false;
267 
268  // Prevent extra work if already seen.
269  if (collected.contains(op))
270  return true;
271 
272  // Throw it out so later don't have to deal with this.
273  if (op->getNumResults() != 1 ||
274  !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
275  return false;
276 
277  // We don't wish to traverse up a ReshapeOp, since generally we can't
278  // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
279  // will have no in-edges in the data dependency graph we construct for
280  // the downstream TransposeOp.
281  if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
282  !llvm::isa<tosa::ConstOp>(op)) {
283 
284  if (!llvm::isa<tosa::MulOp>(op) &&
286  return false;
287 
288  for (Value operand : op->getOperands()) {
289  // If this is a problem in future, think about alternatives to recursion.
290  if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
291  // do not recurse into MulOp's shift operand
292  continue;
293  }
294  if (!collectFanIn(operand.getDefiningOp(), collected))
295  return false;
296  }
297  }
298 
299  // Insert in topological order.
300  collected.insert(op);
301 
302  return true;
303 }
304 
305 // Assuming that due to the verification of TransposeOp perms arrays are
306 // permutations of 0 - perms.size() - 1.
307 bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
308  ArrayRef<int32_t> perms2) {
309  if (perms1.size() != perms2.size())
310  return false;
311  int32_t n = perms1.size();
312  for (int32_t i = 0; i < n; i++)
313  if (perms2[perms1[i]] != i)
314  return false;
315  return true;
316 }
317 
318 // Primary overload for those with TosaElementwiseOperator trait.
319 // The other ones handle the case of the operations that occur at the
320 // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
321 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
322  Operation *op, const DenseMap<Value, Value> &valuesMap,
323  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
324  if (op->getNumResults() != 1 ||
325  (!llvm::isa<tosa::MulOp>(op) &&
327  return std::nullopt;
328 
329  auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
330  SmallVector<Value, 3> operands;
331  for (Value v : op->getOperands()) {
332  if (valuesMap.contains(v)) {
333  operands.push_back(valuesMap.at(v));
334  } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
335  // special case for MulOp's shift operand
336  operands.push_back(v);
337  } else {
338  return std::nullopt;
339  }
340  }
341 
342  // Conceptually, we propagate the hoisted TransposeOp through
343  // these interveaning operations. For example,
344 
345  // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
346  // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
347  // tensor<3x2xi32>
348 
349  // becomes:
350  // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
351  // tensor<3x2xi32>
352  // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
353 
354  // We construct this new tosa.clamp here, but it doesn't
355  // turn "live" until the transpose being hoisted through this chain
356  // is replaced with the proper value from the new chain.
357 
358  return rewriter
359  .create(op->getLoc(), op->getName().getIdentifier(), operands,
361  applyTOSAPermutation(resultType.getShape(), hoistedPerms),
362  resultType.getElementType()),
363  op->getAttrs())
364  ->getResult(0);
365 }
366 
367 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
368  TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
369  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
370  if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
371  return std::nullopt;
372  return transposeOp.getInput1();
373 }
374 
375 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
376  ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
377  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
378  auto reshapeOutput = reshapeOp.getOutput();
379  auto reshapeInputType =
380  llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
381  auto reshapeInputShape = reshapeInputType.getShape();
382  // want reshape N -> 1x1x...x1xNx1x...x1x1
383  if (!reshapeInputType || reshapeInputShape.size() != 1)
384  return std::nullopt;
385  auto reshapeOutputType =
386  llvm::cast<RankedTensorType>(reshapeOutput.getType());
387 
388  // Instead of inserting a TransposeOp here, we check if we can fold it into
389  // the ReshapeOp. There is more complex cases where this is possible, and
390  // this check can be extended.
391 
392  // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
393  auto shape = reshapeOutputType.getShape();
394  size_t ones = llvm::count(shape, 1);
395  // N == 1 and N != 1
396  if (ones != shape.size() - 1 &&
397  !(ones == shape.size() && reshapeInputShape[0] == 1))
398  return std::nullopt;
399 
400  // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
402  if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
403  newShape)) {
404  // this mean shape is not constant
405  return std::nullopt;
406  }
407  ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
408  auto foldedReshape = rewriter.create<ReshapeOp>(
409  reshapeOp.getLoc(),
410  RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
411  reshapeOutputType.getElementType()),
412  reshapeOp.getInput1(),
414  hoistedPerms)));
415  return foldedReshape->getResult(0);
416 }
417 
418 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
419  ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
420  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
421  auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
422  if (!denseAttr)
423  return std::nullopt;
424  auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
425  if (!maybeNewDenseAttr.has_value())
426  return std::nullopt;
427  auto newDenseAttr = maybeNewDenseAttr.value();
428  auto newConstOp = rewriter.create<ConstOp>(
429  constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
430  return newConstOp->getResult(0);
431 }
432 
433 bool TosaReduceTransposes::convertDependentOps(
434  SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
435  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
436 
437  for (Operation *op : dependentOps) {
438  if (!op || op->getNumResults() != 1)
439  return false;
440 
441  Value priorValue = op->getResult(0);
442 
443  // It's possible on a prior transposeOp we had the same dependency and
444  // already resolved it.
445  if (valuesMap.contains(priorValue))
446  continue;
447 
448  // Keep converted ops close to the original.
449  rewriter.setInsertionPointAfter(op);
450 
451  std::optional<Value> maybeValue =
453  .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
454  return buildMappedToValue(transposeOp, valuesMap, rewriter,
455  hoistedPerms);
456  })
457  .Default([&](Operation *op) {
458  return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
459  });
460 
461  if (!maybeValue.has_value())
462  return false;
463 
464  valuesMap[priorValue] = maybeValue.value();
465  }
466 
467  return true;
468 }
469 
470 bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
471  Operation *user, std::set<TransposeOp> &validTransposes,
472  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
473  &transposeInfo) {
474  return llvm::none_of(
475  transposeInfo,
476  [&validTransposes,
477  user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
478  const auto &[transposeOp, dependentOps] = info;
479  return validTransposes.count(transposeOp) &&
480  dependentOps.contains(user);
481  });
482 }
483 
484 // Dependencies are valid for an operation if none of them occur outside
485 // of the proper fan-in cones of the hoisted TransposeOp with the same perms
486 // that we can replace. Described in more detail within.
487 bool TosaReduceTransposes::dependenciesAreValid(
488  ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
489  std::set<TransposeOp> &validTransposes,
490  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
491  &transposeInfo) {
492  for (Operation *op : dependentOps) {
493 
494  // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
495  // This can be changed later if we find the memory impact is too high.
496  if (llvm::isa<ConstOp>(op))
497  continue;
498 
499  for (OpOperand &use : op->getUses()) {
500  // Want the uses to be (1) contained in the dependentOps of other
501  // validTransposes, or (2) to be directly used in a TransposeOp with the
502  // same perms. For (2) it means the fan-in is a subset of our
503  // dependentOps, so it is also a validTranspose that will eventually be
504  // replaced.
505  Operation *user = use.getOwner();
506  if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
507  // Can later think about cases where transpose -> transpose
508  // or reshape -> transpose, where the transposes are not necessarily
509  // the same perms as the hoisted, if implementing a more general
510  // transform. These could be permitted.
511  if (!llvm::equal(perms, otherTranspose.getPerms()))
512  return false;
513  } else if (userNotContainedInValidTransposeDependencies(
514  user, validTransposes, transposeInfo)) {
515  return false;
516  }
517  }
518  }
519 
520  return true;
521 }
522 
523 // Getting the set of TransposeOp that we can replace without causing
524 // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
525 // dead code. This is done by iterating the set until convergence, since
526 // if you are used outside your own fan-in cone, it's possible to be used
527 // in another fan-in cone of a TransposeOp that is being replaced -- unless
528 // we find that that one has a usage outside of it too.
529 std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
530  ArrayRef<int32_t> perms,
531  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
532  &transposeInfo) {
533  // Initially, we assume they are all good to replace,
534  // and we whittle them down based on our criteria.
535  std::set<TransposeOp> ableToReplace;
536  for (const auto &[transposeOp, _] : transposeInfo)
537  ableToReplace.insert(transposeOp);
538 
539  bool gotRid;
540  do {
541  gotRid = false;
542  for (const auto &[transposeOp, dependentOps] : transposeInfo) {
543  // We don't care about it. Already invalidated.
544  if (!ableToReplace.count(transposeOp))
545  continue;
546 
547  // Check for validity.
548  if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
549  transposeInfo)) {
550  ableToReplace.erase(transposeOp);
551  gotRid = true;
552  break;
553  }
554  }
555 
556  } while (gotRid);
557 
558  return ableToReplace;
559 }
560 
561 void TosaReduceTransposes::runOnOperation() {
562  // We want to operate only within a single block.
563  if (!getOperation().getRegion().hasOneBlock())
564  return;
565 
566  IRRewriter rewriter(&getContext());
567  // For each perms, maintain a mapping for converted ops, avoid duplication.
569  // For each perms, we keep track of which TransposeOp are eligible
570  // for replacement alongside their dependentOps.
572  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
573  permsToTransposeInfo;
574 
575  // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
576  // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
577  // since no guarantee of smallness.
578  std::vector<SmallVector<int32_t>> collectedPerms;
579 
580  // This keeps track of the order across all eligible-for-replacement
581  // TransposeOp and their perms, a necessity for the final replacements.
582  std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
583 
584  // We want to reserve the space up front, since SmallVector stores some data
585  // internally and the ArrayRef can reference that, which we don't want to get
586  // invalidated.
587  size_t expectedMaxPerms = 0;
588  getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
589  collectedPerms.reserve(expectedMaxPerms);
590 
591  getOperation().walk([&](TransposeOp transposeOp) {
592  SetVector<Operation *> dependentOps;
593  collectedPerms.emplace_back();
594  SmallVector<int32_t> &perms = collectedPerms.back();
595 
596  // Dynamic shapes are OK, but the incompatible ones will be rejected later.
597  auto input = transposeOp.getInput1();
598  auto output = transposeOp.getOutput();
599 
600  // However, we don't support unranked tensors.
601  if (!llvm::isa<RankedTensorType>(input.getType()) ||
602  !llvm::isa<RankedTensorType>(output.getType()))
603  return;
604 
605  llvm::for_each(transposeOp.getPerms(),
606  [&perms](const auto i) { perms.emplace_back(i); });
607 
608  // We let --canonicalize deal with identity transpose.
609  if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
610  return;
611 
612  // Can fail if some set of basic invariants is not met that we want to
613  // perform our conversions.
614  if (!collectFanIn(input.getDefiningOp(), dependentOps))
615  return;
616 
617  // Want to associate valuesMap for already converted of the same perms,
618  // since it's possible multiple hoisted transposes w/ different perms
619  // converge on an op, which would result in different transformations.
620  DenseMap<Value, Value> &valuesMap = permsToValues[perms];
621 
622  // Attempt to perform the conversions and placements into IR
623  // without turning inserted code "live". Also fills out valuesMap.
624  // Fails if there is an intermediary we do not support.
625  if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
626  // Some additional operations may have been inserted, but will be
627  // removed by dead code elimination.
628  return;
629 
630  // This should not happen. If it does -- it's unexpected,
631  // so we fail the pass.
632  if (!valuesMap.contains(input))
633  return signalPassFailure();
634 
635  // It's possible the types are not compatible (because of dynamic shapes),
636  // and in these cases, want to resolve dynamic shapes before running the
637  // pass.
638  if (output.getType() != valuesMap.at(input).getType())
639  return;
640 
641  auto &transposeInfo = permsToTransposeInfo[perms];
642 
643  // In general, we might also want to introduce "newDependentOps"
644  // if there are new usages that don't fall inside the original fan-ins
645  // (like the TransposeOp we insert for ReshapeOp),
646  // but in this case, that is specialized enough and overlaps
647  // with another direct-use TransposeOp case we need to cover anyway.
648  transposeInfo.push_back({transposeOp, dependentOps});
649 
650  // This is for the final replacement across all transposes.
651  totalTransposeOrder.push({transposeOp, perms});
652  });
653 
654  // We want to do a full fan-in analysis on a perms-level,
655  // since if we do it on a multi-perms level, and they share (due to a shared
656  // dependency on a Reshape) then we would also get duplicate ops.
657  // Const is special cased.
658  std::set<TransposeOp> ableToReplace;
659  for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
660  // Gives us back replacements that would never result in any duplicate
661  // operations being inserted by us in the IR (i.e, our goal is only to
662  // remove transposes, and not create a "new chain" to do so, but replace
663  // the existing chains).
664  // Ideally, --canonicalize is run before this pass, since it helps this
665  // analysis by removing dead code to allow more potentially acceptable
666  // transformations.
667  auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
668  ableToReplace.insert(goodReplacementsForPerms.begin(),
669  goodReplacementsForPerms.end());
670  }
671 
672  // We want to do replacement across all transposes
673  // in reverse order, due to invalidation of valuesMap mappings
674  // if we did it otherwise.
675  while (!totalTransposeOrder.empty()) {
676  auto [transposeOp, perms] = totalTransposeOrder.top();
677  totalTransposeOrder.pop();
678 
679  if (ableToReplace.count(transposeOp) == 0)
680  continue;
681 
682  auto &valuesMap = permsToValues[perms];
683  auto input = transposeOp.getInput1();
684 
685  // The purpose of this reverse iteration
686  // is to avoid valuesMap invalidation. If it happens,
687  // something is wrong.
688  if (!valuesMap.contains(input))
689  return signalPassFailure();
690 
691  rewriter.replaceOp(transposeOp, valuesMap.at(input));
692  }
693 
694  // We can remove all dead code by going in reverse.
695  // This is because we would remove usages before we
696  // see the users.
697  getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
698  [&](Operation *op) {
699  if (isOpTriviallyDead(op))
700  rewriter.eraseOp(op);
701  });
702 }
703 
704 } // namespace
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:243
This class indicates that an op is tosa-elementwise (permits broadcasting, unlike Elementwise trait).
Definition: TosaOps.h:110
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:847
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This iterator enumerates elements in "reverse" order.
Definition: Iterators.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.